Skip to content

Commit 4aed023

Browse files
This commit introduces two new Linalg operations:
`conv_2d_nhwgc_gfhwc` and `conv_2d_nhwgc_gfhwc_q`. These operations perform 2-D grouped convolutions with and without zero point offsets, respectively. The input layout is NHWGC, and the kernel layout is GFHWC. These additions enhance support for grouped convolution operations in MLIR.
1 parent e61a7dc commit 4aed023

File tree

3 files changed

+330
-0
lines changed

3 files changed

+330
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3547,6 +3547,243 @@ structured_op: !LinalgStructuredOpConfig
35473547
- !ScalarExpression
35483548
scalar_arg: K
35493549
--- !LinalgOpConfig
3550+
metadata: !LinalgOpMetadata
3551+
name: conv_2d_nhwgc_gfhwc
3552+
cpp_class_name: Conv2DNhwgcGfhwcOp
3553+
doc: |-
3554+
Performs 2-D grouped convolution.
3555+
3556+
Layout:
3557+
* Input: NHWGC.
3558+
* Kernel: GFHWC.
3559+
3560+
Numeric casting is performed on the operands to the inner multiply, promoting
3561+
them to the same data type as the accumulator/output.
3562+
implements:
3563+
- LinalgConvolutionOpInterface
3564+
structured_op: !LinalgStructuredOpConfig
3565+
args:
3566+
- !LinalgOperandDefConfig
3567+
name: I
3568+
kind: input_tensor
3569+
type_var: T1
3570+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3571+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3572+
- !LinalgOperandDefConfig
3573+
name: K
3574+
kind: input_tensor
3575+
type_var: T2
3576+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3577+
(s9, s11, s3, s7, s10)>
3578+
- !LinalgOperandDefConfig
3579+
name: O
3580+
kind: output_tensor
3581+
type_var: U
3582+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3583+
(s0, s1, s5, s9, s11)>
3584+
- !LinalgOperandDefConfig
3585+
name: strides
3586+
kind: index_attr
3587+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3588+
-> (s2, s6)>
3589+
default_indices:
3590+
- 1
3591+
- 1
3592+
- !LinalgOperandDefConfig
3593+
name: dilations
3594+
kind: index_attr
3595+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3596+
-> (s4, s8)>
3597+
default_indices:
3598+
- 1
3599+
- 1
3600+
indexing_maps: !LinalgIndexingMapsConfig
3601+
static_indexing_maps:
3602+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3603+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3604+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3605+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3606+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3607+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3608+
iterator_types:
3609+
- parallel
3610+
- parallel
3611+
- parallel
3612+
- parallel
3613+
- parallel
3614+
- reduction
3615+
- reduction
3616+
- reduction
3617+
assignments:
3618+
- !ScalarAssign
3619+
arg: O
3620+
value: !ScalarExpression
3621+
scalar_fn:
3622+
kind: binary
3623+
fn_name: add
3624+
operands:
3625+
- !ScalarExpression
3626+
scalar_arg: O
3627+
- !ScalarExpression
3628+
scalar_fn:
3629+
kind: binary
3630+
fn_name: mul
3631+
operands:
3632+
- !ScalarExpression
3633+
scalar_fn:
3634+
kind: type
3635+
fn_name: cast_signed
3636+
type_var: U
3637+
operands:
3638+
- !ScalarExpression
3639+
scalar_arg: I
3640+
- !ScalarExpression
3641+
scalar_fn:
3642+
kind: type
3643+
fn_name: cast_signed
3644+
type_var: U
3645+
operands:
3646+
- !ScalarExpression
3647+
scalar_arg: K
3648+
--- !LinalgOpConfig
3649+
metadata: !LinalgOpMetadata
3650+
name: conv_2d_nhwgc_gfhwc_q
3651+
cpp_class_name: Conv2DNhwgcGfhwcQOp
3652+
doc: |-
3653+
Performs 2-D grouped convolution with zero point offsets.
3654+
3655+
Layout:
3656+
* Input: NHWGC.
3657+
* Kernel: GFHWC.
3658+
3659+
Numeric casting is performed on the operands to the inner multiply, promoting
3660+
them to the same data type as the accumulator/output. This includes the zero
3661+
point offsets common to quantized operations.
3662+
implements:
3663+
- LinalgConvolutionOpInterface
3664+
structured_op: !LinalgStructuredOpConfig
3665+
args:
3666+
- !LinalgOperandDefConfig
3667+
name: I
3668+
kind: input_tensor
3669+
type_var: T1
3670+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3671+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3672+
- !LinalgOperandDefConfig
3673+
name: K
3674+
kind: input_tensor
3675+
type_var: T2
3676+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3677+
(s9, s11, s3, s7, s10)>
3678+
- !LinalgOperandDefConfig
3679+
name: IZp
3680+
kind: scalar
3681+
type_var: I32
3682+
- !LinalgOperandDefConfig
3683+
name: KZp
3684+
kind: scalar
3685+
type_var: I32
3686+
- !LinalgOperandDefConfig
3687+
name: O
3688+
kind: output_tensor
3689+
type_var: U
3690+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3691+
(s0, s1, s5, s9, s11)>
3692+
- !LinalgOperandDefConfig
3693+
name: strides
3694+
kind: index_attr
3695+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3696+
-> (s2, s6)>
3697+
default_indices:
3698+
- 1
3699+
- 1
3700+
- !LinalgOperandDefConfig
3701+
name: dilations
3702+
kind: index_attr
3703+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3704+
-> (s4, s8)>
3705+
default_indices:
3706+
- 1
3707+
- 1
3708+
indexing_maps: !LinalgIndexingMapsConfig
3709+
static_indexing_maps:
3710+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3711+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3712+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3713+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3714+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3715+
s8, s9, s10, s11] -> ()>
3716+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3717+
s8, s9, s10, s11] -> ()>
3718+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3719+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3720+
iterator_types:
3721+
- parallel
3722+
- parallel
3723+
- parallel
3724+
- parallel
3725+
- parallel
3726+
- reduction
3727+
- reduction
3728+
- reduction
3729+
assignments:
3730+
- !ScalarAssign
3731+
arg: O
3732+
value: !ScalarExpression
3733+
scalar_fn:
3734+
kind: binary
3735+
fn_name: add
3736+
operands:
3737+
- !ScalarExpression
3738+
scalar_arg: O
3739+
- !ScalarExpression
3740+
scalar_fn:
3741+
kind: binary
3742+
fn_name: mul
3743+
operands:
3744+
- !ScalarExpression
3745+
scalar_fn:
3746+
kind: binary
3747+
fn_name: sub
3748+
operands:
3749+
- !ScalarExpression
3750+
scalar_fn:
3751+
kind: type
3752+
fn_name: cast_signed
3753+
type_var: U
3754+
operands:
3755+
- !ScalarExpression
3756+
scalar_arg: I
3757+
- !ScalarExpression
3758+
scalar_fn:
3759+
kind: type
3760+
fn_name: cast_signed
3761+
type_var: U
3762+
operands:
3763+
- !ScalarExpression
3764+
scalar_arg: IZp
3765+
- !ScalarExpression
3766+
scalar_fn:
3767+
kind: binary
3768+
fn_name: sub
3769+
operands:
3770+
- !ScalarExpression
3771+
scalar_fn:
3772+
kind: type
3773+
fn_name: cast_signed
3774+
type_var: U
3775+
operands:
3776+
- !ScalarExpression
3777+
scalar_arg: K
3778+
- !ScalarExpression
3779+
scalar_fn:
3780+
kind: type
3781+
fn_name: cast_signed
3782+
type_var: U
3783+
operands:
3784+
- !ScalarExpression
3785+
scalar_arg: KZp
3786+
--- !LinalgOpConfig
35503787
metadata: !LinalgOpMetadata
35513788
name: conv_2d_ngchw_gfchw_q
35523789
cpp_class_name: Conv2DNgchwGfchwQOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,67 @@ def conv_2d_ngchw_gfchw(
981981
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
982982

983983

984+
@linalg_structured_op
985+
def conv_2d_nhwgc_gfhwc(
986+
I=TensorDef(
987+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
988+
),
989+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
990+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
991+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
992+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
993+
):
994+
"""Performs 2-D grouped convolution.
995+
996+
Layout:
997+
* Input: NHWGC.
998+
* Kernel: GFHWC.
999+
1000+
Numeric casting is performed on the operands to the inner multiply, promoting
1001+
them to the same data type as the accumulator/output.
1002+
"""
1003+
implements(ConvolutionOpInterface)
1004+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
1005+
O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed(
1006+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
1007+
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
1008+
1009+
1010+
@linalg_structured_op
1011+
def conv_2d_nhwgc_gfhwc_q(
1012+
I=TensorDef(
1013+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
1014+
),
1015+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
1016+
IZp=ScalarDef(I32),
1017+
KZp=ScalarDef(I32),
1018+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
1019+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1020+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1021+
):
1022+
"""Performs 2-D grouped convolution with zero point offsets.
1023+
1024+
Layout:
1025+
* Input: NHWGC.
1026+
* Kernel: GFHWC.
1027+
1028+
Numeric casting is performed on the operands to the inner multiply, promoting
1029+
them to the same data type as the accumulator/output. This includes the zero
1030+
point offsets common to quantized operations.
1031+
"""
1032+
implements(ConvolutionOpInterface)
1033+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
1034+
O[D.n, D.oh, D.ow, D.g, D.fg] += (
1035+
TypeFn.cast_signed(
1036+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
1037+
)
1038+
- TypeFn.cast_signed(U, IZp)
1039+
) * (
1040+
TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
1041+
- TypeFn.cast_signed(U, KZp)
1042+
)
1043+
1044+
9841045
@linalg_structured_op
9851046
def conv_2d_ngchw_gfchw_q(
9861047
I=TensorDef(

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
409409

410410
// -----
411411

412+
// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc
413+
func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
414+
// CHECK: linalg.conv_2d_nhwgc_gfhwc
415+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
416+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
417+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
418+
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xf32>)
419+
linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
420+
strides = dense<1> : tensor<2xi64>}
421+
ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
422+
outs (%output: memref<?x?x?x?x?xf32>)
423+
return
424+
}
425+
426+
// -----
427+
428+
// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc_tensor
429+
func.func @conv_2d_nhwgc_gfhwc_tensor(%input: tensor<1x28x28x2x3xf32>, %filter: tensor<2x8x3x3x3xf32>, %output: tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32> {
430+
// CHECK: linalg.conv_2d_nhwgc_gfhwc
431+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
432+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
433+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x28x28x2x3xf32>, tensor<2x8x3x3x3xf32>)
434+
// CHECK-SAME: outs(%{{.+}} : tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32>
435+
%0 = linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
436+
strides = dense<1> : tensor<2xi64>}
437+
ins (%input, %filter: tensor<1x28x28x2x3xf32>, tensor<2x8x3x3x3xf32>)
438+
outs (%output: tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32>
439+
return %0 : tensor<1x26x26x2x8xf32>
440+
}
441+
442+
// -----
443+
412444
// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
413445
func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
414446
// CHECK: linalg.conv_2d_ngchw_fgchw

0 commit comments

Comments
 (0)