1616# Standard
1717from dataclasses import dataclass
1818from functools import partial
19- from typing import Any , Callable , Mapping , Optional , Union
19+ from typing import Any , Callable , Optional , Union
2020import copy
2121
2222# Third Party
@@ -199,8 +199,8 @@ def __repr__(self) -> str:
199199
200200
201201def update_from_partial (
202- linear_config : Mapping [Union [str , Callable ], Any ],
203- ) -> Mapping [Union [str , Callable ], Any ]:
202+ linear_config : dict [Union [str , Callable ], Any ],
203+ ) -> dict [Union [str , Callable ], Any ]:
204204 """Update linear config parameters using those of partial callable"""
205205
206206 linear_config_updated = copy .deepcopy (linear_config )
@@ -213,7 +213,7 @@ def get_int8_aiu_linear(
213213 in_features : int ,
214214 out_features : int ,
215215 bias : bool ,
216- linear_config : Mapping [Union [str , Callable ], Any ],
216+ linear_config : dict [Union [str , Callable ], Any ],
217217 linear_type : Optional [str ] = None ,
218218 use_smoothquant : bool = False ,
219219) -> torch .nn .Module :
@@ -222,7 +222,7 @@ def get_int8_aiu_linear(
222222 # Preprocess linear_config if its linear_type field is a callable
223223 # (which would not initialize correctly the dataclass parameters).
224224 # We don't want to alter the original linear_config though.
225- linear_config_for_dataclass = None
225+ linear_config_for_dataclass : Optional [ dict [ Union [ str , Callable ], Any ]] = None
226226 if callable (linear_config ["linear_type" ]):
227227 linear_config_for_dataclass = update_from_partial (linear_config )
228228 linear_config_for_dataclass ["linear_type" ] = linear_type
0 commit comments