Skip to content

Commit c2ee3d9

Browse files
committed
lint
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent ce935e6 commit c2ee3d9

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ def forward(self, x):
190190

191191
def re_register_qdata(self) -> None:
192192
"""Remove existing self.qdata tensor and register it again as a buffer.
193-
This method is used during TP, after a module has been """
193+
This method is used during TP, after other quantization metadata have been
194+
updated.
195+
"""
196+
194197
del self.qdata
195198
self.register_buffer(
196199
"qdata",
@@ -206,7 +209,6 @@ def re_register_qdata(self) -> None:
206209
),
207210
)
208211

209-
210212
def __repr__(self) -> str:
211213
return (
212214
f"{self.__class__.__name__}"
@@ -241,7 +243,7 @@ def get_int8_aiu_linear(
241243
# Preprocess linear_config if its linear_type field is a callable
242244
# (which would not initialize correctly the dataclass parameters).
243245
# We don't want to alter the original linear_config though.
244-
linear_config_for_dataclass: Optional[dict[Union[str, Callable], Any]] = None
246+
linear_config_for_dataclass = None
245247
if callable(linear_config["linear_type"]):
246248
linear_config_for_dataclass = update_from_partial(linear_config)
247249
linear_config_for_dataclass["linear_type"] = linear_type
@@ -279,8 +281,7 @@ def is_w_clip_per_channel(
279281
def is_smoothquant_enabled(
280282
smoothquant_scale: torch.Tensor,
281283
) -> bool:
282-
"""Determine whether smoothquant is enabled on a module.
283-
"""
284+
"""Determine whether smoothquant is enabled on a module."""
284285

285286
if smoothquant_scale.dim() != 1:
286287
raise ValueError(
@@ -339,8 +340,7 @@ def shard_int8_aiu_linear(
339340
# check for every linear module if smoothquant is enabled
340341
if is_smoothquant_enabled(module_info.linear_module.smoothquant_scale):
341342
smoothquant_linear_param = LinearParameterShardingInfo(
342-
0,
343-
ShardType.SHARD if module_info.sharding_dim == 1 else ShardType.CLONE
343+
0, ShardType.SHARD if module_info.sharding_dim == 1 else ShardType.CLONE
344344
)
345345
else:
346346
smoothquant_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)

0 commit comments

Comments
 (0)