Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 7dd461a

Browse files
[mlir][linalg] Add Grouped Convolution Ops: conv_2d_nhwgc_gfhwc and conv_2d_nhwgc_gfhwc_q (#108192)
This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops. - Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp. - Updated the conversion process to use these new ops for tosa group conv2d operations.
1 parent d8f8d99 commit 7dd461a

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,67 @@ def conv_2d_ngchw_gfchw(
964964
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
965965

966966

967+
@linalg_structured_op
968+
def conv_2d_nhwgc_gfhwc(
969+
I=TensorDef(
970+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
971+
),
972+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
973+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
974+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
975+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
976+
):
977+
"""Performs 2-D grouped convolution.
978+
979+
Layout:
980+
* Input: NHWGC.
981+
* Kernel: GFHWC.
982+
983+
Numeric casting is performed on the operands to the inner multiply, promoting
984+
them to the same data type as the accumulator/output.
985+
"""
986+
implements(ConvolutionOpInterface)
987+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
988+
O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed(
989+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
990+
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
991+
992+
993+
@linalg_structured_op
994+
def conv_2d_nhwgc_gfhwc_q(
995+
I=TensorDef(
996+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
997+
),
998+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
999+
IZp=ScalarDef(I32),
1000+
KZp=ScalarDef(I32),
1001+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
1002+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1003+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1004+
):
1005+
"""Performs 2-D grouped convolution with zero point offsets.
1006+
1007+
Layout:
1008+
* Input: NHWGC.
1009+
* Kernel: GFHWC.
1010+
1011+
Numeric casting is performed on the operands to the inner multiply, promoting
1012+
them to the same data type as the accumulator/output. This includes the zero
1013+
point offsets common to quantized operations.
1014+
"""
1015+
implements(ConvolutionOpInterface)
1016+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
1017+
O[D.n, D.oh, D.ow, D.g, D.fg] += (
1018+
TypeFn.cast_signed(
1019+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
1020+
)
1021+
- TypeFn.cast_signed(U, IZp)
1022+
) * (
1023+
TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
1024+
- TypeFn.cast_signed(U, KZp)
1025+
)
1026+
1027+
9671028
@linalg_structured_op
9681029
def conv_2d_ngchw_gfchw_q(
9691030
I=TensorDef(

0 commit comments

Comments
 (0)