Skip to content

Commit 02f5ff3

Browse files
Merge pull request #62 from andrea-fasoli/unit_test_int8
test: Unit test int8
2 parents 9fc7c75 + 7064a2a commit 02f5ff3

File tree

7 files changed

+70
-59
lines changed

7 files changed

+70
-59
lines changed

fms_mo/aiu_addons/gptq/gptq_aiu_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def get_gptq_aiu_linear(
141141
in_features: int,
142142
out_features: int,
143143
bias: bool,
144-
linear_config: Optional[Mapping[str, Any]] = None,
144+
linear_config: Mapping[str, Any],
145145
) -> torch.nn.Module:
146146
"""Retrieve a GPTQ W4A16 Linear module"""
147147

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def _int8_qparams_aiu(
4747

4848

4949
def _add_defaults_and_concat(
50-
new_sd: Mapping[str, torch.Tensor],
51-
modules_seen: set,
50+
new_sd: dict[str, torch.Tensor],
51+
modules_seen: set[str],
5252
) -> None:
5353
"""
5454
Add default activation clip values, zero_shift, and smoothquant_scale (if not

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def extract_qdata(
114114
w_in_feat: int,
115115
w_out_feat: int,
116116
smoothquant: bool,
117-
) -> tuple[torch.Tensor]:
117+
) -> tuple[torch.Tensor, ...]:
118118
"""6 tensors are to be de-concatenated from qdata:
119119
w_clip_val [ : idx1]
120120
w_clip_valn [idx1: idx2]
@@ -194,19 +194,19 @@ def quant_dequant_activ(
194194
"""
195195
if activ_quant_type == "per_tensor_symm":
196196
scale_x = 127 / a_cv
197-
x_int = torch.round(x / sq * scale_x).clamp(-127, 127)
198-
return x_int / scale_x * sq
197+
x_int = torch.round(x / sq * scale_x).clamp(-127, 127).to(torch.int8)
198+
return x_int.div(scale_x).mul(sq)
199199
if activ_quant_type == "per_tensor_asymm":
200200
scale_x = 255 / (a_cv - a_cvn)
201201
zp_x = a_cvn * scale_x
202202
x_int = torch.round(x / sq * scale_x - zp_x).clamp(0, 255)
203-
return (x_int + zp_x) / scale_x * sq
203+
return x_int.add(zp_x).div(scale_x).mul(sq)
204204
if activ_quant_type == "per_token":
205205
x_sq = x / sq
206206
a_cv_per_token = x_sq.abs().max(dim=-1, keepdim=True)[0]
207207
scale_x = 127 / a_cv_per_token
208208
x_int = torch.round(x_sq * scale_x).clamp(-127, 127)
209-
return x_int / scale_x * sq
209+
return x_int.div(scale_x).mul(sq)
210210
raise NotImplementedError(
211211
f"activation quantizantion type {activ_quant_type} is not supported"
212212
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"accelerate>=0.20.3,!=0.34,<1.4",
2727
"transformers>=4.45,<4.49",
2828
"torch>=2.2.0,<2.5",
29-
"triton>=3.0,<3.2",
29+
"triton>=3.0,<3.2",
3030
"tqdm>=4.66.2,<5.0",
3131
"datasets>=3.0.0,<4.0",
3232
"ninja>=1.11.1.1,<2.0",

tests/aiu_addons/conftest.py

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
"""Pytest configuration file with fixtures for add-ons functionality testing"""
1515

16+
# Standard
17+
from pathlib import Path
18+
1619
# Third Party
1720
import pytest
1821
import torch
@@ -67,65 +70,58 @@ def get_gptq_gemm_inputs(request) -> tuple[torch.Tensor, ...]:
6770

6871
i8i8_metadata = [
6972
{
70-
"bs": 4,
71-
"seq_len": 7,
72-
"hid_dim": 256,
73-
"out_feat": 512,
74-
"dtype": torch.float16,
7573
"wtype": "per_tensor", # per_channel
7674
"atype": "per_tensor_symm", # per_tensor_asymm, per_token
7775
"smoothquant": False,
78-
}
76+
},
77+
# {
78+
# "wtype": "per_channel", # per_channel
79+
# "atype": "per_tensor_symm", # per_tensor_asymm, per_token
80+
# "smoothquant": False,
81+
# },
7982
]
8083

8184

8285
@pytest.fixture(scope="session", params=i8i8_metadata)
8386
def get_i8i8_gemm_inputs(
8487
request,
85-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, str, bool]:
88+
) -> tuple[
89+
torch.Tensor,
90+
torch.Tensor,
91+
torch.Tensor,
92+
torch.Tensor,
93+
str,
94+
str,
95+
bool,
96+
torch.Tensor,
97+
]:
8698
"""pytest fixture returning test inputs for INT8xINT8 op"""
8799

88100
data = request.param
89-
x = torch.randn(
90-
(data["bs"], data["seq_len"], data["hid_dim"]),
91-
dtype=data["dtype"],
92-
).clamp(-1, 1)
93-
w_int = torch.randint(
94-
low=-8,
95-
high=8,
96-
size=(data["out_feat"], data["hid_dim"]),
97-
dtype=torch.int8,
101+
102+
filename = (
103+
f"ref_w-{data['wtype']}_"
104+
f"a-{data['atype']}_"
105+
f"sq-{'Y' if data['smoothquant'] else 'N'}.pt"
98106
)
99-
b = torch.zeros(data["out_feat"], dtype=data["dtype"])
100-
qdata = create_qdata(
101-
data["wtype"],
102-
data["atype"],
103-
data["hid_dim"],
104-
data["out_feat"],
105-
data["smoothquant"],
106-
data["dtype"],
107+
addon_references = Path("tests/artifacts/aiu_addons")
108+
i8i8_data = torch.load(addon_references / filename, weights_only=True)
109+
110+
assert isinstance(i8i8_data, dict)
111+
assert data["wtype"] == i8i8_data["weight_quant_type"]
112+
assert data["atype"] == i8i8_data["activ_quant_type"]
113+
assert data["smoothquant"] == i8i8_data["smoothquant"]
114+
assert all(
115+
item in i8i8_data for item in ["x", "w_int", "bias", "qdata", "reference_out"]
107116
)
108117

109-
return (x, w_int, b, qdata, data["wtype"], data["atype"], data["smoothquant"])
110-
111-
112-
def create_qdata(
113-
wtype: str,
114-
atype: str,
115-
in_feat: int,
116-
out_feat: int,
117-
smoothquant: bool,
118-
dtype: torch.dtype,
119-
) -> torch.Tensor:
120-
"""Generate dummy qdata tensor based on the provided quantization configuration"""
121-
122-
qdata_len = 2 if wtype == "per_tensor" else 2 * out_feat # weight clips
123-
qdata_len += 2 # activation clips
124-
qdata_len += out_feat if atype == "per_tensor_asymm" else 1 # zero shift
125-
qdata_len += in_feat if smoothquant else 1 # smoothquant scales
126-
127-
# TODO: improve dummy generation
128-
qdata = torch.ones(qdata_len, dtype=dtype)
129-
qdata[1] = -qdata[0] # !!! temporary solution to enforce clip symmetry
130-
qdata[3] = -qdata[2]
131-
return qdata
118+
return (
119+
i8i8_data["x"],
120+
i8i8_data["w_int"],
121+
i8i8_data["bias"],
122+
i8i8_data["qdata"],
123+
i8i8_data["weight_quant_type"],
124+
i8i8_data["activ_quant_type"],
125+
i8i8_data["smoothquant"],
126+
i8i8_data["reference_out"],
127+
)

tests/aiu_addons/test_int8_addon.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ def test_i8i8_registration() -> None:
3333

3434
def test_i8i8_op(
3535
get_i8i8_gemm_inputs: tuple[
36-
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, str, bool
36+
torch.Tensor,
37+
torch.Tensor,
38+
torch.Tensor,
39+
torch.Tensor,
40+
str,
41+
str,
42+
bool,
43+
torch.Tensor,
3744
],
3845
) -> None:
39-
"""Validate output shapes of INT8xINT8 matmul.
46+
"""Validate output shapes and content of INT8xINT8 matmul.
4047
Computations are simulated, using quantized/dequantized tensors.
4148
"""
4249

@@ -48,8 +55,13 @@ def test_i8i8_op(
4855
weight_quant_type,
4956
activ_quant_type,
5057
smoothquant,
58+
reference_out,
5159
) = get_i8i8_gemm_inputs
5260

61+
# enforce fp16 dtype on all fp parameters for this test
62+
x = x.to(torch.float16)
63+
qdata = qdata.to(torch.float16)
64+
5365
out = torch.ops.fms_mo.i8i8_aiu(
5466
x,
5567
weight,
@@ -60,4 +72,7 @@ def test_i8i8_op(
6072
smoothquant,
6173
)
6274

63-
assert out.size() == torch.Size((x.size()[:-1] + (weight.size(0),)))
75+
error_tolerance = 1e-4 # TODO: this needs adjusting
76+
assert out.size() == x.size()[:-1] + (weight.size(0),)
77+
assert torch.all((out - reference_out).abs() < error_tolerance)
78+
# assert torch.linalg.norm(out - reference_out) < error_tolerance # alternative check
Binary file not shown.

0 commit comments

Comments
 (0)