2424 cublas = None
2525
2626
27- def quantize (w , dtype , dev , ** opt ):
27+ def quantize (w , dtype , ** opt ):
2828 if dtype == "bf16" :
2929 wq = w .to (torch .bfloat16 ).transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
3030 return wq , InFlexData (), None
@@ -36,8 +36,9 @@ def quantize(w, dtype, dev, **opt):
3636 else :
3737 assert dtype == "mx4" , f"{ dtype = } "
3838 w , w_scale = downcast_to_mxfp (w .to (torch .bfloat16 ), torch .uint8 , axis = 1 )
39- w = convert_layout (wrap_torch_tensor (w , dtype = FP4 ), opt ["value_layout" ], ** opt ["value_layout_opts" ])
40- w_scale = convert_layout (wrap_torch_tensor (w_scale ), opt ["scale_layout" ], ** opt ["scale_layout_opts" ])
39+ if opt :
40+ w = convert_layout (wrap_torch_tensor (w , dtype = FP4 ), opt ["value_layout" ], ** opt ["value_layout_opts" ])
41+ w_scale = convert_layout (wrap_torch_tensor (w_scale ), opt ["scale_layout" ], ** opt ["scale_layout_opts" ])
4142 return w , InFlexData (), w_scale
4243
4344
@@ -109,9 +110,9 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
109110 opt1 = {"value_layout" : value_layout , "value_layout_opts" : value_layout_opts , \
110111 "scale_layout" : scale_layout , "scale_layout_opts" : scale_layout_opts }
111112 opt2 = deepcopy (opt1 )
112- wg , wg_flex , wg_scale = quantize (wg , "bf16" , dev , ** optg )
113- w1 , w1_flex , w1_scale = quantize (w1 , w_dtype , dev , ** opt1 )
114- w2 , w2_flex , w2_scale = quantize (w2 , w_dtype , dev , ** opt2 )
113+ wg , wg_flex , wg_scale = quantize (wg , "bf16" , ** optg )
114+ w1 , w1_flex , w1_scale = quantize (w1 , w_dtype , ** opt1 )
115+ w2 , w2_flex , w2_scale = quantize (w2 , w_dtype , ** opt2 )
115116 pcg = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = wg_flex ), weight_scale = wg_scale )
116117 act = FusedActivation (FnSpecs ("swiglu" , triton_kernels .swiglu .swiglu_fn , ("alpha" , "limit" )), (1.0 , 1.0 ), 2 )
117118 pc1 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w1_flex ), weight_scale = w1_scale )
0 commit comments