Skip to content

Commit f8b7a76

Browse files
Merge branch 'foundation-model-stack:main' into main
2 parents 45c8ded + 4705c75 commit f8b7a76

File tree

8 files changed

+291
-143
lines changed

8 files changed

+291
-143
lines changed

.github/workflows/labelpr.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
github-token: ${{ secrets.GITHUB_TOKEN }}
1414
script: |
1515
// https://github.com/commitizen/conventional-commit-types
16-
const valid_pr_types = ['feat', 'fix', 'docs', 'style', 'refactor', 'perf', 'test', 'build', 'ci', 'chore', 'revert'];
16+
const valid_pr_types = ['feat', 'fix', 'docs', 'style', 'refactor', 'perf', 'test', 'build', 'ci', 'chore', 'revert', 'dependencies'];
1717
1818
1919
const title = context.payload.pull_request.title;
@@ -28,4 +28,4 @@ jobs:
2828
const labels = context.payload.pull_request.labels;
2929
const new_labels = labels.filter(label => !valid_pr_types.includes(label.name)); // keep all labels that are not in valid_pr_types
3030
new_labels.push({name: pr_type});
31-
await github.rest.issues.update({ ...context.repo, issue_number: context.payload.number, labels: new_labels });
31+
await github.rest.issues.update({ ...context.repo, issue_number: context.payload.number, labels: new_labels });

.github/workflows/pypi.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
# for setuptools-scm
4545
fetch-depth: 0
4646

47-
- uses: hynek/build-and-inspect-python-package@v2
47+
- uses: hynek/build-and-inspect-python-package@b5076c307dc91924a82ad150cdd1533b444d3310 # v2.12.0
4848

4949
# push to Test PyPI on
5050
# - a new GitHub release is published
@@ -77,7 +77,7 @@ jobs:
7777
path: dist
7878

7979
- name: Upload to Test PyPI
80-
uses: pypa/gh-action-pypi-publish@v12.2.4
80+
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
8181
with:
8282
repository-url: https://test.pypi.org/legacy/
8383

@@ -122,4 +122,4 @@ jobs:
122122
run: rm ./dist/*.sigstore.json
123123

124124
- name: Upload to PyPI
125-
uses: pypa/gh-action-pypi-publish@v12.2.4
125+
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4

fms_mo/aiu_addons/gptq/gptq_aiu_op.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818

1919
# Third Party
20+
from packaging.version import Version
2021
import torch
2122

2223
# pylint: disable=unused-argument
@@ -25,6 +26,36 @@
2526
logger = logging.getLogger(__name__)
2627

2728

29+
def implement_op_decorator(op_namespace_id):
30+
"""Version-dependent decorator for custom op implementation.
31+
Always compare against pytorch version in current environment.
32+
"""
33+
34+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
35+
36+
def decorator(func):
37+
if torch_version < Version("2.4"):
38+
return torch.library.impl(op_namespace_id, "default")(func)
39+
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
40+
41+
return decorator
42+
43+
44+
def register_op_decorator(op_namespace_id):
45+
"""Version-dependent decorator for custom op registration.
46+
Always compare against pytorch version in current environment.
47+
"""
48+
49+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
50+
51+
def decorator(func):
52+
if torch_version < Version("2.4"):
53+
return torch.library.impl_abstract(op_namespace_id)(func)
54+
return torch.library.register_fake(op_namespace_id)(func)
55+
56+
return decorator
57+
58+
2859
def register_aiu_gptq_op():
2960
"""Register AIU-specific op to enable torch compile without graph break.
3061
The op preserves I/O shapes of a `X @ W^T` matmul but performs no operation.
@@ -36,17 +67,33 @@ def register_aiu_gptq_op():
3667
):
3768
logger.warning("AIU op has already been registered")
3869
return
39-
4070
op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
41-
torch.library.define(
42-
op_namespace_id,
43-
"(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor",
44-
)
71+
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
72+
torch.library.define(
73+
op_namespace_id,
74+
"(Tensor x, Tensor qw, Tensor qzeros, "
75+
"Tensor scales, Tensor g_idx) -> Tensor",
76+
)
4577

4678
# Add implementations for the operator
47-
@torch.library.impl(op_namespace_id, "default")
48-
def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
49-
# on AIU, GPTQ qw is [out_feat, in_feat]
79+
@implement_op_decorator(op_namespace_id)
80+
def i4f16_fxinputs_aiu(
81+
x: torch.Tensor,
82+
qw: torch.Tensor,
83+
qzeros: torch.Tensor,
84+
scales: torch.Tensor,
85+
g_idx: torch.Tensor,
86+
) -> torch.Tensor:
87+
"""Implement fake processing of GPTQ W4A16 matmul. The purpose is to create a
88+
node on the computational graph to be captured during compiling for AIU.
89+
90+
Instead of computing the weight decompression and matmul, this function returns
91+
a zero tensor with the expected shape.
92+
93+
NOTE: on AIU, GPTQ qw is [out_feat, in_feat], while AutoGPTQ saves the quantized
94+
weights as [in_feat, out_feat]
95+
"""
96+
5097
outshape = x.shape[:-1] + (qw.shape[0],)
5198
x = x.view(-1, x.shape[-1])
5299
output = torch.zeros(
@@ -56,8 +103,10 @@ def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
56103
)
57104
return output.view(outshape)
58105

59-
@torch.library.impl_abstract(op_namespace_id)
60-
def i4f16_fxinputs_aiu_abstract(x, qw, qzeros, scales, g_idx):
106+
@register_op_decorator(op_namespace_id)
107+
def _(x, qw, qzeros, scales, g_idx):
108+
"""OP template of I/O sizes"""
109+
61110
outshape = x.shape[:-1] + (qw.shape[0],)
62111
return torch.empty(
63112
outshape,

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -97,41 +97,22 @@ def _add_defaults_and_concat(
9797
)
9898

9999

100-
# registration of new adapter steps for each architecture
101-
serialization.register_adapter_step("llama", "int8_qparams_aiu", _int8_qparams_aiu)
102-
serialization.register_adapter_step(
103-
"gpt_bigcode", "int8_qparams_aiu", _int8_qparams_aiu
104-
)
105-
serialization.register_adapter_step("roberta", "int8_qparams_aiu", _int8_qparams_aiu)
106-
serialization.register_adapter_step(
107-
"roberta_question_answering",
108-
"int8_qparams_aiu",
109-
_int8_qparams_aiu,
110-
)
111-
112-
# registration of multi-step adapter for each architecture
113-
serialization.register_adapter(
100+
# registration of new adapter step and adapter for each architecture
101+
for arch in [
114102
"llama",
115-
"fms_mo",
116-
[
117-
"hf_to_fms_names",
118-
"hf_to_fms_rope",
119-
"weight_fusion",
120-
"int8_qparams_aiu",
121-
],
122-
)
123-
serialization.register_adapter(
124-
"gpt_bigcode", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
125-
)
126-
serialization.register_adapter(
127-
"roberta", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
128-
)
129-
serialization.register_adapter(
103+
"gpt_bigcode",
104+
"granite",
105+
"roberta",
130106
"roberta_question_answering",
131-
"fms_mo",
132-
[
133-
"hf_to_fms_names",
134-
"weight_fusion",
135-
"int8_qparams_aiu",
136-
],
137-
)
107+
]:
108+
serialization.register_adapter_step(arch, "int8_qparams_aiu", _int8_qparams_aiu)
109+
if arch in ["llama", "granite"]:
110+
steps_to_register = [
111+
"hf_to_fms_names",
112+
"hf_to_fms_rope",
113+
"weight_fusion",
114+
"int8_qparams_aiu",
115+
]
116+
else:
117+
steps_to_register = ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
118+
serialization.register_adapter(arch, "fms_mo", steps_to_register)

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 106 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def __init__(
8484
"weight",
8585
torch.zeros(out_features, in_features, dtype=torch.int8),
8686
)
87-
if bias:
88-
self.register_buffer(
89-
"bias", torch.zeros((out_features), dtype=torch.float16)
90-
)
87+
88+
self.has_bias = bias
89+
bias_size = out_features if self.has_bias else 1
90+
self.register_buffer("bias", torch.zeros((bias_size), dtype=torch.float16))
9191

9292
if config.weight_per_channel:
9393
w_clip_size = out_features
@@ -188,11 +188,32 @@ def forward(self, x):
188188
self.smoothquant,
189189
)
190190

191+
def re_register_qdata(self) -> None:
192+
"""Remove existing self.qdata tensor and register it again as a buffer.
193+
This method is used during TP, after other quantization metadata have been
194+
updated.
195+
"""
196+
197+
del self.qdata
198+
self.register_buffer(
199+
"qdata",
200+
torch.cat(
201+
(
202+
self.w_clip_val,
203+
self.w_clip_valn,
204+
self.a_clip_val,
205+
self.a_clip_valn,
206+
self.zero_shift,
207+
self.smoothquant_scale,
208+
)
209+
),
210+
)
211+
191212
def __repr__(self) -> str:
192213
return (
193214
f"{self.__class__.__name__}"
194215
f"(in={self.in_features}, out={self.out_features}, "
195-
f"bias={self.bias is not None}, wq={self.weight_quant_type}, "
216+
f"bias={self.has_bias}, wq={self.weight_quant_type}, "
196217
f"aq={self.activ_quant_type}, smoothq={self.smoothquant}, "
197218
f"op={self.aiu_op})"
198219
)
@@ -222,7 +243,7 @@ def get_int8_aiu_linear(
222243
# Preprocess linear_config if its linear_type field is a callable
223244
# (which would not initialize correctly the dataclass parameters).
224245
# We don't want to alter the original linear_config though.
225-
linear_config_for_dataclass: Optional[dict[Union[str, Callable], Any]] = None
246+
linear_config_for_dataclass = None
226247
if callable(linear_config["linear_type"]):
227248
linear_config_for_dataclass = update_from_partial(linear_config)
228249
linear_config_for_dataclass["linear_type"] = linear_type
@@ -240,6 +261,36 @@ def get_int8_aiu_linear(
240261
return linear
241262

242263

264+
def is_w_clip_per_channel(
265+
w_clip: torch.Tensor,
266+
) -> bool:
267+
"""Determine whether the weight clip value in use for INT8 quantization of the
268+
provided linear module is:
269+
- per-tensor (1 element, 1-dim tensor), or
270+
- per-channel (out_feat elements, 1-dim tensor).
271+
"""
272+
273+
if w_clip.dim() != 1:
274+
raise ValueError(
275+
f"TP error: weight clip value dimensions {str(list(w_clip.size()))} are "
276+
"incompatible with expected per-tensor or per-channel quantization."
277+
)
278+
return w_clip.numel() > 1
279+
280+
281+
def is_smoothquant_enabled(
282+
smoothquant_scale: torch.Tensor,
283+
) -> bool:
284+
"""Determine whether smoothquant is enabled on a module."""
285+
286+
if smoothquant_scale.dim() != 1:
287+
raise ValueError(
288+
"TP error: smoothquant_scale array should always be 1-dimensional but "
289+
f"has size {str(list(smoothquant_scale.size()))}"
290+
)
291+
return smoothquant_scale.numel() > 1
292+
293+
243294
def shard_int8_aiu_linear(
244295
tensor_values: dict[str, torch.Tensor],
245296
tp_module: TPModule,
@@ -259,49 +310,73 @@ def shard_int8_aiu_linear(
259310
| bias | 0 | - |
260311
| others* | N | - |
261312
262-
Other quantization parameters: w_clip_val, w_clip_valn,
263-
a_clip_val, a_clip_valn, zero_shift, smoothquant_scale
264-
No sharding on all these parameters, except w_clip_val and w_clip_valn when
265-
per-channel quantization is used
313+
Other quantization parameters: w_clip_val, w_clip_valn, a_clip_val, a_clip_valn,
314+
zero_shift, smoothquant_scale
315+
316+
No sharding on any of these parameters (they are CLONED on each rank), with the
317+
exception of:
318+
- w_clip_val and w_clip_valn, only column-sharding and only when per-channel
319+
quantization is used
320+
- smoothquant_scale, only row-sharding and only if smoothquant in use
321+
322+
These parameters are 1-dimensional, so if sharding is needed, it is always applied
323+
on dim=0.
266324
"""
325+
267326
param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {}
327+
w_clip_linear_param = None
268328
for module_name, module_info in module_sharding_info.items():
269-
int8_aiu_mod = module_info.linear_module
329+
int8_aiu_module = module_info.linear_module
330+
331+
# check every module if per-channel in use (sharding depends on module)
332+
if is_w_clip_per_channel(module_info.linear_module.w_clip_val):
333+
w_clip_linear_param = LinearParameterShardingInfo(
334+
0,
335+
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.CLONE,
336+
)
337+
else:
338+
w_clip_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)
339+
340+
# check for every linear module if smoothquant is enabled
341+
if is_smoothquant_enabled(module_info.linear_module.smoothquant_scale):
342+
smoothquant_linear_param = LinearParameterShardingInfo(
343+
0, ShardType.SHARD if module_info.sharding_dim == 1 else ShardType.CLONE
344+
)
345+
else:
346+
smoothquant_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)
347+
270348
params: dict[str, LinearParameterShardingInfo] = {
271349
"weight": LinearParameterShardingInfo(
272350
module_info.sharding_dim, ShardType.SHARD
273351
),
274-
# FIXME: with per-channel W, clips need to be sharded
275-
# but if per-tensor w, there should be no sharding
276-
# HOW CAN WE DISCRIMINATE THE TWO CASES?
277-
"w_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
278-
"w_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
279-
# "w_clip_val": LinearParameterShardingInfo(
280-
# module_info.sharding_dim,
281-
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
282-
# ),
283-
# "w_clip_valn": LinearParameterShardingInfo(
284-
# module_info.sharding_dim,
285-
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
286-
# ),
352+
"w_clip_val": w_clip_linear_param,
353+
"w_clip_valn": w_clip_linear_param,
287354
"a_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
288355
"a_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
289356
"zero_shift": LinearParameterShardingInfo(0, ShardType.CLONE),
290-
"smooqthquant_scale": LinearParameterShardingInfo(0, ShardType.CLONE),
357+
"smoothquant_scale": smoothquant_linear_param,
291358
}
292-
if int8_aiu_mod.bias is not None:
359+
if int8_aiu_module.bias is not None and int8_aiu_module.bias.numel() > 1:
293360
params["bias"] = LinearParameterShardingInfo(
294-
module_info.sharding_dim,
361+
0,
295362
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
296363
)
297364
param_sharding_info[module_name] = params
298365

366+
# trim qdata from dictionary of tensors to be copied on sharded modules.
367+
# if not trimmed, qdata wouldn't be copied but the keys would be marked as unused
368+
tensor_values = {k: v for k, v in tensor_values.items() if "qdata" not in k}
369+
299370
unused_keys = shard_base_linear(
300371
tensor_values, tp_module, module_sharding_info, param_sharding_info
301372
)
302373

303-
raise NotImplementedError("TP not yet supported for INT8. Work in progress")
304-
# return unused_keys
374+
# qdata contains all quantization metadata to pass to the AIU and needs to be
375+
# updated post-sharding, after metadata tensor have changed
376+
for module_name, module_info in module_sharding_info.items():
377+
module_info.linear_module.re_register_qdata()
378+
379+
return unused_keys
305380

306381

307382
register_linear_type_to_module_map(
@@ -320,4 +395,6 @@ def shard_int8_aiu_linear(
320395
use_smoothquant=True,
321396
),
322397
)
398+
399+
# int8 linear with and w/o smoothquant share a common sharding map
323400
register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear)

0 commit comments

Comments
 (0)