Skip to content

Commit 3c07b46

Browse files
Merge pull request #68 from andrea-fasoli/fix_int8_linear_type
fix: handle linear_type callable at int8 linear instantiation
2 parents ed76001 + 973ac6d commit 3c07b46

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _int8_qparams_aiu(
3737

3838
param_type = "w" if is_weight else "a"
3939
new_name = f"{module_name}.{param_type}_{name_split[-1]}"
40-
elif "smoothq" in name:
40+
elif "smoothq" in name and "smoothquant" not in name:
4141
new_name = name.replace("smoothq", "smoothquant")
4242

4343
new_sd[new_name] = param

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
# Standard
1717
from dataclasses import dataclass
1818
from functools import partial
19-
from typing import Any, Mapping, Optional
19+
from typing import Any, Callable, Optional, Union
20+
import copy
2021

2122
# Third Party
2223
from fms.modules.linear import (
@@ -197,16 +198,38 @@ def __repr__(self) -> str:
197198
)
198199

199200

201+
def update_from_partial(
202+
linear_config: dict[Union[str, Callable], Any],
203+
) -> dict[Union[str, Callable], Any]:
204+
"""Update linear config parameters using those of partial callable"""
205+
206+
linear_config_updated = copy.deepcopy(linear_config)
207+
for k, v in linear_config["linear_type"].keywords.items():
208+
linear_config_updated[k] = v
209+
return linear_config_updated
210+
211+
200212
def get_int8_aiu_linear(
201213
in_features: int,
202214
out_features: int,
203215
bias: bool,
204-
linear_config: Optional[Mapping[str, Any]] = None,
216+
linear_config: dict[Union[str, Callable], Any],
217+
linear_type: Optional[str] = None,
205218
use_smoothquant: bool = False,
206219
) -> torch.nn.Module:
207220
"""Retrieve a W8A8 Linear module"""
208221

209-
int8_config = W8A8LinearConfig(**linear_config)
222+
# Preprocess linear_config if its linear_type field is a callable
223+
# (which would not initialize correctly the dataclass parameters).
224+
# We don't want to alter the original linear_config though.
225+
linear_config_for_dataclass: Optional[dict[Union[str, Callable], Any]] = None
226+
if callable(linear_config["linear_type"]):
227+
linear_config_for_dataclass = update_from_partial(linear_config)
228+
linear_config_for_dataclass["linear_type"] = linear_type
229+
if not linear_config_for_dataclass:
230+
linear_config_for_dataclass = linear_config
231+
232+
int8_config = W8A8LinearConfig(**linear_config_for_dataclass)
210233
linear = W8A8LinearAIU(
211234
in_features=in_features,
212235
out_features=out_features,
@@ -281,9 +304,20 @@ def shard_int8_aiu_linear(
281304
# return unused_keys
282305

283306

284-
register_linear_type_to_module_map("int8_aiu", get_int8_aiu_linear)
307+
register_linear_type_to_module_map(
308+
"int8_aiu",
309+
partial(
310+
get_int8_aiu_linear,
311+
linear_type="int8_aiu",
312+
use_smoothquant=False,
313+
),
314+
)
285315
register_linear_type_to_module_map(
286316
"int8_smoothquant_aiu",
287-
partial(get_int8_aiu_linear, use_smoothquant=True),
317+
partial(
318+
get_int8_aiu_linear,
319+
linear_type="int8_smoothquant_aiu",
320+
use_smoothquant=True,
321+
),
288322
)
289323
register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear)

0 commit comments

Comments
 (0)