Skip to content

Commit 109f31a

Browse files
This commit introduces two new Linalg operations:
`conv_2d_nhwgc_gfhwc` and `conv_2d_nhwgc_gfhwc_q`. These operations perform 2-D grouped convolutions with and without zero point offsets, respectively. The input layout is NHWGC, and the kernel layout is GFHWC. These additions enhance support for grouped convolution operations in MLIR.
1 parent 520ddf2 commit 109f31a

File tree

3 files changed

+314
-0
lines changed

3 files changed

+314
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3410,6 +3410,243 @@ structured_op: !LinalgStructuredOpConfig
34103410
- !ScalarExpression
34113411
scalar_arg: K
34123412
--- !LinalgOpConfig
3413+
metadata: !LinalgOpMetadata
3414+
name: conv_2d_nhwgc_gfhwc
3415+
cpp_class_name: Conv2DNhwgcGfhwcOp
3416+
doc: |-
3417+
Performs 2-D grouped convolution.
3418+
3419+
Layout:
3420+
* Input: NHWGC.
3421+
* Kernel: GFHWC.
3422+
3423+
Numeric casting is performed on the operands to the inner multiply, promoting
3424+
them to the same data type as the accumulator/output.
3425+
implements:
3426+
- LinalgConvolutionOpInterface
3427+
structured_op: !LinalgStructuredOpConfig
3428+
args:
3429+
- !LinalgOperandDefConfig
3430+
name: I
3431+
kind: input_tensor
3432+
type_var: T1
3433+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3434+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3435+
- !LinalgOperandDefConfig
3436+
name: K
3437+
kind: input_tensor
3438+
type_var: T2
3439+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3440+
(s9, s11, s3, s7, s10)>
3441+
- !LinalgOperandDefConfig
3442+
name: O
3443+
kind: output_tensor
3444+
type_var: U
3445+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3446+
(s0, s1, s5, s9, s11)>
3447+
- !LinalgOperandDefConfig
3448+
name: strides
3449+
kind: index_attr
3450+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3451+
-> (s2, s6)>
3452+
default_indices:
3453+
- 1
3454+
- 1
3455+
- !LinalgOperandDefConfig
3456+
name: dilations
3457+
kind: index_attr
3458+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3459+
-> (s4, s8)>
3460+
default_indices:
3461+
- 1
3462+
- 1
3463+
indexing_maps: !LinalgIndexingMapsConfig
3464+
static_indexing_maps:
3465+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3466+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3467+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3468+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3469+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3470+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3471+
iterator_types:
3472+
- parallel
3473+
- parallel
3474+
- parallel
3475+
- parallel
3476+
- parallel
3477+
- reduction
3478+
- reduction
3479+
- reduction
3480+
assignments:
3481+
- !ScalarAssign
3482+
arg: O
3483+
value: !ScalarExpression
3484+
scalar_fn:
3485+
kind: binary
3486+
fn_name: add
3487+
operands:
3488+
- !ScalarExpression
3489+
scalar_arg: O
3490+
- !ScalarExpression
3491+
scalar_fn:
3492+
kind: binary
3493+
fn_name: mul
3494+
operands:
3495+
- !ScalarExpression
3496+
scalar_fn:
3497+
kind: type
3498+
fn_name: cast_signed
3499+
type_var: U
3500+
operands:
3501+
- !ScalarExpression
3502+
scalar_arg: I
3503+
- !ScalarExpression
3504+
scalar_fn:
3505+
kind: type
3506+
fn_name: cast_signed
3507+
type_var: U
3508+
operands:
3509+
- !ScalarExpression
3510+
scalar_arg: K
3511+
--- !LinalgOpConfig
3512+
metadata: !LinalgOpMetadata
3513+
name: conv_2d_nhwgc_gfhwc_q
3514+
cpp_class_name: Conv2DNhwgcGfhwcQOp
3515+
doc: |-
3516+
Performs 2-D grouped convolution with zero point offsets.
3517+
3518+
Layout:
3519+
* Input: NHWGC.
3520+
* Kernel: GFHWC.
3521+
3522+
Numeric casting is performed on the operands to the inner multiply, promoting
3523+
them to the same data type as the accumulator/output. This includes the zero
3524+
point offsets common to quantized operations.
3525+
implements:
3526+
- LinalgConvolutionOpInterface
3527+
structured_op: !LinalgStructuredOpConfig
3528+
args:
3529+
- !LinalgOperandDefConfig
3530+
name: I
3531+
kind: input_tensor
3532+
type_var: T1
3533+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3534+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3535+
- !LinalgOperandDefConfig
3536+
name: K
3537+
kind: input_tensor
3538+
type_var: T2
3539+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3540+
(s9, s11, s3, s7, s10)>
3541+
- !LinalgOperandDefConfig
3542+
name: IZp
3543+
kind: scalar
3544+
type_var: I32
3545+
- !LinalgOperandDefConfig
3546+
name: KZp
3547+
kind: scalar
3548+
type_var: I32
3549+
- !LinalgOperandDefConfig
3550+
name: O
3551+
kind: output_tensor
3552+
type_var: U
3553+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3554+
(s0, s1, s5, s9, s11)>
3555+
- !LinalgOperandDefConfig
3556+
name: strides
3557+
kind: index_attr
3558+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3559+
-> (s2, s6)>
3560+
default_indices:
3561+
- 1
3562+
- 1
3563+
- !LinalgOperandDefConfig
3564+
name: dilations
3565+
kind: index_attr
3566+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3567+
-> (s4, s8)>
3568+
default_indices:
3569+
- 1
3570+
- 1
3571+
indexing_maps: !LinalgIndexingMapsConfig
3572+
static_indexing_maps:
3573+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3574+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3575+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3576+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3577+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3578+
s8, s9, s10, s11] -> ()>
3579+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3580+
s8, s9, s10, s11] -> ()>
3581+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3582+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3583+
iterator_types:
3584+
- parallel
3585+
- parallel
3586+
- parallel
3587+
- parallel
3588+
- parallel
3589+
- reduction
3590+
- reduction
3591+
- reduction
3592+
assignments:
3593+
- !ScalarAssign
3594+
arg: O
3595+
value: !ScalarExpression
3596+
scalar_fn:
3597+
kind: binary
3598+
fn_name: add
3599+
operands:
3600+
- !ScalarExpression
3601+
scalar_arg: O
3602+
- !ScalarExpression
3603+
scalar_fn:
3604+
kind: binary
3605+
fn_name: mul
3606+
operands:
3607+
- !ScalarExpression
3608+
scalar_fn:
3609+
kind: binary
3610+
fn_name: sub
3611+
operands:
3612+
- !ScalarExpression
3613+
scalar_fn:
3614+
kind: type
3615+
fn_name: cast_signed
3616+
type_var: U
3617+
operands:
3618+
- !ScalarExpression
3619+
scalar_arg: I
3620+
- !ScalarExpression
3621+
scalar_fn:
3622+
kind: type
3623+
fn_name: cast_signed
3624+
type_var: U
3625+
operands:
3626+
- !ScalarExpression
3627+
scalar_arg: IZp
3628+
- !ScalarExpression
3629+
scalar_fn:
3630+
kind: binary
3631+
fn_name: sub
3632+
operands:
3633+
- !ScalarExpression
3634+
scalar_fn:
3635+
kind: type
3636+
fn_name: cast_signed
3637+
type_var: U
3638+
operands:
3639+
- !ScalarExpression
3640+
scalar_arg: K
3641+
- !ScalarExpression
3642+
scalar_fn:
3643+
kind: type
3644+
fn_name: cast_signed
3645+
type_var: U
3646+
operands:
3647+
- !ScalarExpression
3648+
scalar_arg: KZp
3649+
--- !LinalgOpConfig
34133650
metadata: !LinalgOpMetadata
34143651
name: conv_2d_ngchw_gfchw_q
34153652
cpp_class_name: Conv2DNgchwGfchwQOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,67 @@ def conv_2d_ngchw_gfchw(
952952
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
953953

954954

955+
@linalg_structured_op
956+
def conv_2d_nhwgc_gfhwc(
957+
I=TensorDef(
958+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
959+
),
960+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
961+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
962+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
963+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
964+
):
965+
"""Performs 2-D grouped convolution.
966+
967+
Layout:
968+
* Input: NHWGC.
969+
* Kernel: GFHWC.
970+
971+
Numeric casting is performed on the operands to the inner multiply, promoting
972+
them to the same data type as the accumulator/output.
973+
"""
974+
implements(ConvolutionOpInterface)
975+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
976+
O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed(
977+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
978+
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
979+
980+
981+
@linalg_structured_op
982+
def conv_2d_nhwgc_gfhwc_q(
983+
I=TensorDef(
984+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
985+
),
986+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
987+
IZp=ScalarDef(I32),
988+
KZp=ScalarDef(I32),
989+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
990+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
991+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
992+
):
993+
"""Performs 2-D grouped convolution with zero point offsets.
994+
995+
Layout:
996+
* Input: NHWGC.
997+
* Kernel: GFHWC.
998+
999+
Numeric casting is performed on the operands to the inner multiply, promoting
1000+
them to the same data type as the accumulator/output. This includes the zero
1001+
point offsets common to quantized operations.
1002+
"""
1003+
implements(ConvolutionOpInterface)
1004+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
1005+
O[D.n, D.oh, D.ow, D.g, D.fg] += (
1006+
TypeFn.cast_signed(
1007+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
1008+
)
1009+
- TypeFn.cast_signed(U, IZp)
1010+
) * (
1011+
TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
1012+
- TypeFn.cast_signed(U, KZp)
1013+
)
1014+
1015+
9551016
@linalg_structured_op
9561017
def conv_2d_ngchw_gfchw_q(
9571018
I=TensorDef(

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,22 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
409409

410410
// -----
411411

412+
// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc
413+
func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
414+
// CHECK: linalg.conv_2d_nhwgc_gfhwc
415+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
416+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
417+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
418+
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xf32>)
419+
linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
420+
strides = dense<1> : tensor<2xi64>}
421+
ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
422+
outs (%output: memref<?x?x?x?x?xf32>)
423+
return
424+
}
425+
426+
// -----
427+
412428
// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
413429
func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
414430
// CHECK: linalg.conv_2d_ngchw_fgchw

0 commit comments

Comments
 (0)