Skip to content

Commit 69e34a6

Browse files
committed
minor type fixes
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 4f18a1a commit 69e34a6

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _int8_qparams_aiu(
4747

4848

4949
def _add_defaults_and_concat(
50-
new_sd: Mapping[str, torch.Tensor],
50+
new_sd: dict[str, torch.Tensor],
5151
modules_seen: set,
5252
) -> None:
5353
"""

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Standard
1717
from dataclasses import dataclass
1818
from functools import partial
19-
from typing import Any, Callable, Mapping, Optional, Union
19+
from typing import Any, Callable, Optional, Union
2020
import copy
2121

2222
# Third Party
@@ -199,8 +199,8 @@ def __repr__(self) -> str:
199199

200200

201201
def update_from_partial(
202-
linear_config: Mapping[Union[str, Callable], Any],
203-
) -> Mapping[Union[str, Callable], Any]:
202+
linear_config: dict[Union[str, Callable], Any],
203+
) -> dict[Union[str, Callable], Any]:
204204
"""Update linear config parameters using those of partial callable"""
205205

206206
linear_config_updated = copy.deepcopy(linear_config)
@@ -213,7 +213,7 @@ def get_int8_aiu_linear(
213213
in_features: int,
214214
out_features: int,
215215
bias: bool,
216-
linear_config: Mapping[Union[str, Callable], Any],
216+
linear_config: dict[Union[str, Callable], Any],
217217
linear_type: Optional[str] = None,
218218
use_smoothquant: bool = False,
219219
) -> torch.nn.Module:
@@ -222,7 +222,7 @@ def get_int8_aiu_linear(
222222
# Preprocess linear_config if its linear_type field is a callable
223223
# (which would not initialize correctly the dataclass parameters).
224224
# We don't want to alter the original linear_config though.
225-
linear_config_for_dataclass = None
225+
linear_config_for_dataclass: Optional[dict[Union[str, Callable], Any]] = None
226226
if callable(linear_config["linear_type"]):
227227
linear_config_for_dataclass = update_from_partial(linear_config)
228228
linear_config_for_dataclass["linear_type"] = linear_type

0 commit comments

Comments
 (0)