1717import logging
1818
1919# Third Party
20+ from packaging .version import Version
2021import torch
2122
2223# pylint: disable=unused-argument
2526logger = logging .getLogger (__name__ )
2627
2728
29+ def implement_op_decorator (op_namespace_id ):
30+ """Version-dependent decorator for custom op implementation.
31+ Always compare against pytorch version in current environment.
32+ """
33+
34+ torch_version = Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ])
35+
36+ def decorator (func ):
37+ if torch_version < Version ("2.4" ):
38+ return torch .library .impl (op_namespace_id , "default" )(func )
39+ return torch .library .custom_op (op_namespace_id , mutates_args = ())(func )
40+
41+ return decorator
42+
43+
44+ def register_op_decorator (op_namespace_id ):
45+ """Version-dependent decorator for custom op registration.
46+ Always compare against pytorch version in current environment.
47+ """
48+
49+ torch_version = Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ])
50+
51+ def decorator (func ):
52+ if torch_version < Version ("2.4" ):
53+ return torch .library .impl_abstract (op_namespace_id )(func )
54+ return torch .library .register_fake (op_namespace_id )(func )
55+
56+ return decorator
57+
58+
2859def register_aiu_gptq_op ():
2960 """Register AIU-specific op to enable torch compile without graph break.
3061 The op preserves I/O shapes of a `X @ W^T` matmul but performs no operation.
@@ -36,17 +67,33 @@ def register_aiu_gptq_op():
3667 ):
3768 logger .warning ("AIU op has already been registered" )
3869 return
39-
4070 op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
41- torch .library .define (
42- op_namespace_id ,
43- "(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor" ,
44- )
71+ if Version (torch .__version__ .split ("+" , maxsplit = 1 )[0 ]) < Version ("2.4" ):
72+ torch .library .define (
73+ op_namespace_id ,
74+ "(Tensor x, Tensor qw, Tensor qzeros, "
75+ "Tensor scales, Tensor g_idx) -> Tensor" ,
76+ )
4577
4678 # Add implementations for the operator
47- @torch .library .impl (op_namespace_id , "default" )
48- def i4f16_fxinputs_aiu (x , qw , qzeros , scales , g_idx ):
49- # on AIU, GPTQ qw is [out_feat, in_feat]
79+ @implement_op_decorator (op_namespace_id )
80+ def i4f16_fxinputs_aiu (
81+ x : torch .Tensor ,
82+ qw : torch .Tensor ,
83+ qzeros : torch .Tensor ,
84+ scales : torch .Tensor ,
85+ g_idx : torch .Tensor ,
86+ ) -> torch .Tensor :
87+ """Implement fake processing of GPTQ W4A16 matmul. The purpose is to create a
88+ node on the computational graph to be captured during compiling for AIU.
89+
90+ Instead of computing the weight decompression and matmul, this function returns
91+ a zero tensor with the expected shape.
92+
93+ NOTE: on AIU, GPTQ qw is [out_feat, in_feat], while AutoGPTQ saves the quantized
94+ weights as [in_feat, out_feat]
95+ """
96+
5097 outshape = x .shape [:- 1 ] + (qw .shape [0 ],)
5198 x = x .view (- 1 , x .shape [- 1 ])
5299 output = torch .zeros (
@@ -56,8 +103,10 @@ def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
56103 )
57104 return output .view (outshape )
58105
59- @torch .library .impl_abstract (op_namespace_id )
60- def i4f16_fxinputs_aiu_abstract (x , qw , qzeros , scales , g_idx ):
106+ @register_op_decorator (op_namespace_id )
107+ def _ (x , qw , qzeros , scales , g_idx ):
108+ """OP template of I/O sizes"""
109+
61110 outshape = x .shape [:- 1 ] + (qw .shape [0 ],)
62111 return torch .empty (
63112 outshape ,
0 commit comments