Skip to content

Commit 9fc7c75

Browse files
Merge pull request #65 from andrea-fasoli/int8_smoothquant
feat: Support for int8 smoothquant
2 parents 608068d + fb870bd commit 9fc7c75

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def _add_defaults_and_concat(
103103
"gpt_bigcode", "int8_qparams_aiu", _int8_qparams_aiu
104104
)
105105
serialization.register_adapter_step("roberta", "int8_qparams_aiu", _int8_qparams_aiu)
106+
serialization.register_adapter_step(
107+
"roberta_question_answering",
108+
"int8_qparams_aiu",
109+
_int8_qparams_aiu,
110+
)
106111

107112
# registration of multi-step adapter for each architecture
108113
serialization.register_adapter(
@@ -121,3 +126,12 @@ def _add_defaults_and_concat(
121126
serialization.register_adapter(
122127
"roberta", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
123128
)
129+
serialization.register_adapter(
130+
"roberta_question_answering",
131+
"fms_mo",
132+
[
133+
"hf_to_fms_names",
134+
"weight_fusion",
135+
"int8_qparams_aiu",
136+
],
137+
)

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# Standard
1717
from dataclasses import dataclass
18+
from functools import partial
1819
from typing import Any, Mapping, Optional
1920

2021
# Third Party
@@ -201,7 +202,7 @@ def get_int8_aiu_linear(
201202
out_features: int,
202203
bias: bool,
203204
linear_config: Optional[Mapping[str, Any]] = None,
204-
use_smoothquant: bool = True,
205+
use_smoothquant: bool = False,
205206
) -> torch.nn.Module:
206207
"""Retrieve a W8A8 Linear module"""
207208

@@ -281,4 +282,8 @@ def shard_int8_aiu_linear(
281282

282283

283284
register_linear_type_to_module_map("int8_aiu", get_int8_aiu_linear)
285+
register_linear_type_to_module_map(
286+
"int8_smoothquant_aiu",
287+
partial(get_int8_aiu_linear, use_smoothquant=True),
288+
)
284289
register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear)

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def i8i8_aiu(
8484
x_dq = quant_dequant_activ(x, a_cv, a_cvn, sq, activ_quant_type)
8585
w_dq = dequant_weights(weight, w_cv, sq, weight_quant_type)
8686

87-
return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias)
87+
return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype))
8888

8989
@torch.library.impl_abstract(op_namespace_id)
9090
def i8i8_aiu_abstract(

0 commit comments

Comments
 (0)