Skip to content

Commit 0d762ef

Browse files
committed
Add INT8 basic tests (wip)
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent c30dadf commit 0d762ef

File tree

7 files changed

+216
-42
lines changed

7 files changed

+216
-42
lines changed
File renamed without changes.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import torch.nn as nn
3232

3333
# Local
34-
from fms_mo.aiu_addons.int8.int8_aiu_op import register_aiu_i8i8_op
34+
from fms_mo.aiu_addons.i8i8.i8i8_aiu_op import register_aiu_i8i8_op
3535

3636
register_aiu_i8i8_op()
3737

tests/aiu_addons/conftest.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Pytest configuration file with fixtures for add-ons functionality testing"""
15+
16+
# Third Party
17+
import pytest
18+
import torch
19+
20+
# ================================================
21+
# GPTQ W4A16 fixtures
22+
# ================================================
23+
24+
gptq_input_sizes = [
25+
{
26+
"bs": 4,
27+
"seq_len": 5,
28+
"hid_dim": 256,
29+
"out_feat": 512,
30+
"n_grp": 4,
31+
},
32+
]
33+
34+
35+
@pytest.fixture(scope="session", params=gptq_input_sizes)
36+
def get_gptq_gemm_inputs(request) -> tuple[torch.Tensor, ...]:
37+
"""pytest fixture returning test inputs for GPTQ op"""
38+
39+
sizes = request.param
40+
compression_factor = 8 # assume 4-bits compression
41+
42+
x = torch.randn(
43+
(sizes["bs"], sizes["seq_len"], sizes["hid_dim"]), dtype=torch.float16
44+
)
45+
qweight = torch.randint(
46+
low=0,
47+
high=torch.iinfo(torch.int32).max,
48+
size=(sizes["out_feat"], sizes["hid_dim"] // compression_factor),
49+
dtype=torch.int32,
50+
)
51+
qzeros = 8 * torch.ones(
52+
(sizes["n_grp"], sizes["out_feat"] // compression_factor),
53+
dtype=torch.int32,
54+
)
55+
scales = torch.randn(
56+
(sizes["n_grp"], sizes["out_feat"]),
57+
dtype=torch.float16,
58+
)
59+
g_idx = torch.zeros(sizes["hid_dim"], dtype=torch.int32)
60+
61+
return (x, qweight, qzeros, scales, g_idx)
62+
63+
64+
# ================================================
65+
# INT8xINT8 fixtures
66+
# ================================================
67+
68+
i8i8_metadata = [
69+
{
70+
"bs": 4,
71+
"seq_len": 7,
72+
"hid_dim": 256,
73+
"out_feat": 512,
74+
"dtype": torch.float16,
75+
"wtype": "per_tensor", # per_channel
76+
"atype": "per_tensor_symm", # per_tensor_asymm, per_token
77+
"smoothquant": False,
78+
}
79+
]
80+
81+
82+
@pytest.fixture(scope="session", params=i8i8_metadata)
83+
def get_i8i8_gemm_inputs(
84+
request,
85+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, str, bool]:
86+
"""pytest fixture returning test inputs for INT8xINT8 op"""
87+
88+
data = request.param
89+
x = torch.randn(
90+
(data["bs"], data["seq_len"], data["hid_dim"]),
91+
dtype=data["dtype"],
92+
).clamp(-1, 1)
93+
w_int = torch.randint(
94+
low=-8,
95+
high=8,
96+
size=(data["out_feat"], data["hid_dim"]),
97+
dtype=torch.int8,
98+
)
99+
b = torch.zeros(data["out_feat"], dtype=data["dtype"])
100+
qdata = create_qdata(
101+
data["wtype"],
102+
data["atype"],
103+
data["hid_dim"],
104+
data["out_feat"],
105+
data["smoothquant"],
106+
data["dtype"],
107+
)
108+
109+
return (x, w_int, b, qdata, data["wtype"], data["atype"], data["smoothquant"])
110+
111+
112+
def create_qdata(
113+
wtype: str,
114+
atype: str,
115+
in_feat: int,
116+
out_feat: int,
117+
smoothquant: bool,
118+
dtype: torch.dtype,
119+
) -> torch.Tensor:
120+
"""Generate dummy qdata tensor based on the provided quantization configuration"""
121+
122+
qdata_len = 2 if wtype == "per_tensor" else 2 * out_feat # weight clips
123+
qdata_len += 2 # activation clips
124+
qdata_len += out_feat if atype == "per_tensor_asymm" else 1 # zero shift
125+
qdata_len += in_feat if smoothquant else 1 # smoothquant scales
126+
127+
# TODO: improve dummy generation
128+
qdata = torch.ones(qdata_len, dtype=dtype)
129+
qdata[1] = -qdata[0] # !!! temporary solution to enforce clip symmetry
130+
qdata[3] = -qdata[2]
131+
return qdata

tests/aiu_addons/test_gptq_addon.py

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,37 @@
1-
import pytest
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Test suite for FMS addon for AIU, introducing GPTQ functionalities"""
15+
16+
# Third Party
217
import torch
318

19+
# Local
420
from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op
521

622

7-
input_sizes = [
8-
{
9-
"bs": 4,
10-
"seq_len": 32,
11-
"hid_dim": 768,
12-
"out_feat": 3072,
13-
"n_grp": 6,
14-
},
15-
]
16-
17-
18-
@pytest.fixture(params=input_sizes)
19-
def get_gptq_gemm_inputs(request):
20-
sizes = request.param
21-
compression_factor = 8 # = assume 4-bits compression
22-
23-
x = torch.randn(
24-
(sizes["bs"], sizes["seq_len"], sizes["hid_dim"]), dtype=torch.float16
25-
)
26-
qweight = torch.randint(
27-
low=0,
28-
high=torch.iinfo(torch.int32).max,
29-
size=(sizes["out_feat"], sizes["hid_dim"] // compression_factor),
30-
dtype=torch.int32,
31-
)
32-
qzeros = 8 * torch.ones(
33-
(sizes["n_grp"], sizes["out_feat"] // 8), dtype = torch.int32
34-
)
35-
scales = torch.randn(
36-
(sizes["n_grp"], sizes["out_feat"]), dtype=torch.float16,
37-
)
38-
g_idx = torch.zeros(sizes["hid_dim"], dtype=torch.int32)
39-
40-
return (x, qweight, qzeros, scales, g_idx)
41-
42-
4323
def test_gptq_registration() -> None:
44-
"""Call the registration function of GPTQ W4A16 operation, to add it.
45-
Note: registration must be called before other GPTQ tests.
24+
"""Call the registration function of GPTQ W4A16 operation, adding the op to torch
25+
namespace.
26+
Note: registration must be called before other GPTQ tests that use this op.
4627
"""
4728

4829
register_aiu_gptq_op()
4930
assert hasattr(torch.ops, "gptq_gemm")
5031
assert hasattr(torch.ops.gptq_gemm, "i4f16_fxinputs_aiu")
51-
return
5232

5333

54-
def test_gptq_op(get_gptq_gemm_inputs) -> None:
34+
def test_gptq_op(get_gptq_gemm_inputs: tuple[torch.Tensor, ...]) -> None:
5535
"""Validate output shapes of GPTQ W4A16 tensors.
5636
Note: this AIU-compatible operation only returns a zero tensor of the
5737
expected shape, it does not perform a real W4A16 matmul operation.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Test suite for FMS addon for AIU, introducing INT8xINT8 functionalities"""
15+
16+
# Third Party
17+
import torch
18+
19+
# Local
20+
from fms_mo.aiu_addons.i8i8.i8i8_aiu_op import register_aiu_i8i8_op
21+
22+
23+
def test_i8i8_registration() -> None:
24+
"""Call the registration function of INT8xINT8 operation, adding the op to torch
25+
namespace.
26+
Note: registration must be called before other INT8 tests that use this op.
27+
"""
28+
29+
register_aiu_i8i8_op()
30+
assert hasattr(torch.ops, "fms_mo")
31+
assert hasattr(torch.ops.fms_mo, "i8i8_aiu")
32+
33+
34+
def test_i8i8_op(
35+
get_i8i8_gemm_inputs: tuple[
36+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, str, bool
37+
],
38+
) -> None:
39+
"""Validate output shapes of INT8xINT8 matmul.
40+
Computations are simulated, using quantized/dequantized tensors.
41+
"""
42+
43+
(
44+
x,
45+
weight,
46+
bias,
47+
qdata,
48+
weight_quant_type,
49+
activ_quant_type,
50+
smoothquant,
51+
) = get_i8i8_gemm_inputs
52+
53+
out = torch.ops.fms_mo.i8i8_aiu(
54+
x,
55+
weight,
56+
bias,
57+
qdata,
58+
weight_quant_type,
59+
activ_quant_type,
60+
smoothquant,
61+
)
62+
63+
assert out.size() == torch.Size((x.size()[:-1] + (weight.size(0),)))

0 commit comments

Comments
 (0)