Skip to content

Commit 9d489a8

Browse files
rogerxfeng8ganyi1996ppozhuweiwincent8
authored
add dequant+gemm for compute bound int4 gemm scenario (#4958) (#4964)
Co-authored-by: Pleaplusone <[email protected]> Co-authored-by: zhuwei <[email protected]> Co-authored-by: wincent8 <[email protected]>
1 parent c1ad0e7 commit 9d489a8

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

intel_extension_for_pytorch/nn/utils/_quantize_convert.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,11 @@ def forward(self, input: Tensor) -> Tensor:
462462

463463
if xpu_gemm_use_xetla(self.force_xetla):
464464
# TODO input.shape[1] > 1 seems not work on gidx scenario, need to fix this bug
465-
if input.shape[1] > 1 and not self.force_xetla:
465+
if input.dim() == 3:
466+
m = input.size(1)
467+
else:
468+
m = input.size(0)
469+
if m > 1:
466470
return dequant_gemm_block(input, self)
467471
return torch.ops.torch_ipex.mm_low_bits(
468472
input,
@@ -578,7 +582,7 @@ def convert_qmodel_recursive(module):
578582

579583
def dequant_gemm_block(input, quant_layer, output=None):
580584
if quant_layer.g_idx is not None:
581-
input = input[:, :, quant_layer.g_idx]
585+
input = input[..., quant_layer.g_idx]
582586
if output is None:
583587
output = torch.ops.torch_ipex.mm_common(
584588
input,

0 commit comments

Comments
 (0)