Skip to content

Commit 45dd501

Browse files
Merge branch 'main' into mx_impl
Signed-off-by: chichun-charlie-liu <[email protected]>
2 parents 764d7bc + 4705c75 commit 45dd501

File tree

18 files changed

+715
-201
lines changed

18 files changed

+715
-201
lines changed

.github/workflows/labelpr.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
github-token: ${{ secrets.GITHUB_TOKEN }}
1414
script: |
1515
// https://github.com/commitizen/conventional-commit-types
16-
const valid_pr_types = ['feat', 'fix', 'docs', 'style', 'refactor', 'perf', 'test', 'build', 'ci', 'chore', 'revert'];
16+
const valid_pr_types = ['feat', 'fix', 'docs', 'style', 'refactor', 'perf', 'test', 'build', 'ci', 'chore', 'revert', 'dependencies'];
1717
1818
1919
const title = context.payload.pull_request.title;
@@ -28,4 +28,4 @@ jobs:
2828
const labels = context.payload.pull_request.labels;
2929
const new_labels = labels.filter(label => !valid_pr_types.includes(label.name)); // keep all labels that are not in valid_pr_types
3030
new_labels.push({name: pr_type});
31-
await github.rest.issues.update({ ...context.repo, issue_number: context.payload.number, labels: new_labels });
31+
await github.rest.issues.update({ ...context.repo, issue_number: context.payload.number, labels: new_labels });

.github/workflows/pypi.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
# for setuptools-scm
4545
fetch-depth: 0
4646

47-
- uses: hynek/build-and-inspect-python-package@v2
47+
- uses: hynek/build-and-inspect-python-package@b5076c307dc91924a82ad150cdd1533b444d3310 # v2.12.0
4848

4949
# push to Test PyPI on
5050
# - a new GitHub release is published
@@ -77,7 +77,7 @@ jobs:
7777
path: dist
7878

7979
- name: Upload to Test PyPI
80-
uses: pypa/gh-action-pypi-publish@15c56dba361d8335944d31a2ecd17d700fc7bcbc # v1.12.2
80+
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
8181
with:
8282
repository-url: https://test.pypi.org/legacy/
8383

@@ -122,4 +122,4 @@ jobs:
122122
run: rm ./dist/*.sigstore.json
123123

124124
- name: Upload to PyPI
125-
uses: pypa/gh-action-pypi-publish@15c56dba361d8335944d31a2ecd17d700fc7bcbc # v1.12.2
125+
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ jobs:
4040
strategy:
4141
matrix:
4242
python:
43-
- "3.9"
4443
- "3.10"
4544
- "3.11"
45+
- "3.12"
4646
platform:
4747
- "ubuntu-latest"
4848

fms_mo/aiu_addons/gptq/gptq_aiu_op.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818

1919
# Third Party
20+
from packaging.version import Version
2021
import torch
2122

2223
# pylint: disable=unused-argument
@@ -25,6 +26,36 @@
2526
logger = logging.getLogger(__name__)
2627

2728

29+
def implement_op_decorator(op_namespace_id):
30+
"""Version-dependent decorator for custom op implementation.
31+
Always compare against pytorch version in current environment.
32+
"""
33+
34+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
35+
36+
def decorator(func):
37+
if torch_version < Version("2.4"):
38+
return torch.library.impl(op_namespace_id, "default")(func)
39+
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
40+
41+
return decorator
42+
43+
44+
def register_op_decorator(op_namespace_id):
45+
"""Version-dependent decorator for custom op registration.
46+
Always compare against pytorch version in current environment.
47+
"""
48+
49+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
50+
51+
def decorator(func):
52+
if torch_version < Version("2.4"):
53+
return torch.library.impl_abstract(op_namespace_id)(func)
54+
return torch.library.register_fake(op_namespace_id)(func)
55+
56+
return decorator
57+
58+
2859
def register_aiu_gptq_op():
2960
"""Register AIU-specific op to enable torch compile without graph break.
3061
The op preserves I/O shapes of a `X @ W^T` matmul but performs no operation.
@@ -36,17 +67,33 @@ def register_aiu_gptq_op():
3667
):
3768
logger.warning("AIU op has already been registered")
3869
return
39-
4070
op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
41-
torch.library.define(
42-
op_namespace_id,
43-
"(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor",
44-
)
71+
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
72+
torch.library.define(
73+
op_namespace_id,
74+
"(Tensor x, Tensor qw, Tensor qzeros, "
75+
"Tensor scales, Tensor g_idx) -> Tensor",
76+
)
4577

4678
# Add implementations for the operator
47-
@torch.library.impl(op_namespace_id, "default")
48-
def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
49-
# on AIU, GPTQ qw is [out_feat, in_feat]
79+
@implement_op_decorator(op_namespace_id)
80+
def i4f16_fxinputs_aiu(
81+
x: torch.Tensor,
82+
qw: torch.Tensor,
83+
qzeros: torch.Tensor,
84+
scales: torch.Tensor,
85+
g_idx: torch.Tensor,
86+
) -> torch.Tensor:
87+
"""Implement fake processing of GPTQ W4A16 matmul. The purpose is to create a
88+
node on the computational graph to be captured during compiling for AIU.
89+
90+
Instead of computing the weight decompression and matmul, this function returns
91+
a zero tensor with the expected shape.
92+
93+
NOTE: on AIU, GPTQ qw is [out_feat, in_feat], while AutoGPTQ saves the quantized
94+
weights as [in_feat, out_feat]
95+
"""
96+
5097
outshape = x.shape[:-1] + (qw.shape[0],)
5198
x = x.view(-1, x.shape[-1])
5299
output = torch.zeros(
@@ -56,8 +103,10 @@ def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
56103
)
57104
return output.view(outshape)
58105

59-
@torch.library.impl_abstract(op_namespace_id)
60-
def i4f16_fxinputs_aiu_abstract(x, qw, qzeros, scales, g_idx):
106+
@register_op_decorator(op_namespace_id)
107+
def _(x, qw, qzeros, scales, g_idx):
108+
"""OP template of I/O sizes"""
109+
61110
outshape = x.shape[:-1] + (qw.shape[0],)
62111
return torch.empty(
63112
outshape,

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -97,41 +97,22 @@ def _add_defaults_and_concat(
9797
)
9898

9999

100-
# registration of new adapter steps for each architecture
101-
serialization.register_adapter_step("llama", "int8_qparams_aiu", _int8_qparams_aiu)
102-
serialization.register_adapter_step(
103-
"gpt_bigcode", "int8_qparams_aiu", _int8_qparams_aiu
104-
)
105-
serialization.register_adapter_step("roberta", "int8_qparams_aiu", _int8_qparams_aiu)
106-
serialization.register_adapter_step(
107-
"roberta_question_answering",
108-
"int8_qparams_aiu",
109-
_int8_qparams_aiu,
110-
)
111-
112-
# registration of multi-step adapter for each architecture
113-
serialization.register_adapter(
100+
# registration of new adapter step and adapter for each architecture
101+
for arch in [
114102
"llama",
115-
"fms_mo",
116-
[
117-
"hf_to_fms_names",
118-
"hf_to_fms_rope",
119-
"weight_fusion",
120-
"int8_qparams_aiu",
121-
],
122-
)
123-
serialization.register_adapter(
124-
"gpt_bigcode", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
125-
)
126-
serialization.register_adapter(
127-
"roberta", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
128-
)
129-
serialization.register_adapter(
103+
"gpt_bigcode",
104+
"granite",
105+
"roberta",
130106
"roberta_question_answering",
131-
"fms_mo",
132-
[
133-
"hf_to_fms_names",
134-
"weight_fusion",
135-
"int8_qparams_aiu",
136-
],
137-
)
107+
]:
108+
serialization.register_adapter_step(arch, "int8_qparams_aiu", _int8_qparams_aiu)
109+
if arch in ["llama", "granite"]:
110+
steps_to_register = [
111+
"hf_to_fms_names",
112+
"hf_to_fms_rope",
113+
"weight_fusion",
114+
"int8_qparams_aiu",
115+
]
116+
else:
117+
steps_to_register = ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
118+
serialization.register_adapter(arch, "fms_mo", steps_to_register)

0 commit comments

Comments
 (0)