1717import logging
1818
1919# Third Party
20+ from packaging .version import Version
2021import torch
2122
2223# pylint: disable=unused-argument
2324# gptq op must be registered with specific I/O, even if not in use by the op function
2425
2526logger = logging .getLogger (__name__ )
27+ torch_version = Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ])
28+
29+
30+ def implement_op_decorator (pt_ver , op_namespace_id ):
31+ """Version-dependent decorator for custom op implementation."""
32+
33+ def decorator (func ):
34+ if pt_ver < Version ("2.4" ):
35+ return torch .library .impl (op_namespace_id , "default" )(func )
36+ return torch .library .custom_op (op_namespace_id , mutates_args = ())(func )
37+
38+ return decorator
39+
40+
41+ def register_op_decorator (pt_ver , op_namespace_id ):
42+ """Version-dependent decorator for custom op registration."""
43+
44+ def decorator (func ):
45+ if pt_ver <= Version ("2.4" ):
46+ return torch .library .impl_abstract (op_namespace_id )(func )
47+ return torch .library .register_fake (op_namespace_id )(func )
48+
49+ return decorator
2650
2751
2852def register_aiu_gptq_op ():
@@ -37,9 +61,15 @@ def register_aiu_gptq_op():
3761 logger .warning ("AIU op has already been registered" )
3862 return
3963 op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
64+ if torch_version <= Version ("2.4" ):
65+ torch .library .define (
66+ op_namespace_id ,
67+ "(Tensor x, Tensor qw, Tensor qzeros, "
68+ "Tensor scales, Tensor g_idx) -> Tensor" ,
69+ )
4070
4171 # Add implementations for the operator
42- @torch . library . custom_op ( op_namespace_id , mutates_args = () )
72+ @implement_op_decorator ( torch_version , op_namespace_id )
4373 def i4f16_fxinputs_aiu (
4474 x : torch .Tensor ,
4575 qw : torch .Tensor ,
@@ -66,7 +96,7 @@ def i4f16_fxinputs_aiu(
6696 )
6797 return output .view (outshape )
6898
69- @torch . library . register_fake ( op_namespace_id )
99+ @register_op_decorator ( torch_version , op_namespace_id )
70100 def _ (x , qw , qzeros , scales , g_idx ):
71101 """OP template of I/O sizes"""
72102
0 commit comments