2929# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3030
3131logger = logging .getLogger (__name__ )
32- torch_version = Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ])
3332
3433
35- def implement_op_decorator (pt_ver , op_namespace_id ):
36- """Version-dependent decorator for custom op implementation."""
34+ def implement_op_decorator (op_namespace_id ):
35+ """Version-dependent decorator for custom op implementation.
36+ Always compare against pytorch version in current environment.
37+ """
38+
39+ torch_version = Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ])
3740
3841 def decorator (func ):
39- if pt_ver < Version ("2.4" ):
42+ if torch_version < Version ("2.4" ):
4043 return torch .library .impl (op_namespace_id , "default" )(func )
4144 return torch .library .custom_op (op_namespace_id , mutates_args = ())(func )
4245
4346 return decorator
4447
4548
46- def register_op_decorator (pt_ver , op_namespace_id ):
47- """Version-dependent decorator for custom op registration."""
49+ def register_op_decorator (op_namespace_id ):
50+ """Version-dependent decorator for custom op registration.
51+ Always compare against pytorch version in current environment.
52+ """
53+
54+ torch_version = Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ])
4855
4956 def decorator (func ):
50- if pt_ver < Version ("2.4" ):
57+ if torch_version < Version ("2.4" ):
5158 return torch .library .impl_abstract (op_namespace_id )(func )
5259 return torch .library .register_fake (op_namespace_id )(func )
5360
@@ -66,7 +73,7 @@ def register_aiu_i8i8_op():
6673 logger .warning ("AIU op has already been registered" )
6774 return
6875 op_namespace_id = "fms_mo::i8i8_aiu"
69- if torch_version < Version ("2.4" ):
76+ if Version ( torch . __version__ . split ( "+" , maxsplit = 1 )[ 0 ]) < Version ("2.4" ):
7077 torch .library .define (
7178 op_namespace_id ,
7279 "(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
@@ -75,7 +82,7 @@ def register_aiu_i8i8_op():
7582 "-> Tensor" ,
7683 )
7784
78- @implement_op_decorator (torch_version , op_namespace_id )
85+ @implement_op_decorator (op_namespace_id )
7986 def i8i8_aiu (
8087 x : torch .Tensor ,
8188 weight : torch .Tensor ,
@@ -110,7 +117,7 @@ def i8i8_aiu(
110117
111118 return F .linear (x_dq .to (dtype ), w_dq .to (dtype ), bias .to (dtype ))
112119
113- @register_op_decorator (torch_version , op_namespace_id )
120+ @register_op_decorator (op_namespace_id )
114121 def _ (x , weight , bias , qdata , weight_quant_type , activ_quant_type , smoothquant ):
115122 """OP template of I/O sizes"""
116123
0 commit comments