@@ -190,7 +190,10 @@ def forward(self, x):
190190
191191 def re_register_qdata (self ) -> None :
192192 """Remove existing self.qdata tensor and register it again as a buffer.
193- This method is used during TP, after a module has been """
193+ This method is used during TP, after other quantization metadata have been
194+ updated.
195+ """
196+
194197 del self .qdata
195198 self .register_buffer (
196199 "qdata" ,
@@ -206,7 +209,6 @@ def re_register_qdata(self) -> None:
206209 ),
207210 )
208211
209-
210212 def __repr__ (self ) -> str :
211213 return (
212214 f"{ self .__class__ .__name__ } "
@@ -241,7 +243,7 @@ def get_int8_aiu_linear(
241243 # Preprocess linear_config if its linear_type field is a callable
242244 # (which would not initialize correctly the dataclass parameters).
243245 # We don't want to alter the original linear_config though.
244- linear_config_for_dataclass : Optional [ dict [ Union [ str , Callable ], Any ]] = None
246+ linear_config_for_dataclass = None
245247 if callable (linear_config ["linear_type" ]):
246248 linear_config_for_dataclass = update_from_partial (linear_config )
247249 linear_config_for_dataclass ["linear_type" ] = linear_type
@@ -279,8 +281,7 @@ def is_w_clip_per_channel(
279281def is_smoothquant_enabled (
280282 smoothquant_scale : torch .Tensor ,
281283) -> bool :
282- """Determine whether smoothquant is enabled on a module.
283- """
284+ """Determine whether smoothquant is enabled on a module."""
284285
285286 if smoothquant_scale .dim () != 1 :
286287 raise ValueError (
@@ -339,8 +340,7 @@ def shard_int8_aiu_linear(
339340 # check for every linear module if smoothquant is enabled
340341 if is_smoothquant_enabled (module_info .linear_module .smoothquant_scale ):
341342 smoothquant_linear_param = LinearParameterShardingInfo (
342- 0 ,
343- ShardType .SHARD if module_info .sharding_dim == 1 else ShardType .CLONE
343+ 0 , ShardType .SHARD if module_info .sharding_dim == 1 else ShardType .CLONE
344344 )
345345 else :
346346 smoothquant_linear_param = LinearParameterShardingInfo (0 , ShardType .CLONE )
0 commit comments