11from pathlib import Path
2+ import matplotlib .pyplot as plt
23import json
34import triton .profiler as proton
45import torch
89from triton_bench .numerics import InFlexData
910from triton_bench .routing import routing
1011from triton_bench .target_info import is_hip , get_cdna_version
12+ from dataclasses import dataclass
1113
1214if torch .cuda .is_available () and not is_hip ():
1315 from triton ._C .libtriton import nvidia
@@ -66,9 +68,38 @@ def quantize(w, dtype, dev, **opt):
6668 actual_weight_scale_shape = weight_scale_shape )
6769
6870
69- def bench_mlp (batch , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype ,
70- # tensor / expert parallelism
71- TP = 1 , EP = 1 , name = "" ):
71+ @dataclass
72+ class PerfData :
73+ time : float
74+ flops : float
75+ bytes : float
76+
77+ @property
78+ def tflops (self ):
79+ return self .flops / self .time * 1e-3
80+
81+ @property
82+ def tbps (self ):
83+ return self .bytes / self .time * 1e-3
84+
85+ @property
86+ def opint (self ):
87+ # operational intensity
88+ assert self .bytes > 0
89+ return self .flops / self .bytes
90+
91+ @property
92+ def util (self ) -> float :
93+ if SPECS is None :
94+ return 0.0
95+
96+ peak_flops = max (SPECS ["MAX_TFLOPS8" ], SPECS .get ("MAX_TFLOPS16" , 0 ))
97+ min_t_flop = self .flops / peak_flops * 1e-3 # ns → µs
98+ min_t_bw = self .bytes / SPECS ["MAX_TBPS" ] * 1e-3
99+ return max (min_t_flop , min_t_bw ) / self .time
100+
101+
102+ def bench_mlp (batch , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , name ):
72103 assert n_expts_tot % EP == 0
73104 assert dim2 % TP == 0
74105 dev = "cuda"
@@ -96,7 +127,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
96127 pc2 = PrecisionConfig (mx_ctx = w2_mx , flex_ctx = FlexCtx (rhs_data = w2_flex ))
97128
98129 # -- benchmark --
99- fpath = Path (f"logs/{ name } /{ batch } -{ dim1 } - { dim2 } - { n_expts_tot } - { n_expts_act } - { x_dtype } - { w_dtype } .hatchet" )
130+ fpath = Path (f"logs/{ name } /{ x_dtype } -{ w_dtype } -TP { TP } -EP { EP } /profiles/batch- { batch } .hatchet" )
100131 fpath .parent .mkdir (parents = True , exist_ok = True )
101132 x_dtype = {"fp16" : torch .float16 , "bf16" : torch .bfloat16 , "fp8" : torch .float8_e4m3fn }[x_dtype ]
102133 # special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
@@ -115,7 +146,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
115146 else :
116147 rdata , gather_indx , scatter_indx = None , None , None
117148 x = matmul_ogs (x , w1 , b1 , rdata , gather_indx = gather_indx , precision_config = pc1 )
118- x = triton_bench .swiglu .swiglu (x , 1.0 , pcs )
149+ x = triton_bench .swiglu .swiglu (x , 1.0 , pcs , routing_data = rdata )
119150 x = matmul_ogs (x , w2 , b2 , rdata , scatter_indx = scatter_indx , precision_config = pc2 )
120151 proton .finalize ()
121152
@@ -127,42 +158,70 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
127158 matmuls = [
128159 x for x in data [0 ]["children" ] if "_matmul" in x ["frame" ]["name" ] and "metadata" not in x ["frame" ]["name" ]
129160 ]
130- tot_bytes = sum ([x ["metrics" ]["bytes" ] for x in matmuls ])
131- tot_flops = {w : sum ([x ["metrics" ].get (f"flops{ w } " , 0 ) for x in matmuls ]) for w in [8 , 16 ]}
161+ bytes = sum ([x ["metrics" ]["bytes" ] for x in matmuls ])
162+ flops = {w : sum ([x ["metrics" ].get (f"flops{ w } " , 0 ) for x in matmuls ]) for w in [8 , 16 ]}
163+ flops = sum ([flops [w ] for w in [8 , 16 ]])
132164 # compute total time (incl. "not useful" work)
133165 # TODO: proton should really be recording that in the json instead of
134166 # relying on the user to aggregate
135- tot_time = sum (x ["metrics" ].get ("time (ns)" , 0 ) for x in data [0 ]["children" ])
136- min_time_flops = min_time_bytes = 0
137- if SPECS is not None :
138- min_time_flops = sum ([tot_flops [w ] / SPECS [f"MAX_TFLOPS{ w } " ] for w in [8 , 16 ]]) * 1e-3
139- min_time_bytes = tot_bytes / SPECS ["MAX_TBPS" ] * 1e-3
140- min_time = max (min_time_flops , min_time_bytes )
141- util = min_time / tot_time
142- else :
143- util = 0.0
144- tflops = sum ([tot_flops [w ] for w in [8 , 16 ]]) / tot_time * 1e-3
145- tbps = tot_bytes / tot_time * 1e-3
146- print (f"Utilization: { util :.0%} ; { tflops :>6.1f} TFLOPs, { tbps :.1f} TB/s" )
147-
148- return util , tflops , tbps
167+ time = sum (x ["metrics" ].get ("time (ns)" , 0 ) for x in data [0 ]["children" ])
168+ return PerfData (time , flops , bytes )
169+
170+
171+ def roofline_mlp (batch_ranges , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP = 1 , EP = 1 , name = "" ,
172+ verbose = True ):
173+ import numpy as np
174+ from itertools import chain
175+ from bisect import bisect_left
176+ batches = list (chain (* [range (* r ) for r in batch_ranges ]))
177+ # collect performance data
178+ perfs = []
179+ print (f"Benchmarking { name } ({ x_dtype } x{ w_dtype } , TP={ TP } , EP={ EP } )..." )
180+ print ("===============================================================" )
181+ for batch in batches :
182+ perfs += [bench_mlp (batch , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , name )]
183+ if verbose :
184+ print (f"Batch: { batch } ; Util: { perfs [- 1 ].util } ; TFLOPS: { perfs [- 1 ].tflops } ; TBPS: { perfs [- 1 ].tbps } " )
185+ print ("===============================================================" )
186+ # machine limits
187+ fig , ax = plt .subplots (figsize = (7 , 5 ), dpi = 120 )
188+ ax .set_xlabel ("batch size (toks/expt)" )
189+ ax .set_ylabel ("performance [TFLOP/s]" )
190+ ax .set_title ("roofline" )
191+ # add a tiny margin so points are not flush with the frame
192+ xs = [batch * n_expts_act / n_expts_tot for batch in batches ]
193+ perf = [p .tflops for p in perfs ]
194+ xmin , xmax = min (xs ), max (xs )
195+ dx = 0.05 * (xmax - xmin ) if xmax > xmin else 1.0
196+ ax .set_xlim (xmin - dx , xmax + dx )
197+ ax .set_ylim (100 , SPECS ["MAX_TFLOPS8" ] + 500 )
198+ # plot roofline
199+ max_tbps = SPECS ["MAX_TBPS" ]
200+ max_tflops = SPECS ["MAX_TFLOPS8" ]
201+ opints = [p .opint for p in perfs ]
202+ knee = bisect_left (opints , max_tflops / max_tbps ) - 1
203+ x_bw , x_comp = xs [:knee ], xs [knee :]
204+ y_bw = [op * max_tbps for op in opints [:knee ]]
205+ y_comp = [max_tflops ] * len (x_comp )
206+ ax .plot (x_bw , y_bw , "--" , label = f"BW-bound ({ max_tbps :.0f} TB/s)" )
207+ ax .plot (x_comp , y_comp , "--" , label = f"Compute-bound ({ max_tflops :.0f} TFLOP/s)" )
208+ # plot data
209+ ax .scatter (xs , perf , marker = "+" )
210+ ax .legend (frameon = False , loc = "lower right" )
211+ ax .grid (True , which = "both" , ls = ":" , lw = 0.5 )
212+ fig .tight_layout ()
213+ fpath = Path (f"logs/{ name } /{ x_dtype } -{ w_dtype } -TP{ TP } -EP{ EP } /roofline.png" )
214+ plt .savefig (fpath )
149215
150216
151217if __name__ == "__main__" :
152218 has_native_mx4 = torch .cuda .get_device_capability (0 )[0 ] >= 10 or get_cdna_version () == 4
153219 if SPECS is None :
154220 print ("Current GPU has no specs provided, utilization is N/A" )
155- if has_native_mx4 :
156- bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "fp8" , TP = 1 , EP = 1 , name = "dense" )
157- bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "mx4" , TP = 1 , EP = 1 , name = "dense" )
158- bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "fp8" , TP = 4 , EP = 1 , name = "llama4" )
159- bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "mx4" , TP = 4 , EP = 1 , name = "llama4" )
160- else :
161- # bf16/fp16 x fp8 is skipped because matmul_ogs requires x and w has the
162- # same type when not doing mxfp operation
163- bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "fp8" , TP = 1 , EP = 1 , name = "dense" )
164- bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp16" , "mx4" , TP = 1 , EP = 1 , name = "dense" )
165- bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "bf16" , "mx4" , TP = 1 , EP = 1 , name = "dense" )
166- bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "fp8" , TP = 4 , EP = 1 , name = "llama4" )
167- bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "bf16" , "mx4" , TP = 4 , EP = 1 , name = "llama4" )
168- bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp16" , "mx4" , TP = 4 , EP = 1 , name = "llama4" )
221+ batch_ranges = [(1024 , 32768 , 1024 )]
222+ dense_dtypes = ["fp8" , "fp8" ]
223+ quantized_dtypes = ["fp8" , "mx4" ] if has_native_mx4 else ["bf16" , "mx4" ]
224+ roofline_mlp (batch_ranges , 8192 , 8192 , 1 , 1 , * dense_dtypes , TP = 1 , EP = 1 , name = "dense" )
225+ roofline_mlp (batch_ranges , 8192 , 8192 , 1 , 1 , * quantized_dtypes , TP = 1 , EP = 1 , name = "dense" )
226+ roofline_mlp (batch_ranges , 5120 , 8192 , 128 , 4 , * dense_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
227+ roofline_mlp (batch_ranges , 5120 , 8192 , 128 , 4 , * quantized_dtypes , TP = 1 , EP = 1 , name = "llama4-maverick" )
0 commit comments