|
16 | 16 | # Standard |
17 | 17 | from dataclasses import dataclass |
18 | 18 | from functools import partial |
19 | | -from typing import Any, Mapping, Optional |
| 19 | +from typing import Any, Callable, Optional, Union |
| 20 | +import copy |
20 | 21 |
|
21 | 22 | # Third Party |
22 | 23 | from fms.modules.linear import ( |
@@ -197,16 +198,38 @@ def __repr__(self) -> str: |
197 | 198 | ) |
198 | 199 |
|
199 | 200 |
|
| 201 | +def update_from_partial( |
| 202 | + linear_config: dict[Union[str, Callable], Any], |
| 203 | +) -> dict[Union[str, Callable], Any]: |
| 204 | + """Update linear config parameters using those of partial callable""" |
| 205 | + |
| 206 | + linear_config_updated = copy.deepcopy(linear_config) |
| 207 | + for k, v in linear_config["linear_type"].keywords.items(): |
| 208 | + linear_config_updated[k] = v |
| 209 | + return linear_config_updated |
| 210 | + |
| 211 | + |
200 | 212 | def get_int8_aiu_linear( |
201 | 213 | in_features: int, |
202 | 214 | out_features: int, |
203 | 215 | bias: bool, |
204 | | - linear_config: Optional[Mapping[str, Any]] = None, |
| 216 | + linear_config: dict[Union[str, Callable], Any], |
| 217 | + linear_type: Optional[str] = None, |
205 | 218 | use_smoothquant: bool = False, |
206 | 219 | ) -> torch.nn.Module: |
207 | 220 | """Retrieve a W8A8 Linear module""" |
208 | 221 |
|
209 | | - int8_config = W8A8LinearConfig(**linear_config) |
| 222 | + # Preprocess linear_config if its linear_type field is a callable |
| 223 | + # (which would not initialize correctly the dataclass parameters). |
| 224 | + # We don't want to alter the original linear_config though. |
| 225 | + linear_config_for_dataclass: Optional[dict[Union[str, Callable], Any]] = None |
| 226 | + if callable(linear_config["linear_type"]): |
| 227 | + linear_config_for_dataclass = update_from_partial(linear_config) |
| 228 | + linear_config_for_dataclass["linear_type"] = linear_type |
| 229 | + if not linear_config_for_dataclass: |
| 230 | + linear_config_for_dataclass = linear_config |
| 231 | + |
| 232 | + int8_config = W8A8LinearConfig(**linear_config_for_dataclass) |
210 | 233 | linear = W8A8LinearAIU( |
211 | 234 | in_features=in_features, |
212 | 235 | out_features=out_features, |
@@ -281,9 +304,20 @@ def shard_int8_aiu_linear( |
281 | 304 | # return unused_keys |
282 | 305 |
|
283 | 306 |
|
284 | | -register_linear_type_to_module_map("int8_aiu", get_int8_aiu_linear) |
| 307 | +register_linear_type_to_module_map( |
| 308 | + "int8_aiu", |
| 309 | + partial( |
| 310 | + get_int8_aiu_linear, |
| 311 | + linear_type="int8_aiu", |
| 312 | + use_smoothquant=False, |
| 313 | + ), |
| 314 | +) |
285 | 315 | register_linear_type_to_module_map( |
286 | 316 | "int8_smoothquant_aiu", |
287 | | - partial(get_int8_aiu_linear, use_smoothquant=True), |
| 317 | + partial( |
| 318 | + get_int8_aiu_linear, |
| 319 | + linear_type="int8_smoothquant_aiu", |
| 320 | + use_smoothquant=True, |
| 321 | + ), |
288 | 322 | ) |
289 | 323 | register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear) |
0 commit comments