@@ -234,7 +234,9 @@ def mm_dequant_impl(
234234 out_shape = (out_shape [0 ] * out_shape [1 ], out_shape [2 ])
235235
236236 if compute_dtype not in [torch .float32 , torch .bfloat16 ]:
237- warnings .warn (f"mm_dequant_{ A .device } : compute_dtype { compute_dtype } is not supported, will use bfloat16 instead" )
237+ warnings .warn (
238+ f"mm_dequant_{ A .device } : compute_dtype { compute_dtype } is not supported, will use bfloat16 instead"
239+ )
238240 compute_dtype = torch .bfloat16
239241 A_reshaped = A .reshape (out_shape ).to (compute_dtype )
240242 row_stats = row_stats .reshape (- 1 ).unsqueeze (- 1 ).to (compute_dtype )
@@ -439,9 +441,7 @@ def dequantize_4bit_impl(
439441 raise NotImplementedError ("bnb_4bit_use_double_quant is not supported yet for CPU/XPU" )
440442
441443 if ipex_cpu_only and _ipex_cpu_version_prereq (2 , 5 ) and getattr (quant_state , "ipex" , False ):
442- A = torch .ops .ipex_prepack .woq_linear_unpack_weight (
443- A , "nf4" , quant_state .shape , 2
444- )
444+ A = torch .ops .ipex_prepack .woq_linear_unpack_weight (A , "nf4" , quant_state .shape , 2 )
445445 quant_state .ipex = False
446446
447447 # Map nf4 to [-1, 1]
@@ -466,9 +466,9 @@ def dequantize_4bit_impl(
466466 if out is None :
467467 out = torch .empty (quant_state .shape , dtype = quant_state .dtype , device = A .device )
468468 out_reshaped = out .reshape (- 1 )
469- out_reshaped [: n - rem ] = (out_dq [: n - rem ]. view ( - 1 , blocksize ) * absmax [: blocks - has_rem ]. view ( - 1 , 1 )). reshape (
470- - 1
471- )
469+ out_reshaped [: n - rem ] = (
470+ out_dq [: n - rem ]. view ( - 1 , blocksize ) * absmax [: blocks - has_rem ]. view ( - 1 , 1 )
471+ ). reshape ( - 1 )
472472 out_reshaped [n - rem :] = out_dq [n - rem :] * absmax [- 1 ]
473473 else :
474474 out = (out_dq .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )).reshape (quant_state .shape ).to (quant_state .dtype )
@@ -513,9 +513,20 @@ def gemm_4bit_impl(
513513 GEMM output tensor.
514514 """
515515 if getattr (state , "ipex" , False ):
516- output = torch .ops .torch_ipex .woq_linear (A , B , "nf4" , state .shape ,
517- state .new_scales , state .new_zeros , None , None , state .blocksize ,
518- ipex_cpu .quantization .WoqLowpMode .BF16 , 1 , state .compensation )
516+ output = torch .ops .torch_ipex .woq_linear (
517+ A ,
518+ B ,
519+ "nf4" ,
520+ state .shape ,
521+ state .new_scales ,
522+ state .new_zeros ,
523+ None ,
524+ None ,
525+ state .blocksize ,
526+ ipex_cpu .quantization .WoqLowpMode .BF16 ,
527+ 1 ,
528+ state .compensation ,
529+ )
519530 else :
520531 dqB = dequantize_4bit_impl (B , state , blocksize = state .blocksize ).t ()
521532 output = torch .matmul (A , dqB .to (A .dtype ))
0 commit comments