@@ -35,23 +35,26 @@ def process_weights_after_loading(self, layer) -> None:
3535 # tensor scales (thus N scales being passed to the kernel),
3636 # requantize so we can always run per tensor
3737 if self .qscheme == "per_tensor" :
38- max_w_scale , weight = requantize_with_max_scale (
39- weight = layer .weight ,
40- weight_scale = layer .weight_scale ,
41- logical_widths = layer .logical_widths ,
42- )
43-
4438 if current_platform .is_rocm ():
4539 weight , max_w_scale , input_scale = normalize_e4m3fn_to_e4m3fnuz (
46- weight = weight ,
47- weight_scale = max_w_scale ,
40+ weight = layer . weight ,
41+ weight_scale = layer . weight_scale ,
4842 input_scale = layer .input_scale )
49- if input_scale is not None :
50- layer .input_scale = Parameter (input_scale ,
51- requires_grad = False )
43+ else :
44+ max_w_scale = layer .weight_scale
45+ weight = layer .weight
46+ input_scale = layer .input_scape
47+
48+ max_w_scale , weight = requantize_with_max_scale (
49+ weight = weight ,
50+ weight_scale = max_w_scale ,
51+ logical_widths = layer .logical_widths ,
52+ )
5253
5354 layer .weight = Parameter (weight .t (), requires_grad = False )
5455 layer .weight_scale = Parameter (max_w_scale , requires_grad = False )
56+ if input_scale is not None :
57+ layer .input_scale = Parameter (input_scale , requires_grad = False )
5558
5659 # If channelwise, scales are already lined up, so just transpose.
5760 elif self .qscheme == "per_channel" :
0 commit comments