Skip to content

Commit c30dadf

Browse files
committed
Add GPTQ W4A16 initial tests
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent f858ffa commit c30dadf

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

tests/aiu_addons/__init__.py

Whitespace-only changes.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
import torch
3+
4+
from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op
5+
6+
7+
input_sizes = [
8+
{
9+
"bs": 4,
10+
"seq_len": 32,
11+
"hid_dim": 768,
12+
"out_feat": 3072,
13+
"n_grp": 6,
14+
},
15+
]
16+
17+
18+
@pytest.fixture(params=input_sizes)
19+
def get_gptq_gemm_inputs(request):
20+
sizes = request.param
21+
compression_factor = 8 # = assume 4-bits compression
22+
23+
x = torch.randn(
24+
(sizes["bs"], sizes["seq_len"], sizes["hid_dim"]), dtype=torch.float16
25+
)
26+
qweight = torch.randint(
27+
low=0,
28+
high=torch.iinfo(torch.int32).max,
29+
size=(sizes["out_feat"], sizes["hid_dim"] // compression_factor),
30+
dtype=torch.int32,
31+
)
32+
qzeros = 8 * torch.ones(
33+
(sizes["n_grp"], sizes["out_feat"] // 8), dtype = torch.int32
34+
)
35+
scales = torch.randn(
36+
(sizes["n_grp"], sizes["out_feat"]), dtype=torch.float16,
37+
)
38+
g_idx = torch.zeros(sizes["hid_dim"], dtype=torch.int32)
39+
40+
return (x, qweight, qzeros, scales, g_idx)
41+
42+
43+
def test_gptq_registration() -> None:
44+
"""Call the registration function of GPTQ W4A16 operation, to add it.
45+
Note: registration must be called before other GPTQ tests.
46+
"""
47+
48+
register_aiu_gptq_op()
49+
assert hasattr(torch.ops, "gptq_gemm")
50+
assert hasattr(torch.ops.gptq_gemm, "i4f16_fxinputs_aiu")
51+
return
52+
53+
54+
def test_gptq_op(get_gptq_gemm_inputs) -> None:
55+
"""Validate output shapes of GPTQ W4A16 tensors.
56+
Note: this AIU-compatible operation only returns a zero tensor of the
57+
expected shape, it does not perform a real W4A16 matmul operation.
58+
"""
59+
60+
x, qweight, qzeros, scales, g_idx = get_gptq_gemm_inputs
61+
out = torch.ops.gptq_gemm.i4f16_fxinputs_aiu(x, qweight, qzeros, scales, g_idx)
62+
assert out.size() == torch.Size((x.size()[:-1] + (qweight.size(0),)))

0 commit comments

Comments
 (0)