Skip to content

Commit 2c75fcc

Browse files
Merge pull request #94 from andrea-fasoli/int8_llm_tp
feat: INT8 LLM TP>1 enablement
2 parents 50b6ce3 + c2ee3d9 commit 2c75fcc

File tree

1 file changed

+101
-24
lines changed

1 file changed

+101
-24
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 101 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,27 @@ def forward(self, x):
188188
self.smoothquant,
189189
)
190190

191+
def re_register_qdata(self) -> None:
192+
"""Remove existing self.qdata tensor and register it again as a buffer.
193+
This method is used during TP, after other quantization metadata have been
194+
updated.
195+
"""
196+
197+
del self.qdata
198+
self.register_buffer(
199+
"qdata",
200+
torch.cat(
201+
(
202+
self.w_clip_val,
203+
self.w_clip_valn,
204+
self.a_clip_val,
205+
self.a_clip_valn,
206+
self.zero_shift,
207+
self.smoothquant_scale,
208+
)
209+
),
210+
)
211+
191212
def __repr__(self) -> str:
192213
return (
193214
f"{self.__class__.__name__}"
@@ -222,7 +243,7 @@ def get_int8_aiu_linear(
222243
# Preprocess linear_config if its linear_type field is a callable
223244
# (which would not initialize correctly the dataclass parameters).
224245
# We don't want to alter the original linear_config though.
225-
linear_config_for_dataclass: Optional[dict[Union[str, Callable], Any]] = None
246+
linear_config_for_dataclass = None
226247
if callable(linear_config["linear_type"]):
227248
linear_config_for_dataclass = update_from_partial(linear_config)
228249
linear_config_for_dataclass["linear_type"] = linear_type
@@ -240,6 +261,36 @@ def get_int8_aiu_linear(
240261
return linear
241262

242263

264+
def is_w_clip_per_channel(
265+
w_clip: torch.Tensor,
266+
) -> bool:
267+
"""Determine whether the weight clip value in use for INT8 quantization of the
268+
provided linear module is:
269+
- per-tensor (1 element, 1-dim tensor), or
270+
- per-channel (out_feat elements, 1-dim tensor).
271+
"""
272+
273+
if w_clip.dim() != 1:
274+
raise ValueError(
275+
f"TP error: weight clip value dimensions {str(list(w_clip.size()))} are "
276+
"incompatible with expected per-tensor or per-channel quantization."
277+
)
278+
return w_clip.numel() > 1
279+
280+
281+
def is_smoothquant_enabled(
282+
smoothquant_scale: torch.Tensor,
283+
) -> bool:
284+
"""Determine whether smoothquant is enabled on a module."""
285+
286+
if smoothquant_scale.dim() != 1:
287+
raise ValueError(
288+
"TP error: smoothquant_scale array should always be 1-dimensional but "
289+
f"has size {str(list(smoothquant_scale.size()))}"
290+
)
291+
return smoothquant_scale.numel() > 1
292+
293+
243294
def shard_int8_aiu_linear(
244295
tensor_values: dict[str, torch.Tensor],
245296
tp_module: TPModule,
@@ -259,49 +310,73 @@ def shard_int8_aiu_linear(
259310
| bias | 0 | - |
260311
| others* | N | - |
261312
262-
Other quantization parameters: w_clip_val, w_clip_valn,
263-
a_clip_val, a_clip_valn, zero_shift, smoothquant_scale
264-
No sharding on all these parameters, except w_clip_val and w_clip_valn when
265-
per-channel quantization is used
313+
Other quantization parameters: w_clip_val, w_clip_valn, a_clip_val, a_clip_valn,
314+
zero_shift, smoothquant_scale
315+
316+
No sharding on any of these parameters (they are CLONED on each rank), with the
317+
exception of:
318+
- w_clip_val and w_clip_valn, only column-sharding and only when per-channel
319+
quantization is used
320+
- smoothquant_scale, only row-sharding and only if smoothquant in use
321+
322+
These parameters are 1-dimensional, so if sharding is needed, it is always applied
323+
on dim=0.
266324
"""
325+
267326
param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {}
327+
w_clip_linear_param = None
268328
for module_name, module_info in module_sharding_info.items():
269-
int8_aiu_mod = module_info.linear_module
329+
int8_aiu_module = module_info.linear_module
330+
331+
# check every module if per-channel in use (sharding depends on module)
332+
if is_w_clip_per_channel(module_info.linear_module.w_clip_val):
333+
w_clip_linear_param = LinearParameterShardingInfo(
334+
0,
335+
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.CLONE,
336+
)
337+
else:
338+
w_clip_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)
339+
340+
# check for every linear module if smoothquant is enabled
341+
if is_smoothquant_enabled(module_info.linear_module.smoothquant_scale):
342+
smoothquant_linear_param = LinearParameterShardingInfo(
343+
0, ShardType.SHARD if module_info.sharding_dim == 1 else ShardType.CLONE
344+
)
345+
else:
346+
smoothquant_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)
347+
270348
params: dict[str, LinearParameterShardingInfo] = {
271349
"weight": LinearParameterShardingInfo(
272350
module_info.sharding_dim, ShardType.SHARD
273351
),
274-
# FIXME: with per-channel W, clips need to be sharded
275-
# but if per-tensor w, there should be no sharding
276-
# HOW CAN WE DISCRIMINATE THE TWO CASES?
277-
"w_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
278-
"w_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
279-
# "w_clip_val": LinearParameterShardingInfo(
280-
# module_info.sharding_dim,
281-
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
282-
# ),
283-
# "w_clip_valn": LinearParameterShardingInfo(
284-
# module_info.sharding_dim,
285-
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
286-
# ),
352+
"w_clip_val": w_clip_linear_param,
353+
"w_clip_valn": w_clip_linear_param,
287354
"a_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
288355
"a_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
289356
"zero_shift": LinearParameterShardingInfo(0, ShardType.CLONE),
290-
"smooqthquant_scale": LinearParameterShardingInfo(0, ShardType.CLONE),
357+
"smoothquant_scale": smoothquant_linear_param,
291358
}
292-
if int8_aiu_mod.bias is not None:
359+
if int8_aiu_module.bias is not None and int8_aiu_module.bias.numel() > 1:
293360
params["bias"] = LinearParameterShardingInfo(
294-
module_info.sharding_dim,
361+
0,
295362
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
296363
)
297364
param_sharding_info[module_name] = params
298365

366+
# trim qdata from dictionary of tensors to be copied on sharded modules.
367+
# if not trimmed, qdata wouldn't be copied but the keys would be marked as unused
368+
tensor_values = {k: v for k, v in tensor_values.items() if "qdata" not in k}
369+
299370
unused_keys = shard_base_linear(
300371
tensor_values, tp_module, module_sharding_info, param_sharding_info
301372
)
302373

303-
raise NotImplementedError("TP not yet supported for INT8. Work in progress")
304-
# return unused_keys
374+
# qdata contains all quantization metadata to pass to the AIU and needs to be
375+
# updated post-sharding, after metadata tensor have changed
376+
for module_name, module_info in module_sharding_info.items():
377+
module_info.linear_module.re_register_qdata()
378+
379+
return unused_keys
305380

306381

307382
register_linear_type_to_module_map(
@@ -320,4 +395,6 @@ def shard_int8_aiu_linear(
320395
use_smoothquant=True,
321396
),
322397
)
398+
399+
# int8 linear with and w/o smoothquant share a common sharding map
323400
register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear)

0 commit comments

Comments
 (0)