@@ -188,6 +188,27 @@ def forward(self, x):
188188 self .smoothquant ,
189189 )
190190
191+ def re_register_qdata (self ) -> None :
192+ """Remove existing self.qdata tensor and register it again as a buffer.
193+ This method is used during TP, after other quantization metadata have been
194+ updated.
195+ """
196+
197+ del self .qdata
198+ self .register_buffer (
199+ "qdata" ,
200+ torch .cat (
201+ (
202+ self .w_clip_val ,
203+ self .w_clip_valn ,
204+ self .a_clip_val ,
205+ self .a_clip_valn ,
206+ self .zero_shift ,
207+ self .smoothquant_scale ,
208+ )
209+ ),
210+ )
211+
191212 def __repr__ (self ) -> str :
192213 return (
193214 f"{ self .__class__ .__name__ } "
@@ -222,7 +243,7 @@ def get_int8_aiu_linear(
222243 # Preprocess linear_config if its linear_type field is a callable
223244 # (which would not initialize correctly the dataclass parameters).
224245 # We don't want to alter the original linear_config though.
225- linear_config_for_dataclass : Optional [ dict [ Union [ str , Callable ], Any ]] = None
246+ linear_config_for_dataclass = None
226247 if callable (linear_config ["linear_type" ]):
227248 linear_config_for_dataclass = update_from_partial (linear_config )
228249 linear_config_for_dataclass ["linear_type" ] = linear_type
@@ -240,6 +261,36 @@ def get_int8_aiu_linear(
240261 return linear
241262
242263
264+ def is_w_clip_per_channel (
265+ w_clip : torch .Tensor ,
266+ ) -> bool :
267+ """Determine whether the weight clip value in use for INT8 quantization of the
268+ provided linear module is:
269+ - per-tensor (1 element, 1-dim tensor), or
270+ - per-channel (out_feat elements, 1-dim tensor).
271+ """
272+
273+ if w_clip .dim () != 1 :
274+ raise ValueError (
275+ f"TP error: weight clip value dimensions { str (list (w_clip .size ()))} are "
276+ "incompatible with expected per-tensor or per-channel quantization."
277+ )
278+ return w_clip .numel () > 1
279+
280+
281+ def is_smoothquant_enabled (
282+ smoothquant_scale : torch .Tensor ,
283+ ) -> bool :
284+ """Determine whether smoothquant is enabled on a module."""
285+
286+ if smoothquant_scale .dim () != 1 :
287+ raise ValueError (
288+ "TP error: smoothquant_scale array should always be 1-dimensional but "
289+ f"has size { str (list (smoothquant_scale .size ()))} "
290+ )
291+ return smoothquant_scale .numel () > 1
292+
293+
243294def shard_int8_aiu_linear (
244295 tensor_values : dict [str , torch .Tensor ],
245296 tp_module : TPModule ,
@@ -259,49 +310,73 @@ def shard_int8_aiu_linear(
259310 | bias | 0 | - |
260311 | others* | N | - |
261312
262- Other quantization parameters: w_clip_val, w_clip_valn,
263- a_clip_val, a_clip_valn, zero_shift, smoothquant_scale
264- No sharding on all these parameters, except w_clip_val and w_clip_valn when
265- per-channel quantization is used
313+ Other quantization parameters: w_clip_val, w_clip_valn, a_clip_val, a_clip_valn,
314+ zero_shift, smoothquant_scale
315+
316+ No sharding on any of these parameters (they are CLONED on each rank), with the
317+ exception of:
318+ - w_clip_val and w_clip_valn, only column-sharding and only when per-channel
319+ quantization is used
320+ - smoothquant_scale, only row-sharding and only if smoothquant in use
321+
322+ These parameters are 1-dimensional, so if sharding is needed, it is always applied
323+ on dim=0.
266324 """
325+
267326 param_sharding_info : dict [str , dict [str , LinearParameterShardingInfo ]] = {}
327+ w_clip_linear_param = None
268328 for module_name , module_info in module_sharding_info .items ():
269- int8_aiu_mod = module_info .linear_module
329+ int8_aiu_module = module_info .linear_module
330+
331+ # check every module if per-channel in use (sharding depends on module)
332+ if is_w_clip_per_channel (module_info .linear_module .w_clip_val ):
333+ w_clip_linear_param = LinearParameterShardingInfo (
334+ 0 ,
335+ ShardType .SHARD if module_info .sharding_dim == 0 else ShardType .CLONE ,
336+ )
337+ else :
338+ w_clip_linear_param = LinearParameterShardingInfo (0 , ShardType .CLONE )
339+
340+ # check for every linear module if smoothquant is enabled
341+ if is_smoothquant_enabled (module_info .linear_module .smoothquant_scale ):
342+ smoothquant_linear_param = LinearParameterShardingInfo (
343+ 0 , ShardType .SHARD if module_info .sharding_dim == 1 else ShardType .CLONE
344+ )
345+ else :
346+ smoothquant_linear_param = LinearParameterShardingInfo (0 , ShardType .CLONE )
347+
270348 params : dict [str , LinearParameterShardingInfo ] = {
271349 "weight" : LinearParameterShardingInfo (
272350 module_info .sharding_dim , ShardType .SHARD
273351 ),
274- # FIXME: with per-channel W, clips need to be sharded
275- # but if per-tensor w, there should be no sharding
276- # HOW CAN WE DISCRIMINATE THE TWO CASES?
277- "w_clip_val" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
278- "w_clip_valn" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
279- # "w_clip_val": LinearParameterShardingInfo(
280- # module_info.sharding_dim,
281- # ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
282- # ),
283- # "w_clip_valn": LinearParameterShardingInfo(
284- # module_info.sharding_dim,
285- # ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
286- # ),
352+ "w_clip_val" : w_clip_linear_param ,
353+ "w_clip_valn" : w_clip_linear_param ,
287354 "a_clip_val" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
288355 "a_clip_valn" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
289356 "zero_shift" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
290- "smooqthquant_scale " : LinearParameterShardingInfo ( 0 , ShardType . CLONE ) ,
357+ "smoothquant_scale " : smoothquant_linear_param ,
291358 }
292- if int8_aiu_mod .bias is not None :
359+ if int8_aiu_module .bias is not None and int8_aiu_module . bias . numel () > 1 :
293360 params ["bias" ] = LinearParameterShardingInfo (
294- module_info . sharding_dim ,
361+ 0 ,
295362 ShardType .SHARD if module_info .sharding_dim == 0 else ShardType .RANK0 ,
296363 )
297364 param_sharding_info [module_name ] = params
298365
366+ # trim qdata from dictionary of tensors to be copied on sharded modules.
367+ # if not trimmed, qdata wouldn't be copied but the keys would be marked as unused
368+ tensor_values = {k : v for k , v in tensor_values .items () if "qdata" not in k }
369+
299370 unused_keys = shard_base_linear (
300371 tensor_values , tp_module , module_sharding_info , param_sharding_info
301372 )
302373
303- raise NotImplementedError ("TP not yet supported for INT8. Work in progress" )
304- # return unused_keys
374+ # qdata contains all quantization metadata to pass to the AIU and needs to be
375+ # updated post-sharding, after metadata tensor have changed
376+ for module_name , module_info in module_sharding_info .items ():
377+ module_info .linear_module .re_register_qdata ()
378+
379+ return unused_keys
305380
306381
307382register_linear_type_to_module_map (
@@ -320,4 +395,6 @@ def shard_int8_aiu_linear(
320395 use_smoothquant = True ,
321396 ),
322397)
398+
399+ # int8 linear with and w/o smoothquant share a common sharding map
323400register_linear_type_to_sharding_map ("int8_aiu" , shard_int8_aiu_linear )
0 commit comments