1717import logging
1818
1919# Third Party
20+ from packaging .version import Version
2021import torch
2122import torch .nn .functional as F
2223
23- logger = logging .getLogger (__name__ )
24-
2524# pylint: disable=unused-argument
2625# i8i8 op must be registered with specific I/O, even if not in use by the op function
2726
2827# pylint: disable=not-callable
2928# torch.nn.functional.linear not recognized as callable
3029# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3130
31+ logger = logging .getLogger (__name__ )
32+ torch_version = Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ])
33+
34+
35+ def implement_op_decorator (pt_ver , op_namespace_id ):
36+ """Version-dependent decorator for custom op implementation."""
37+
38+ def decorator (func ):
39+ if pt_ver <= Version ("2.4" ):
40+ return torch .library .impl (op_namespace_id , "default" )(func )
41+ return torch .library .custom_op (op_namespace_id , mutates_args = ())(func )
42+
43+ return decorator
44+
45+
46+ def register_op_decorator (pt_ver , op_namespace_id ):
47+ """Version-dependent decorator for custom op registration."""
48+
49+ def decorator (func ):
50+ if pt_ver <= Version ("2.4" ):
51+ return torch .library .impl_abstract (op_namespace_id )(func )
52+ return torch .library .register_fake (op_namespace_id )(func )
53+
54+ return decorator
55+
3256
3357def register_aiu_i8i8_op ():
3458 """Register AIU-specific op to enable torch compile without graph break.
@@ -42,8 +66,16 @@ def register_aiu_i8i8_op():
4266 logger .warning ("AIU op has already been registered" )
4367 return
4468 op_namespace_id = "fms_mo::i8i8_aiu"
69+ if torch_version <= Version ("2.4" ):
70+ torch .library .define (
71+ op_namespace_id ,
72+ "(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
73+ "str weight_quant_type, str activ_quant_type, "
74+ "bool smoothquant) "
75+ "-> Tensor" ,
76+ )
4577
46- @torch . library . custom_op ( op_namespace_id , mutates_args = () )
78+ @implement_op_decorator ( torch_version , op_namespace_id )
4779 def i8i8_aiu (
4880 x : torch .Tensor ,
4981 weight : torch .Tensor ,
@@ -78,7 +110,7 @@ def i8i8_aiu(
78110
79111 return F .linear (x_dq .to (dtype ), w_dq .to (dtype ), bias .to (dtype ))
80112
81- @torch . library . register_fake ( op_namespace_id )
113+ @register_op_decorator ( torch_version , op_namespace_id )
82114 def _ (x , weight , bias , qdata , weight_quant_type , activ_quant_type , smoothquant ):
83115 """OP template of I/O sizes"""
84116
0 commit comments