Skip to content

Commit 3f07692

Browse files
Merge pull request #96 from andrea-fasoli/custom_ops_syntax
feat: Update syntax of custom torch ops
2 parents acbff2d + 0003bb1 commit 3f07692

File tree

2 files changed

+111
-39
lines changed

2 files changed

+111
-39
lines changed

fms_mo/aiu_addons/gptq/gptq_aiu_op.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818

1919
# Third Party
20+
from packaging.version import Version
2021
import torch
2122

2223
# pylint: disable=unused-argument
@@ -25,6 +26,36 @@
2526
logger = logging.getLogger(__name__)
2627

2728

29+
def implement_op_decorator(op_namespace_id):
30+
"""Version-dependent decorator for custom op implementation.
31+
Always compare against pytorch version in current environment.
32+
"""
33+
34+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
35+
36+
def decorator(func):
37+
if torch_version < Version("2.4"):
38+
return torch.library.impl(op_namespace_id, "default")(func)
39+
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
40+
41+
return decorator
42+
43+
44+
def register_op_decorator(op_namespace_id):
45+
"""Version-dependent decorator for custom op registration.
46+
Always compare against pytorch version in current environment.
47+
"""
48+
49+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
50+
51+
def decorator(func):
52+
if torch_version < Version("2.4"):
53+
return torch.library.impl_abstract(op_namespace_id)(func)
54+
return torch.library.register_fake(op_namespace_id)(func)
55+
56+
return decorator
57+
58+
2859
def register_aiu_gptq_op():
2960
"""Register AIU-specific op to enable torch compile without graph break.
3061
The op preserves I/O shapes of a `X @ W^T` matmul but performs no operation.
@@ -36,17 +67,33 @@ def register_aiu_gptq_op():
3667
):
3768
logger.warning("AIU op has already been registered")
3869
return
39-
4070
op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
41-
torch.library.define(
42-
op_namespace_id,
43-
"(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor",
44-
)
71+
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
72+
torch.library.define(
73+
op_namespace_id,
74+
"(Tensor x, Tensor qw, Tensor qzeros, "
75+
"Tensor scales, Tensor g_idx) -> Tensor",
76+
)
4577

4678
# Add implementations for the operator
47-
@torch.library.impl(op_namespace_id, "default")
48-
def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
49-
# on AIU, GPTQ qw is [out_feat, in_feat]
79+
@implement_op_decorator(op_namespace_id)
80+
def i4f16_fxinputs_aiu(
81+
x: torch.Tensor,
82+
qw: torch.Tensor,
83+
qzeros: torch.Tensor,
84+
scales: torch.Tensor,
85+
g_idx: torch.Tensor,
86+
) -> torch.Tensor:
87+
"""Implement fake processing of GPTQ W4A16 matmul. The purpose is to create a
88+
node on the computational graph to be captured during compiling for AIU.
89+
90+
Instead of computing the weight decompression and matmul, this function returns
91+
a zero tensor with the expected shape.
92+
93+
NOTE: on AIU, GPTQ qw is [out_feat, in_feat], while AutoGPTQ saves the quantized
94+
weights as [in_feat, out_feat]
95+
"""
96+
5097
outshape = x.shape[:-1] + (qw.shape[0],)
5198
x = x.view(-1, x.shape[-1])
5299
output = torch.zeros(
@@ -56,8 +103,10 @@ def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
56103
)
57104
return output.view(outshape)
58105

59-
@torch.library.impl_abstract(op_namespace_id)
60-
def i4f16_fxinputs_aiu_abstract(x, qw, qzeros, scales, g_idx):
106+
@register_op_decorator(op_namespace_id)
107+
def _(x, qw, qzeros, scales, g_idx):
108+
"""OP template of I/O sizes"""
109+
61110
outshape = x.shape[:-1] + (qw.shape[0],)
62111
return torch.empty(
63112
outshape,

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,49 @@
1717
import logging
1818

1919
# Third Party
20+
from packaging.version import Version
2021
import torch
2122
import torch.nn.functional as F
2223

23-
logger = logging.getLogger(__name__)
24-
2524
# pylint: disable=unused-argument
2625
# i8i8 op must be registered with specific I/O, even if not in use by the op function
2726

2827
# pylint: disable=not-callable
2928
# torch.nn.functional.linear not recognized as callable
3029
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3130

31+
logger = logging.getLogger(__name__)
32+
33+
34+
def implement_op_decorator(op_namespace_id):
35+
"""Version-dependent decorator for custom op implementation.
36+
Always compare against pytorch version in current environment.
37+
"""
38+
39+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
40+
41+
def decorator(func):
42+
if torch_version < Version("2.4"):
43+
return torch.library.impl(op_namespace_id, "default")(func)
44+
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
45+
46+
return decorator
47+
48+
49+
def register_op_decorator(op_namespace_id):
50+
"""Version-dependent decorator for custom op registration.
51+
Always compare against pytorch version in current environment.
52+
"""
53+
54+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
55+
56+
def decorator(func):
57+
if torch_version < Version("2.4"):
58+
return torch.library.impl_abstract(op_namespace_id)(func)
59+
return torch.library.register_fake(op_namespace_id)(func)
60+
61+
return decorator
62+
3263

3364
def register_aiu_i8i8_op():
3465
"""Register AIU-specific op to enable torch compile without graph break.
@@ -41,26 +72,26 @@ def register_aiu_i8i8_op():
4172
if hasattr(torch.ops, "fms_mo") and hasattr(torch.ops.fms_mo, "i8i8_aiu"):
4273
logger.warning("AIU op has already been registered")
4374
return
44-
4575
op_namespace_id = "fms_mo::i8i8_aiu"
46-
torch.library.define(
47-
op_namespace_id,
48-
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
49-
"str weight_quant_type, str activ_quant_type, "
50-
"bool smoothquant) "
51-
"-> Tensor",
52-
)
76+
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
77+
torch.library.define(
78+
op_namespace_id,
79+
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
80+
"str weight_quant_type, str activ_quant_type, "
81+
"bool smoothquant) "
82+
"-> Tensor",
83+
)
5384

54-
@torch.library.impl(op_namespace_id, "default")
85+
@implement_op_decorator(op_namespace_id)
5586
def i8i8_aiu(
56-
x,
57-
weight,
58-
bias,
59-
qdata,
60-
weight_quant_type,
61-
activ_quant_type,
62-
smoothquant,
63-
):
87+
x: torch.Tensor,
88+
weight: torch.Tensor,
89+
bias: torch.Tensor,
90+
qdata: torch.Tensor,
91+
weight_quant_type: str,
92+
activ_quant_type: str,
93+
smoothquant: bool,
94+
) -> torch.Tensor:
6495
"""Implement addmm of X and W.
6596
Support various quantization options for weights and activations.
6697
@@ -86,16 +117,8 @@ def i8i8_aiu(
86117

87118
return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype))
88119

89-
@torch.library.impl_abstract(op_namespace_id)
90-
def i8i8_aiu_abstract(
91-
x,
92-
weight,
93-
bias,
94-
qdata,
95-
weight_quant_type,
96-
activ_quant_type,
97-
smoothquant,
98-
):
120+
@register_op_decorator(op_namespace_id)
121+
def _(x, weight, bias, qdata, weight_quant_type, activ_quant_type, smoothquant):
99122
"""OP template of I/O sizes"""
100123

101124
outshape = x.size()[:-1] + (weight.size(0),)

0 commit comments

Comments
 (0)