File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
intel_extension_for_pytorch/nn/utils Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -462,7 +462,11 @@ def forward(self, input: Tensor) -> Tensor:
462462
463463 if xpu_gemm_use_xetla (self .force_xetla ):
464464 # TODO input.shape[1] > 1 seems not work on gidx scenario, need to fix this bug
465- if input .shape [1 ] > 1 and not self .force_xetla :
465+ if input .dim () == 3 :
466+ m = input .size (1 )
467+ else :
468+ m = input .size (0 )
469+ if m > 1 :
466470 return dequant_gemm_block (input , self )
467471 return torch .ops .torch_ipex .mm_low_bits (
468472 input ,
@@ -578,7 +582,7 @@ def convert_qmodel_recursive(module):
578582
579583def dequant_gemm_block (input , quant_layer , output = None ):
580584 if quant_layer .g_idx is not None :
581- input = input [:, : , quant_layer .g_idx ]
585+ input = input [... , quant_layer .g_idx ]
582586 if output is None :
583587 output = torch .ops .torch_ipex .mm_common (
584588 input ,
You can’t perform that action at this time.
0 commit comments