Skip to content

Commit 4f230b1

Browse files
committed
Update syntax of custom torch ops
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 2c75fcc commit 4f230b1

File tree

3 files changed

+34
-38
lines changed

3 files changed

+34
-38
lines changed

fms_mo/aiu_addons/gptq/gptq_aiu_op.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,27 @@ def register_aiu_gptq_op():
3636
):
3737
logger.warning("AIU op has already been registered")
3838
return
39-
4039
op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
41-
torch.library.define(
42-
op_namespace_id,
43-
"(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor",
44-
)
4540

4641
# Add implementations for the operator
47-
@torch.library.impl(op_namespace_id, "default")
48-
def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
49-
# on AIU, GPTQ qw is [out_feat, in_feat]
42+
@torch.library.custom_op(op_namespace_id, mutates_args=())
43+
def i4f16_fxinputs_aiu(
44+
x: torch.Tensor,
45+
qw: torch.Tensor,
46+
qzeros: torch.Tensor,
47+
scales: torch.Tensor,
48+
g_idx: torch.Tensor,
49+
) -> torch.Tensor:
50+
"""Implement fake processing of GPTQ W4A16 matmul. The purpose is to create a
51+
node on the computational graph to be captured during compiling for AIU.
52+
53+
Instead of computing the weight decompression and matmul, this function returns
54+
a zero tensor with the expected shape.
55+
56+
NOTE: on AIU, GPTQ qw is [out_feat, in_feat], while AutoGPTQ saves the quantized
57+
weights as [in_feat, out_feat]
58+
"""
59+
5060
outshape = x.shape[:-1] + (qw.shape[0],)
5161
x = x.view(-1, x.shape[-1])
5262
output = torch.zeros(
@@ -56,8 +66,10 @@ def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
5666
)
5767
return output.view(outshape)
5868

59-
@torch.library.impl_abstract(op_namespace_id)
60-
def i4f16_fxinputs_aiu_abstract(x, qw, qzeros, scales, g_idx):
69+
@torch.library.register_fake(op_namespace_id)
70+
def _(x, qw, qzeros, scales, g_idx):
71+
"""OP template of I/O sizes"""
72+
6173
outshape = x.shape[:-1] + (qw.shape[0],)
6274
return torch.empty(
6375
outshape,

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,18 @@ def register_aiu_i8i8_op():
4141
if hasattr(torch.ops, "fms_mo") and hasattr(torch.ops.fms_mo, "i8i8_aiu"):
4242
logger.warning("AIU op has already been registered")
4343
return
44-
4544
op_namespace_id = "fms_mo::i8i8_aiu"
46-
torch.library.define(
47-
op_namespace_id,
48-
"(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "
49-
"str weight_quant_type, str activ_quant_type, "
50-
"bool smoothquant) "
51-
"-> Tensor",
52-
)
5345

54-
@torch.library.impl(op_namespace_id, "default")
46+
@torch.library.custom_op(op_namespace_id, mutates_args=())
5547
def i8i8_aiu(
56-
x,
57-
weight,
58-
bias,
59-
qdata,
60-
weight_quant_type,
61-
activ_quant_type,
62-
smoothquant,
63-
):
48+
x: torch.Tensor,
49+
weight: torch.Tensor,
50+
bias: torch.Tensor,
51+
qdata: torch.Tensor,
52+
weight_quant_type: str,
53+
activ_quant_type: str,
54+
smoothquant: bool,
55+
) -> torch.Tensor:
6456
"""Implement addmm of X and W.
6557
Support various quantization options for weights and activations.
6658
@@ -86,16 +78,8 @@ def i8i8_aiu(
8678

8779
return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype))
8880

89-
@torch.library.impl_abstract(op_namespace_id)
90-
def i8i8_aiu_abstract(
91-
x,
92-
weight,
93-
bias,
94-
qdata,
95-
weight_quant_type,
96-
activ_quant_type,
97-
smoothquant,
98-
):
81+
@torch.library.register_fake(op_namespace_id)
82+
def _(x, weight, bias, qdata, weight_quant_type, activ_quant_type, smoothquant):
9983
"""OP template of I/O sizes"""
10084

10185
outshape = x.size()[:-1] + (weight.size(0),)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
2626
"accelerate>=0.20.3,!=0.34,<1.7",
2727
"transformers>=4.45,<4.51",
28-
"torch>=2.2.0,<2.6",
28+
"torch>=2.4,<2.6",
2929
"triton>=3.0,<3.2",
3030
"tqdm>=4.66.2,<5.0",
3131
"datasets>=3.0.0,<4.0",

0 commit comments

Comments
 (0)