Skip to content

Commit 7064a2a

Browse files
Merge branch 'main' into unit_test_int8
Signed-off-by: chichun-charlie-liu <[email protected]>
2 parents 70348a8 + 9fc7c75 commit 7064a2a

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def _add_defaults_and_concat(
103103
"gpt_bigcode", "int8_qparams_aiu", _int8_qparams_aiu
104104
)
105105
serialization.register_adapter_step("roberta", "int8_qparams_aiu", _int8_qparams_aiu)
106+
serialization.register_adapter_step(
107+
"roberta_question_answering",
108+
"int8_qparams_aiu",
109+
_int8_qparams_aiu,
110+
)
106111

107112
# registration of multi-step adapter for each architecture
108113
serialization.register_adapter(
@@ -121,3 +126,12 @@ def _add_defaults_and_concat(
121126
serialization.register_adapter(
122127
"roberta", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
123128
)
129+
serialization.register_adapter(
130+
"roberta_question_answering",
131+
"fms_mo",
132+
[
133+
"hf_to_fms_names",
134+
"weight_fusion",
135+
"int8_qparams_aiu",
136+
],
137+
)

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# Standard
1717
from dataclasses import dataclass
18+
from functools import partial
1819
from typing import Any, Mapping, Optional
1920

2021
# Third Party
@@ -200,8 +201,8 @@ def get_int8_aiu_linear(
200201
in_features: int,
201202
out_features: int,
202203
bias: bool,
203-
linear_config: Mapping[str, Any],
204-
use_smoothquant: bool = True,
204+
linear_config: Optional[Mapping[str, Any]] = None,
205+
use_smoothquant: bool = False,
205206
) -> torch.nn.Module:
206207
"""Retrieve a W8A8 Linear module"""
207208

@@ -281,4 +282,8 @@ def shard_int8_aiu_linear(
281282

282283

283284
register_linear_type_to_module_map("int8_aiu", get_int8_aiu_linear)
285+
register_linear_type_to_module_map(
286+
"int8_smoothquant_aiu",
287+
partial(get_int8_aiu_linear, use_smoothquant=True),
288+
)
284289
register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear)

0 commit comments

Comments
 (0)