Skip to content

Commit 0003bb1

Browse files
committed
Incorporate torch version acquisition into decorator
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 5cfc9bf commit 0003bb1

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

fms_mo/aiu_addons/gptq/gptq_aiu_op.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,32 @@
2424
# gptq op must be registered with specific I/O, even if not in use by the op function
2525

2626
logger = logging.getLogger(__name__)
27-
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
2827

2928

30-
def implement_op_decorator(pt_ver, op_namespace_id):
31-
"""Version-dependent decorator for custom op implementation."""
29+
def implement_op_decorator(op_namespace_id):
30+
"""Version-dependent decorator for custom op implementation.
31+
Always compare against pytorch version in current environment.
32+
"""
33+
34+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
3235

3336
def decorator(func):
34-
if pt_ver < Version("2.4"):
37+
if torch_version < Version("2.4"):
3538
return torch.library.impl(op_namespace_id, "default")(func)
3639
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
3740

3841
return decorator
3942

4043

41-
def register_op_decorator(pt_ver, op_namespace_id):
42-
"""Version-dependent decorator for custom op registration."""
44+
def register_op_decorator(op_namespace_id):
45+
"""Version-dependent decorator for custom op registration.
46+
Always compare against pytorch version in current environment.
47+
"""
48+
49+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
4350

4451
def decorator(func):
45-
if pt_ver < Version("2.4"):
52+
if torch_version < Version("2.4"):
4653
return torch.library.impl_abstract(op_namespace_id)(func)
4754
return torch.library.register_fake(op_namespace_id)(func)
4855

@@ -61,15 +68,15 @@ def register_aiu_gptq_op():
6168
logger.warning("AIU op has already been registered")
6269
return
6370
op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
64-
if torch_version < Version("2.4"):
71+
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
6572
torch.library.define(
6673
op_namespace_id,
6774
"(Tensor x, Tensor qw, Tensor qzeros, "
6875
"Tensor scales, Tensor g_idx) -> Tensor",
6976
)
7077

7178
# Add implementations for the operator
72-
@implement_op_decorator(torch_version, op_namespace_id)
79+
@implement_op_decorator(op_namespace_id)
7380
def i4f16_fxinputs_aiu(
7481
x: torch.Tensor,
7582
qw: torch.Tensor,
@@ -96,7 +103,7 @@ def i4f16_fxinputs_aiu(
96103
)
97104
return output.view(outshape)
98105

99-
@register_op_decorator(torch_version, op_namespace_id)
106+
@register_op_decorator(op_namespace_id)
100107
def _(x, qw, qzeros, scales, g_idx):
101108
"""OP template of I/O sizes"""
102109

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,32 @@
2929
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3030

3131
logger = logging.getLogger(__name__)
32-
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
3332

3433

35-
def implement_op_decorator(pt_ver, op_namespace_id):
36-
"""Version-dependent decorator for custom op implementation."""
34+
def implement_op_decorator(op_namespace_id):
35+
"""Version-dependent decorator for custom op implementation.
36+
Always compare against pytorch version in current environment.
37+
"""
38+
39+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
3740

3841
def decorator(func):
39-
if pt_ver < Version("2.4"):
42+
if torch_version < Version("2.4"):
4043
return torch.library.impl(op_namespace_id, "default")(func)
4144
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
4245

4346
return decorator
4447

4548

46-
def register_op_decorator(pt_ver, op_namespace_id):
47-
"""Version-dependent decorator for custom op registration."""
49+
def register_op_decorator(op_namespace_id):
50+
"""Version-dependent decorator for custom op registration.
51+
Always compare against pytorch version in current environment.
52+
"""
53+
54+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
4855

4956
def decorator(func):
50-
if pt_ver < Version("2.4"):
57+
if torch_version < Version("2.4"):
5158
return torch.library.impl_abstract(op_namespace_id)(func)
5259
return torch.library.register_fake(op_namespace_id)(func)
5360

@@ -66,7 +73,7 @@ def register_aiu_i8i8_op():
6673
logger.warning("AIU op has already been registered")
6774
return
6875
op_namespace_id = "fms_mo::i8i8_aiu"
69-
if torch_version < Version("2.4"):
76+
if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"):
7077
torch.library.define(
7178
op_namespace_id,
7279
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
@@ -75,7 +82,7 @@ def register_aiu_i8i8_op():
7582
"-> Tensor",
7683
)
7784

78-
@implement_op_decorator(torch_version, op_namespace_id)
85+
@implement_op_decorator(op_namespace_id)
7986
def i8i8_aiu(
8087
x: torch.Tensor,
8188
weight: torch.Tensor,
@@ -110,7 +117,7 @@ def i8i8_aiu(
110117

111118
return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype))
112119

113-
@register_op_decorator(torch_version, op_namespace_id)
120+
@register_op_decorator(op_namespace_id)
114121
def _(x, weight, bias, qdata, weight_quant_type, activ_quant_type, smoothquant):
115122
"""OP template of I/O sizes"""
116123

0 commit comments

Comments
 (0)