Skip to content

Commit 52e6007

Browse files
xiaowangintelmansiag05
authored andcommitted
[WOQ] Add XPU kernel for _weight_int8pack_mm (pytorch#160938)
Summary: This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA. Motivation: Same as pytorch#159325. Pull Request resolved: pytorch#160938 Approved by: https://github.com/EikanWang, https://github.com/ZhiweiYan-96, https://github.com/liangan1, https://github.com/jerryzh168
1 parent 83aa65c commit 52e6007

File tree

5 files changed

+109
-2
lines changed

5 files changed

+109
-2
lines changed

aten/src/ATen/native/mkldnn/xpu/Blas.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,4 +559,60 @@ Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
559559
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
560560
return _int_mm_out_xpu(self, mat2, result);
561561
}
562+
563+
Tensor _weight_int8pack_mm_xpu(
564+
const Tensor& A,
565+
const Tensor& B,
566+
const Tensor& scales) {
567+
auto M = A.size(0);
568+
auto N = B.size(0);
569+
auto K = A.size(1);
570+
571+
TORCH_CHECK(
572+
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
573+
" : expect A to be either 32-bit or 16-bit float tensor.");
574+
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");
575+
TORCH_CHECK(
576+
A.stride(1) == 1, " : A must be contiguous on the last dimension.");
577+
TORCH_CHECK(B.dtype() == kChar, " : expect B to be int8 tensor.");
578+
TORCH_CHECK(B.is_contiguous(), " : expect B to be contiguous.");
579+
TORCH_CHECK(B.size(1) == K, " : expect B.size(1) == ", K);
580+
581+
TORCH_CHECK(
582+
scales.dim() == 1 && scales.size(0) == N,
583+
" : expect scales to be 1d tensor with size ",
584+
N);
585+
586+
auto C = at::empty({M, N}, A.options());
587+
588+
// --- Launch kernel ---
589+
Tensor bias = at::Tensor();
590+
Tensor mat2_zero_points = at::Tensor();
591+
Tensor non_const_scales = scales;
592+
auto post_op_args = torch::List<std::optional<at::Scalar>>();
593+
594+
at::native::onednn::quantized_matmul(
595+
A.contiguous(),
596+
1.0,
597+
0,
598+
B,
599+
non_const_scales,
600+
mat2_zero_points,
601+
bias,
602+
C,
603+
1.0,
604+
0,
605+
C.scalar_type(),
606+
/*other*/ std::nullopt,
607+
/*other scale*/ 1.0,
608+
/*other zp*/ 0,
609+
/*binary post op*/ "none",
610+
/*binary alpha*/ 1.0,
611+
/*post_op_name*/ "none",
612+
post_op_args,
613+
/*post_op_algorithm*/ "none",
614+
/*m2_trans*/ false);
615+
616+
return C;
617+
}
562618
} // namespace at::native

aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,9 @@ void quantized_matmul(
110110
// [Note] Quantized Matrix Multiplication at XPU
111111
// The following code integrates oneDNN quantized gemm. The quantization
112112
// config we support:
113-
// activation: s8&u8; per tensor calibrated; symmetric&asymmetric
114-
// weight: s8; per_tensor/per_channel calibrated; symmetric
113+
// activation: s8, u8, fp16, bf16, fp32; per tensor calibrated;
114+
// symmetric&asymmetric weight: s8; per_tensor/per_channel calibrated;
115+
// symmetric
115116
auto attr = Attr(static_cast<float>(1.0 / output_scale), output_zero_point);
116117
construct_attr_by_post_op(
117118
binary_post_op,

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4243,6 +4243,7 @@
42434243
CPU: _weight_int8pack_mm_cpu
42444244
CUDA: _weight_int8pack_mm_cuda
42454245
MPS: _weight_int8pack_mm_mps
4246+
XPU: _weight_int8pack_mm_xpu
42464247

42474248
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
42484249
python_module: sparse

test/xpu/test_gemm.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
from torch.testing._internal.common_device_type import (
2020
dtypes,
2121
instantiate_device_type_tests,
22+
onlyNativeDeviceTypes,
2223
precisionOverride,
2324
)
25+
from torch.testing._internal.common_quantization import (
26+
_dynamically_quantize_per_channel,
27+
)
2428
from torch.testing._internal.common_utils import (
2529
iter_indices,
2630
parametrize,
@@ -1446,6 +1450,50 @@ def forward(self, x_1, w_1):
14461450
return out_dtype""",
14471451
)
14481452

1453+
@onlyNativeDeviceTypes
1454+
@parametrize("m", [32, 64])
1455+
@parametrize("k", [32, 64])
1456+
@parametrize("n", [48, 64])
1457+
@parametrize("compile", [True, False])
1458+
@parametrize("slice", [True, False])
1459+
def test__int8_mm(self, device, m, k, n, compile, slice):
1460+
torch.manual_seed(1)
1461+
if slice:
1462+
# logits are generated from LLaMA LM head like this -
1463+
# the activation to LM head is a slice of final hidden state
1464+
# of shape (batch_size, sequence_length, hidden dim),
1465+
# but is non-contiguous
1466+
# Using arbitrary batch-size here, since it'd be converted to 2D
1467+
batch_size = 4
1468+
a = torch.rand((batch_size, m, k), dtype=torch.bfloat16, device=device)
1469+
# Make a non-contiguous
1470+
a = a[:, -1:, :]
1471+
a = a.view(-1, a.size(-1))
1472+
else:
1473+
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
1474+
1475+
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
1476+
1477+
def convert_weight_to_int8pack(b):
1478+
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
1479+
b, -128, 127, torch.int8
1480+
)
1481+
return b_int8pack, b_scales
1482+
1483+
def weight_int8pack_mm(a, b_int8pack, b_scales):
1484+
return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
1485+
1486+
b_int8pack, b_scales = convert_weight_to_int8pack(b)
1487+
if compile:
1488+
mod = torch.compile(weight_int8pack_mm)
1489+
else:
1490+
mod = weight_int8pack_mm
1491+
res = mod(a, b_int8pack, b_scales)
1492+
ref = torch.mm(a, b.transpose(0, 1))
1493+
1494+
mean_err = ((res - ref).abs() / ref).mean()
1495+
self.assertTrue(mean_err < 0.05)
1496+
14491497

14501498
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)
14511499

torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attent
1919
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
2020
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
2121
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0);
22+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
2223
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
2324
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
2425
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);

0 commit comments

Comments
 (0)