File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -240,10 +240,16 @@ def enable_ipex_fusion(linear, x):
240240 )
241241 elif x .device .type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq (2 , 5 ):
242242 converted_weight = reverse_4bit_compress_format (linear .weight .data )
243- new_weight = converted_weight .reshape ([quant_state .shape [0 ], quant_state .shape [1 ] // 2 ])
244243 new_scales = quant_state .absmax .view (quant_state .shape [0 ], quant_state .shape [1 ] // quant_state .blocksize )
245244 new_zeros = None
246245 compensation = None
246+ new_weight = converted_weight .reshape ([quant_state .shape [0 ], quant_state .shape [1 ] // 2 ])
247+ # ipex 2.7 requires new_scales is a list of tensors
248+ if _ipex_xpu_version_prereq (2 , 7 ):
249+ new_scales = list (new_scales )
250+ # ipex 2.7 can dequant converted_weight directly.
251+ if linear .training or x .requires_grad == False :
252+ new_weight = converted_weight
247253 else :
248254 raise ValueError (
249255 "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"
You can’t perform that action at this time.
0 commit comments