@@ -98,7 +98,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
9898 # -- benchmark --
9999 fpath = Path (f"logs/{ name } /{ batch } -{ dim1 } -{ dim2 } -{ n_expts_tot } -{ n_expts_act } -{ x_dtype } -{ w_dtype } .hatchet" )
100100 fpath .parent .mkdir (parents = True , exist_ok = True )
101- x_dtype = {"bf16" : torch .bfloat16 , "fp8" : torch .float8_e4m3fn }[x_dtype ]
101+ x_dtype = {"fp16" : torch . float16 , " bf16" : torch .bfloat16 , "fp8" : torch .float8_e4m3fn }[x_dtype ]
102102 # special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
103103 if x_dtype == torch .float8_e4m3fn and get_cdna_version () == 3 :
104104 x_dtype = torch .float8_e4m3fnuz
@@ -140,17 +140,29 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
140140 min_time = max (min_time_flops , min_time_bytes )
141141 util = min_time / tot_time
142142 else :
143- util = "N/A"
143+ util = 0.0
144144 tflops = sum ([tot_flops [w ] for w in [8 , 16 ]]) / tot_time * 1e-3
145145 tbps = tot_bytes / tot_time * 1e-3
146+ print (f"Utilization: { util :.0%} ; { tflops :>6.1f} TFLOPs, { tbps :.1f} TB/s" )
146147
147148 return util , tflops , tbps
148149
149150
150151if __name__ == "__main__" :
151152 has_native_mx4 = torch .cuda .get_device_capability (0 )[0 ] >= 10 or get_cdna_version () == 4
152- qxdtype = "fp8" if has_native_mx4 else "bf16"
153- print (bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "fp8" , TP = 1 , EP = 1 , name = "dense" ))
154- print (bench_mlp (8192 , 8192 , 8192 , 1 , 1 , qxdtype , "mx4" , TP = 1 , EP = 1 , name = "dense" ))
155- print (bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "fp8" , TP = 4 , EP = 1 , name = "llama4" ))
156- print (bench_mlp (2048 , 5120 , 8192 , 128 , 4 , qxdtype , "mx4" , TP = 4 , EP = 1 , name = "llama4" ))
153+ if SPECS is None :
154+ print ("Current GPU has no specs provided, utilization is N/A" )
155+ if has_native_mx4 :
156+ bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "fp8" , TP = 1 , EP = 1 , name = "dense" )
157+ bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "mx4" , TP = 1 , EP = 1 , name = "dense" )
158+ bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "fp8" , TP = 4 , EP = 1 , name = "llama4" )
159+ bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "mx4" , TP = 4 , EP = 1 , name = "llama4" )
160+ else :
161+ # bf16/fp16 x fp8 is skipped because matmul_ogs requires x and w has the
162+ # same type when not doing mxfp operation
163+ bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp8" , "fp8" , TP = 1 , EP = 1 , name = "dense" )
164+ bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "fp16" , "mx4" , TP = 1 , EP = 1 , name = "dense" )
165+ bench_mlp (8192 , 8192 , 8192 , 1 , 1 , "bf16" , "mx4" , TP = 1 , EP = 1 , name = "dense" )
166+ bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp8" , "fp8" , TP = 4 , EP = 1 , name = "llama4" )
167+ bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "bf16" , "mx4" , TP = 4 , EP = 1 , name = "llama4" )
168+ bench_mlp (2048 , 5120 , 8192 , 128 , 4 , "fp16" , "mx4" , TP = 4 , EP = 1 , name = "llama4" )
0 commit comments