Skip to content

Commit 591fe36

Browse files
committed
make gptq op version-dependent too
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 41c853f commit 591fe36

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

fms_mo/aiu_addons/gptq/gptq_aiu_op.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,36 @@
1717
import logging
1818

1919
# Third Party
20+
from packaging.version import Version
2021
import torch
2122

2223
# pylint: disable=unused-argument
2324
# gptq op must be registered with specific I/O, even if not in use by the op function
2425

2526
logger = logging.getLogger(__name__)
27+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
28+
29+
30+
def implement_op_decorator(pt_ver, op_namespace_id):
31+
"""Version-dependent decorator for custom op implementation."""
32+
33+
def decorator(func):
34+
if pt_ver < Version("2.4"):
35+
return torch.library.impl(op_namespace_id, "default")(func)
36+
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
37+
38+
return decorator
39+
40+
41+
def register_op_decorator(pt_ver, op_namespace_id):
42+
"""Version-dependent decorator for custom op registration."""
43+
44+
def decorator(func):
45+
if pt_ver <= Version("2.4"):
46+
return torch.library.impl_abstract(op_namespace_id)(func)
47+
return torch.library.register_fake(op_namespace_id)(func)
48+
49+
return decorator
2650

2751

2852
def register_aiu_gptq_op():
@@ -37,9 +61,15 @@ def register_aiu_gptq_op():
3761
logger.warning("AIU op has already been registered")
3862
return
3963
op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
64+
if torch_version <= Version("2.4"):
65+
torch.library.define(
66+
op_namespace_id,
67+
"(Tensor x, Tensor qw, Tensor qzeros, "
68+
"Tensor scales, Tensor g_idx) -> Tensor",
69+
)
4070

4171
# Add implementations for the operator
42-
@torch.library.custom_op(op_namespace_id, mutates_args=())
72+
@implement_op_decorator(torch_version, op_namespace_id)
4373
def i4f16_fxinputs_aiu(
4474
x: torch.Tensor,
4575
qw: torch.Tensor,
@@ -66,7 +96,7 @@ def i4f16_fxinputs_aiu(
6696
)
6797
return output.view(outshape)
6898

69-
@torch.library.register_fake(op_namespace_id)
99+
@register_op_decorator(torch_version, op_namespace_id)
70100
def _(x, qw, qzeros, scales, g_idx):
71101
"""OP template of I/O sizes"""
72102

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def register_aiu_i8i8_op():
6666
logger.warning("AIU op has already been registered")
6767
return
6868
op_namespace_id = "fms_mo::i8i8_aiu"
69-
if torch_version <= Version("2.4"):
69+
if torch_version < Version("2.4"):
7070
torch.library.define(
7171
op_namespace_id,
7272
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "

0 commit comments

Comments
 (0)