Skip to content

Commit b29b371

Browse files
committed
clean up
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 441d375 commit b29b371

File tree

6 files changed

+22
-42
lines changed

6 files changed

+22
-42
lines changed

fms_mo/aiu_addons/gptq/gptq_aiu_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def get_gptq_aiu_linear(
141141
in_features: int,
142142
out_features: int,
143143
bias: bool,
144-
linear_config: Optional[Mapping[str, Any]] = None,
144+
linear_config: Mapping[str, Any],
145145
) -> torch.nn.Module:
146146
"""Retrieve a GPTQ W4A16 Linear module"""
147147

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def _int8_qparams_aiu(
4747

4848

4949
def _add_defaults_and_concat(
50-
new_sd: Mapping[str, torch.Tensor],
51-
modules_seen: set,
50+
new_sd: dict[str, torch.Tensor],
51+
modules_seen: set[str],
5252
) -> None:
5353
"""
5454
Add default activation clip values, zero_shift, and smoothquant_scale (if not

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_int8_aiu_linear(
200200
in_features: int,
201201
out_features: int,
202202
bias: bool,
203-
linear_config: Optional[Mapping[str, Any]] = None,
203+
linear_config: Mapping[str, Any],
204204
use_smoothquant: bool = True,
205205
) -> torch.nn.Module:
206206
"""Retrieve a W8A8 Linear module"""

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def extract_qdata(
114114
w_in_feat: int,
115115
w_out_feat: int,
116116
smoothquant: bool,
117-
) -> tuple[torch.Tensor]:
117+
) -> tuple[torch.Tensor, ...]:
118118
"""6 tensors are to be de-concatenated from qdata:
119119
w_clip_val [ : idx1]
120120
w_clip_valn [idx1: idx2]
@@ -195,18 +195,18 @@ def quant_dequant_activ(
195195
if activ_quant_type == "per_tensor_symm":
196196
scale_x = 127 / a_cv
197197
x_int = torch.round(x / sq * scale_x).clamp(-127, 127).to(torch.int8)
198-
return x_int / scale_x * sq
198+
return x_int.div(scale_x).mul(sq)
199199
if activ_quant_type == "per_tensor_asymm":
200200
scale_x = 255 / (a_cv - a_cvn)
201201
zp_x = a_cvn * scale_x
202202
x_int = torch.round(x / sq * scale_x - zp_x).clamp(0, 255)
203-
return (x_int + zp_x) / scale_x * sq
203+
return x_int.add(zp_x).div(scale_x).mul(sq)
204204
if activ_quant_type == "per_token":
205205
x_sq = x / sq
206206
a_cv_per_token = x_sq.abs().max(dim=-1, keepdim=True)[0]
207207
scale_x = 127 / a_cv_per_token
208208
x_int = torch.round(x_sq * scale_x).clamp(-127, 127)
209-
return x_int / scale_x * sq
209+
return x_int.div(scale_x).mul(sq)
210210
raise NotImplementedError(
211211
f"activation quantizantion type {activ_quant_type} is not supported"
212212
)

tests/aiu_addons/conftest.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# limitations under the License.
1414
"""Pytest configuration file with fixtures for add-ons functionality testing"""
1515

16+
# Standard
17+
from pathlib import Path
18+
1619
# Third Party
1720
import pytest
1821
import torch
19-
from pathlib import Path
2022

2123
# ================================================
2224
# GPTQ W4A16 fixtures
@@ -84,15 +86,15 @@ def get_gptq_gemm_inputs(request) -> tuple[torch.Tensor, ...]:
8486
def get_i8i8_gemm_inputs(
8587
request,
8688
) -> tuple[
87-
torch.Tensor,
88-
torch.Tensor,
89-
torch.Tensor,
90-
torch.Tensor,
91-
str,
92-
str,
93-
bool,
94-
torch.Tensor,
95-
]:
89+
torch.Tensor,
90+
torch.Tensor,
91+
torch.Tensor,
92+
torch.Tensor,
93+
str,
94+
str,
95+
bool,
96+
torch.Tensor,
97+
]:
9698
"""pytest fixture returning test inputs for INT8xINT8 op"""
9799

98100
data = request.param
@@ -110,7 +112,7 @@ def get_i8i8_gemm_inputs(
110112
assert data["atype"] == i8i8_data["activ_quant_type"]
111113
assert data["smoothquant"] == i8i8_data["smoothquant"]
112114
assert all(
113-
[item in i8i8_data for item in ["x", "w_int", "bias", "qdata", "reference_out"]]
115+
item in i8i8_data for item in ["x", "w_int", "bias", "qdata", "reference_out"]
114116
)
115117

116118
return (
@@ -123,25 +125,3 @@ def get_i8i8_gemm_inputs(
123125
i8i8_data["smoothquant"],
124126
i8i8_data["reference_out"],
125127
)
126-
127-
128-
def create_qdata(
129-
wtype: str,
130-
atype: str,
131-
in_feat: int,
132-
out_feat: int,
133-
smoothquant: bool,
134-
dtype: torch.dtype,
135-
) -> torch.Tensor:
136-
"""Generate dummy qdata tensor based on the provided quantization configuration"""
137-
138-
qdata_len = 2 if wtype == "per_tensor" else 2 * out_feat # weight clips
139-
qdata_len += 2 # activation clips
140-
qdata_len += out_feat if atype == "per_tensor_asymm" else 1 # zero shift
141-
qdata_len += in_feat if smoothquant else 1 # smoothquant scales
142-
143-
# TODO: improve dummy generation
144-
qdata = torch.ones(qdata_len, dtype=dtype)
145-
qdata[1] = -qdata[0] # !!! temporary solution to enforce clip symmetry
146-
qdata[3] = -qdata[2]
147-
return qdata

tests/aiu_addons/test_int8_addon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@ def test_i8i8_op(
7575
error_tolerance = 1e-4 # TODO: this needs adjusting
7676
assert out.size() == x.size()[:-1] + (weight.size(0),)
7777
assert torch.all((out - reference_out).abs() < error_tolerance)
78-
assert torch.linalg.norm(out - reference_out) < error_tolerance # alternative check
78+
# assert torch.linalg.norm(out - reference_out) < error_tolerance # alternative check

0 commit comments

Comments
 (0)