|
19 | 19 | from torch.testing._internal.common_device_type import (
|
20 | 20 | dtypes,
|
21 | 21 | instantiate_device_type_tests,
|
| 22 | + onlyNativeDeviceTypes, |
22 | 23 | precisionOverride,
|
23 | 24 | )
|
| 25 | +from torch.testing._internal.common_quantization import ( |
| 26 | + _dynamically_quantize_per_channel, |
| 27 | +) |
24 | 28 | from torch.testing._internal.common_utils import (
|
25 | 29 | iter_indices,
|
26 | 30 | parametrize,
|
@@ -1446,6 +1450,50 @@ def forward(self, x_1, w_1):
|
1446 | 1450 | return out_dtype""",
|
1447 | 1451 | )
|
1448 | 1452 |
|
| 1453 | + @onlyNativeDeviceTypes |
| 1454 | + @parametrize("m", [32, 64]) |
| 1455 | + @parametrize("k", [32, 64]) |
| 1456 | + @parametrize("n", [48, 64]) |
| 1457 | + @parametrize("compile", [True, False]) |
| 1458 | + @parametrize("slice", [True, False]) |
| 1459 | + def test__int8_mm(self, device, m, k, n, compile, slice): |
| 1460 | + torch.manual_seed(1) |
| 1461 | + if slice: |
| 1462 | + # logits are generated from LLaMA LM head like this - |
| 1463 | + # the activation to LM head is a slice of final hidden state |
| 1464 | + # of shape (batch_size, sequence_length, hidden dim), |
| 1465 | + # but is non-contiguous |
| 1466 | + # Using arbitrary batch-size here, since it'd be converted to 2D |
| 1467 | + batch_size = 4 |
| 1468 | + a = torch.rand((batch_size, m, k), dtype=torch.bfloat16, device=device) |
| 1469 | + # Make a non-contiguous |
| 1470 | + a = a[:, -1:, :] |
| 1471 | + a = a.view(-1, a.size(-1)) |
| 1472 | + else: |
| 1473 | + a = torch.rand((m, k), dtype=torch.bfloat16, device=device) |
| 1474 | + |
| 1475 | + b = torch.rand((n, k), dtype=torch.bfloat16, device=device) |
| 1476 | + |
| 1477 | + def convert_weight_to_int8pack(b): |
| 1478 | + b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( |
| 1479 | + b, -128, 127, torch.int8 |
| 1480 | + ) |
| 1481 | + return b_int8pack, b_scales |
| 1482 | + |
| 1483 | + def weight_int8pack_mm(a, b_int8pack, b_scales): |
| 1484 | + return torch._weight_int8pack_mm(a, b_int8pack, b_scales) |
| 1485 | + |
| 1486 | + b_int8pack, b_scales = convert_weight_to_int8pack(b) |
| 1487 | + if compile: |
| 1488 | + mod = torch.compile(weight_int8pack_mm) |
| 1489 | + else: |
| 1490 | + mod = weight_int8pack_mm |
| 1491 | + res = mod(a, b_int8pack, b_scales) |
| 1492 | + ref = torch.mm(a, b.transpose(0, 1)) |
| 1493 | + |
| 1494 | + mean_err = ((res - ref).abs() / ref).mean() |
| 1495 | + self.assertTrue(mean_err < 0.05) |
| 1496 | + |
1449 | 1497 |
|
1450 | 1498 | instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)
|
1451 | 1499 |
|
|
0 commit comments