24
24
cublas = None
25
25
26
26
27
- def quantize (w , dtype , dev , ** opt ):
27
+ def quantize (w , dtype , ** opt ):
28
28
if dtype == "bf16" :
29
29
wq = w .to (torch .bfloat16 ).transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
30
30
return wq , InFlexData (), None
@@ -36,8 +36,9 @@ def quantize(w, dtype, dev, **opt):
36
36
else :
37
37
assert dtype == "mx4" , f"{ dtype = } "
38
38
w , w_scale = downcast_to_mxfp (w .to (torch .bfloat16 ), torch .uint8 , axis = 1 )
39
- w = convert_layout (wrap_torch_tensor (w , dtype = FP4 ), opt ["value_layout" ], ** opt ["value_layout_opts" ])
40
- w_scale = convert_layout (wrap_torch_tensor (w_scale ), opt ["scale_layout" ], ** opt ["scale_layout_opts" ])
39
+ if opt :
40
+ w = convert_layout (wrap_torch_tensor (w , dtype = FP4 ), opt ["value_layout" ], ** opt ["value_layout_opts" ])
41
+ w_scale = convert_layout (wrap_torch_tensor (w_scale ), opt ["scale_layout" ], ** opt ["scale_layout_opts" ])
41
42
return w , InFlexData (), w_scale
42
43
43
44
@@ -109,9 +110,9 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
109
110
opt1 = {"value_layout" : value_layout , "value_layout_opts" : value_layout_opts , \
110
111
"scale_layout" : scale_layout , "scale_layout_opts" : scale_layout_opts }
111
112
opt2 = deepcopy (opt1 )
112
- wg , wg_flex , wg_scale = quantize (wg , "bf16" , dev , ** optg )
113
- w1 , w1_flex , w1_scale = quantize (w1 , w_dtype , dev , ** opt1 )
114
- w2 , w2_flex , w2_scale = quantize (w2 , w_dtype , dev , ** opt2 )
113
+ wg , wg_flex , wg_scale = quantize (wg , "bf16" , ** optg )
114
+ w1 , w1_flex , w1_scale = quantize (w1 , w_dtype , ** opt1 )
115
+ w2 , w2_flex , w2_scale = quantize (w2 , w_dtype , ** opt2 )
115
116
pcg = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = wg_flex ), weight_scale = wg_scale )
116
117
act = FusedActivation (FnSpecs ("swiglu" , triton_kernels .swiglu .swiglu_fn , ("alpha" , "limit" )), (1.0 , 1.0 ), 2 )
117
118
pc1 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w1_flex ), weight_scale = w1_scale )
0 commit comments