@@ -188,6 +188,25 @@ def forward(self, x):
188188 self .smoothquant ,
189189 )
190190
191+ def re_register_qdata (self ) -> None :
192+ """Remove existing self.qdata tensor and register it again as a buffer.
193+ This method is used during TP, after a module has been """
194+ del self .qdata
195+ self .register_buffer (
196+ "qdata" ,
197+ torch .cat (
198+ (
199+ self .w_clip_val ,
200+ self .w_clip_valn ,
201+ self .a_clip_val ,
202+ self .a_clip_valn ,
203+ self .zero_shift ,
204+ self .smoothquant_scale ,
205+ )
206+ ),
207+ )
208+
209+
191210 def __repr__ (self ) -> str :
192211 return (
193212 f"{ self .__class__ .__name__ } "
@@ -240,6 +259,37 @@ def get_int8_aiu_linear(
240259 return linear
241260
242261
262+ def is_w_clip_per_channel (
263+ w_clip : torch .Tensor ,
264+ ) -> bool :
265+ """Determine whether the weight clip value in use for INT8 quantization of the
266+ provided linear module is:
267+ - per-tensor (1 element, 1-dim tensor), or
268+ - per-channel (out_feat elements, 1-dim tensor).
269+ """
270+
271+ if w_clip .dim () != 1 :
272+ raise ValueError (
273+ f"TP error: weight clip value dimensions { str (list (w_clip .size ()))} are "
274+ "incompatible with expected per-tensor or per-channel quantization."
275+ )
276+ return w_clip .numel () > 1
277+
278+
279+ def is_smoothquant_enabled (
280+ smoothquant_scale : torch .Tensor ,
281+ ) -> bool :
282+ """Determine whether smoothquant is enabled on a module.
283+ """
284+
285+ if smoothquant_scale .dim () != 1 :
286+ raise ValueError (
287+ "TP error: smoothquant_scale array should always be 1-dimensional but "
288+ f"has size { str (list (smoothquant_scale .size ()))} "
289+ )
290+ return smoothquant_scale .numel () > 1
291+
292+
243293def shard_int8_aiu_linear (
244294 tensor_values : dict [str , torch .Tensor ],
245295 tp_module : TPModule ,
@@ -259,49 +309,74 @@ def shard_int8_aiu_linear(
259309 | bias | 0 | - |
260310 | others* | N | - |
261311
262- Other quantization parameters: w_clip_val, w_clip_valn,
263- a_clip_val, a_clip_valn, zero_shift, smoothquant_scale
264- No sharding on all these parameters, except w_clip_val and w_clip_valn when
265- per-channel quantization is used
312+ Other quantization parameters: w_clip_val, w_clip_valn, a_clip_val, a_clip_valn,
313+ zero_shift, smoothquant_scale
314+
315+ No sharding on any of these parameters (they are CLONED on each rank), with the
316+ exception of:
317+ - w_clip_val and w_clip_valn, only column-sharding and only when per-channel
318+ quantization is used
319+ - smoothquant_scale, only row-sharding and only if smoothquant in use
320+
321+ These parameters are 1-dimensional, so if sharding is needed, it is always applied
322+ on dim=0.
266323 """
324+
267325 param_sharding_info : dict [str , dict [str , LinearParameterShardingInfo ]] = {}
326+ w_clip_linear_param = None
268327 for module_name , module_info in module_sharding_info .items ():
269- int8_aiu_mod = module_info .linear_module
328+ int8_aiu_module = module_info .linear_module
329+
330+ # check every module if per-channel in use (sharding depends on module)
331+ if is_w_clip_per_channel (module_info .linear_module .w_clip_val ):
332+ w_clip_linear_param = LinearParameterShardingInfo (
333+ 0 ,
334+ ShardType .SHARD if module_info .sharding_dim == 0 else ShardType .CLONE ,
335+ )
336+ else :
337+ w_clip_linear_param = LinearParameterShardingInfo (0 , ShardType .CLONE )
338+
339+ # check for every linear module if smoothquant is enabled
340+ if is_smoothquant_enabled (module_info .linear_module .smoothquant_scale ):
341+ smoothquant_linear_param = LinearParameterShardingInfo (
342+ 0 ,
343+ ShardType .SHARD if module_info .sharding_dim == 1 else ShardType .CLONE
344+ )
345+ else :
346+ smoothquant_linear_param = LinearParameterShardingInfo (0 , ShardType .CLONE )
347+
270348 params : dict [str , LinearParameterShardingInfo ] = {
271349 "weight" : LinearParameterShardingInfo (
272350 module_info .sharding_dim , ShardType .SHARD
273351 ),
274- # FIXME: with per-channel W, clips need to be sharded
275- # but if per-tensor w, there should be no sharding
276- # HOW CAN WE DISCRIMINATE THE TWO CASES?
277- "w_clip_val" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
278- "w_clip_valn" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
279- # "w_clip_val": LinearParameterShardingInfo(
280- # module_info.sharding_dim,
281- # ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
282- # ),
283- # "w_clip_valn": LinearParameterShardingInfo(
284- # module_info.sharding_dim,
285- # ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
286- # ),
352+ "w_clip_val" : w_clip_linear_param ,
353+ "w_clip_valn" : w_clip_linear_param ,
287354 "a_clip_val" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
288355 "a_clip_valn" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
289356 "zero_shift" : LinearParameterShardingInfo (0 , ShardType .CLONE ),
290- "smooqthquant_scale " : LinearParameterShardingInfo ( 0 , ShardType . CLONE ) ,
357+ "smoothquant_scale " : smoothquant_linear_param ,
291358 }
292- if int8_aiu_mod .bias is not None :
359+ if int8_aiu_module .bias is not None and int8_aiu_module . bias . numel () > 1 :
293360 params ["bias" ] = LinearParameterShardingInfo (
294- module_info . sharding_dim ,
361+ 0 ,
295362 ShardType .SHARD if module_info .sharding_dim == 0 else ShardType .RANK0 ,
296363 )
297364 param_sharding_info [module_name ] = params
298365
366+ # trim qdata from dictionary of tensors to be copied on sharded modules.
367+ # if not trimmed, qdata wouldn't be copied but the keys would be marked as unused
368+ tensor_values = {k : v for k , v in tensor_values .items () if "qdata" not in k }
369+
299370 unused_keys = shard_base_linear (
300371 tensor_values , tp_module , module_sharding_info , param_sharding_info
301372 )
302373
303- raise NotImplementedError ("TP not yet supported for INT8. Work in progress" )
304- # return unused_keys
374+ # qdata contains all quantization metadata to pass to the AIU and needs to be
375+ # updated post-sharding, after metadata tensor have changed
376+ for module_name , module_info in module_sharding_info .items ():
377+ module_info .linear_module .re_register_qdata ()
378+
379+ return unused_keys
305380
306381
307382register_linear_type_to_module_map (
@@ -320,4 +395,6 @@ def shard_int8_aiu_linear(
320395 use_smoothquant = True ,
321396 ),
322397)
398+
399+ # int8 linear with and w/o smoothquant share a common sharding map
323400register_linear_type_to_sharding_map ("int8_aiu" , shard_int8_aiu_linear )
0 commit comments