Skip to content

Commit 4f5e08a

Browse files
committed
Make custom op syntax dependent on torch version
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 4f230b1 commit 4f5e08a

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,42 @@
1717
import logging
1818

1919
# Third Party
20+
from packaging.version import Version
2021
import torch
2122
import torch.nn.functional as F
2223

23-
logger = logging.getLogger(__name__)
24-
2524
# pylint: disable=unused-argument
2625
# i8i8 op must be registered with specific I/O, even if not in use by the op function
2726

2827
# pylint: disable=not-callable
2928
# torch.nn.functional.linear not recognized as callable
3029
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3130

31+
logger = logging.getLogger(__name__)
32+
torch_version = Version(torch.__version__.split("+", maxsplit=1)[0])
33+
34+
35+
def implement_op_decorator(pt_ver, op_namespace_id):
36+
"""Version-dependent decorator for custom op implementation."""
37+
38+
def decorator(func):
39+
if pt_ver <= Version("2.4"):
40+
return torch.library.impl(op_namespace_id, "default")(func)
41+
return torch.library.custom_op(op_namespace_id, mutates_args=())(func)
42+
43+
return decorator
44+
45+
46+
def register_op_decorator(pt_ver, op_namespace_id):
47+
"""Version-dependent decorator for custom op registration."""
48+
49+
def decorator(func):
50+
if pt_ver <= Version("2.4"):
51+
return torch.library.impl_abstract(op_namespace_id)(func)
52+
return torch.library.register_fake(op_namespace_id)(func)
53+
54+
return decorator
55+
3256

3357
def register_aiu_i8i8_op():
3458
"""Register AIU-specific op to enable torch compile without graph break.
@@ -42,8 +66,16 @@ def register_aiu_i8i8_op():
4266
logger.warning("AIU op has already been registered")
4367
return
4468
op_namespace_id = "fms_mo::i8i8_aiu"
69+
if torch_version <= Version("2.4"):
70+
torch.library.define(
71+
op_namespace_id,
72+
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
73+
"str weight_quant_type, str activ_quant_type, "
74+
"bool smoothquant) "
75+
"-> Tensor",
76+
)
4577

46-
@torch.library.custom_op(op_namespace_id, mutates_args=())
78+
@implement_op_decorator(torch_version, op_namespace_id)
4779
def i8i8_aiu(
4880
x: torch.Tensor,
4981
weight: torch.Tensor,
@@ -78,7 +110,7 @@ def i8i8_aiu(
78110

79111
return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype))
80112

81-
@torch.library.register_fake(op_namespace_id)
113+
@register_op_decorator(torch_version, op_namespace_id)
82114
def _(x, weight, bias, qdata, weight_quant_type, activ_quant_type, smoothquant):
83115
"""OP template of I/O sizes"""
84116

0 commit comments

Comments
 (0)