Skip to content

Commit ce935e6

Browse files
committed
Enable INT8 LLM sharding for TP>1
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 50b6ce3 commit ce935e6

File tree

1 file changed

+100
-23
lines changed

1 file changed

+100
-23
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 100 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,25 @@ def forward(self, x):
188188
self.smoothquant,
189189
)
190190

191+
def re_register_qdata(self) -> None:
192+
"""Remove existing self.qdata tensor and register it again as a buffer.
193+
This method is used during TP, after a module has been """
194+
del self.qdata
195+
self.register_buffer(
196+
"qdata",
197+
torch.cat(
198+
(
199+
self.w_clip_val,
200+
self.w_clip_valn,
201+
self.a_clip_val,
202+
self.a_clip_valn,
203+
self.zero_shift,
204+
self.smoothquant_scale,
205+
)
206+
),
207+
)
208+
209+
191210
def __repr__(self) -> str:
192211
return (
193212
f"{self.__class__.__name__}"
@@ -240,6 +259,37 @@ def get_int8_aiu_linear(
240259
return linear
241260

242261

262+
def is_w_clip_per_channel(
263+
w_clip: torch.Tensor,
264+
) -> bool:
265+
"""Determine whether the weight clip value in use for INT8 quantization of the
266+
provided linear module is:
267+
- per-tensor (1 element, 1-dim tensor), or
268+
- per-channel (out_feat elements, 1-dim tensor).
269+
"""
270+
271+
if w_clip.dim() != 1:
272+
raise ValueError(
273+
f"TP error: weight clip value dimensions {str(list(w_clip.size()))} are "
274+
"incompatible with expected per-tensor or per-channel quantization."
275+
)
276+
return w_clip.numel() > 1
277+
278+
279+
def is_smoothquant_enabled(
280+
smoothquant_scale: torch.Tensor,
281+
) -> bool:
282+
"""Determine whether smoothquant is enabled on a module.
283+
"""
284+
285+
if smoothquant_scale.dim() != 1:
286+
raise ValueError(
287+
"TP error: smoothquant_scale array should always be 1-dimensional but "
288+
f"has size {str(list(smoothquant_scale.size()))}"
289+
)
290+
return smoothquant_scale.numel() > 1
291+
292+
243293
def shard_int8_aiu_linear(
244294
tensor_values: dict[str, torch.Tensor],
245295
tp_module: TPModule,
@@ -259,49 +309,74 @@ def shard_int8_aiu_linear(
259309
| bias | 0 | - |
260310
| others* | N | - |
261311
262-
Other quantization parameters: w_clip_val, w_clip_valn,
263-
a_clip_val, a_clip_valn, zero_shift, smoothquant_scale
264-
No sharding on all these parameters, except w_clip_val and w_clip_valn when
265-
per-channel quantization is used
312+
Other quantization parameters: w_clip_val, w_clip_valn, a_clip_val, a_clip_valn,
313+
zero_shift, smoothquant_scale
314+
315+
No sharding on any of these parameters (they are CLONED on each rank), with the
316+
exception of:
317+
- w_clip_val and w_clip_valn, only column-sharding and only when per-channel
318+
quantization is used
319+
- smoothquant_scale, only row-sharding and only if smoothquant in use
320+
321+
These parameters are 1-dimensional, so if sharding is needed, it is always applied
322+
on dim=0.
266323
"""
324+
267325
param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {}
326+
w_clip_linear_param = None
268327
for module_name, module_info in module_sharding_info.items():
269-
int8_aiu_mod = module_info.linear_module
328+
int8_aiu_module = module_info.linear_module
329+
330+
# check every module if per-channel in use (sharding depends on module)
331+
if is_w_clip_per_channel(module_info.linear_module.w_clip_val):
332+
w_clip_linear_param = LinearParameterShardingInfo(
333+
0,
334+
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.CLONE,
335+
)
336+
else:
337+
w_clip_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)
338+
339+
# check for every linear module if smoothquant is enabled
340+
if is_smoothquant_enabled(module_info.linear_module.smoothquant_scale):
341+
smoothquant_linear_param = LinearParameterShardingInfo(
342+
0,
343+
ShardType.SHARD if module_info.sharding_dim == 1 else ShardType.CLONE
344+
)
345+
else:
346+
smoothquant_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)
347+
270348
params: dict[str, LinearParameterShardingInfo] = {
271349
"weight": LinearParameterShardingInfo(
272350
module_info.sharding_dim, ShardType.SHARD
273351
),
274-
# FIXME: with per-channel W, clips need to be sharded
275-
# but if per-tensor w, there should be no sharding
276-
# HOW CAN WE DISCRIMINATE THE TWO CASES?
277-
"w_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
278-
"w_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
279-
# "w_clip_val": LinearParameterShardingInfo(
280-
# module_info.sharding_dim,
281-
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
282-
# ),
283-
# "w_clip_valn": LinearParameterShardingInfo(
284-
# module_info.sharding_dim,
285-
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
286-
# ),
352+
"w_clip_val": w_clip_linear_param,
353+
"w_clip_valn": w_clip_linear_param,
287354
"a_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
288355
"a_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
289356
"zero_shift": LinearParameterShardingInfo(0, ShardType.CLONE),
290-
"smooqthquant_scale": LinearParameterShardingInfo(0, ShardType.CLONE),
357+
"smoothquant_scale": smoothquant_linear_param,
291358
}
292-
if int8_aiu_mod.bias is not None:
359+
if int8_aiu_module.bias is not None and int8_aiu_module.bias.numel() > 1:
293360
params["bias"] = LinearParameterShardingInfo(
294-
module_info.sharding_dim,
361+
0,
295362
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
296363
)
297364
param_sharding_info[module_name] = params
298365

366+
# trim qdata from dictionary of tensors to be copied on sharded modules.
367+
# if not trimmed, qdata wouldn't be copied but the keys would be marked as unused
368+
tensor_values = {k: v for k, v in tensor_values.items() if "qdata" not in k}
369+
299370
unused_keys = shard_base_linear(
300371
tensor_values, tp_module, module_sharding_info, param_sharding_info
301372
)
302373

303-
raise NotImplementedError("TP not yet supported for INT8. Work in progress")
304-
# return unused_keys
374+
# qdata contains all quantization metadata to pass to the AIU and needs to be
375+
# updated post-sharding, after metadata tensor have changed
376+
for module_name, module_info in module_sharding_info.items():
377+
module_info.linear_module.re_register_qdata()
378+
379+
return unused_keys
305380

306381

307382
register_linear_type_to_module_map(
@@ -320,4 +395,6 @@ def shard_int8_aiu_linear(
320395
use_smoothquant=True,
321396
),
322397
)
398+
399+
# int8 linear with and w/o smoothquant share a common sharding map
323400
register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear)

0 commit comments

Comments
 (0)