Skip to content

Commit 7ee8a07

Browse files
[mlir][tosa] Convert group tosa::Conv2DOp to linalg conv
This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops. - Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp. - Updated the conversion process to use these new ops for tosa group conv2d operations.
1 parent 1e3a24d commit 7ee8a07

File tree

9 files changed

+448
-46
lines changed

9 files changed

+448
-46
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3410,6 +3410,243 @@ structured_op: !LinalgStructuredOpConfig
34103410
- !ScalarExpression
34113411
scalar_arg: K
34123412
--- !LinalgOpConfig
3413+
metadata: !LinalgOpMetadata
3414+
name: conv_2d_nhwgc_gfhwc
3415+
cpp_class_name: Conv2DNhwgcGfhwcOp
3416+
doc: |-
3417+
Performs 2-D grouped convolution.
3418+
3419+
Layout:
3420+
* Input: NHWGC.
3421+
* Kernel: GFHWC.
3422+
3423+
Numeric casting is performed on the operands to the inner multiply, promoting
3424+
them to the same data type as the accumulator/output.
3425+
implements:
3426+
- LinalgConvolutionOpInterface
3427+
structured_op: !LinalgStructuredOpConfig
3428+
args:
3429+
- !LinalgOperandDefConfig
3430+
name: I
3431+
kind: input_tensor
3432+
type_var: T1
3433+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3434+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3435+
- !LinalgOperandDefConfig
3436+
name: K
3437+
kind: input_tensor
3438+
type_var: T2
3439+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3440+
(s9, s11, s3, s7, s10)>
3441+
- !LinalgOperandDefConfig
3442+
name: O
3443+
kind: output_tensor
3444+
type_var: U
3445+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3446+
(s0, s1, s5, s9, s11)>
3447+
- !LinalgOperandDefConfig
3448+
name: strides
3449+
kind: index_attr
3450+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3451+
-> (s2, s6)>
3452+
default_indices:
3453+
- 1
3454+
- 1
3455+
- !LinalgOperandDefConfig
3456+
name: dilations
3457+
kind: index_attr
3458+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3459+
-> (s4, s8)>
3460+
default_indices:
3461+
- 1
3462+
- 1
3463+
indexing_maps: !LinalgIndexingMapsConfig
3464+
static_indexing_maps:
3465+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3466+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3467+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3468+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3469+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3470+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3471+
iterator_types:
3472+
- parallel
3473+
- parallel
3474+
- parallel
3475+
- parallel
3476+
- parallel
3477+
- reduction
3478+
- reduction
3479+
- reduction
3480+
assignments:
3481+
- !ScalarAssign
3482+
arg: O
3483+
value: !ScalarExpression
3484+
scalar_fn:
3485+
kind: binary
3486+
fn_name: add
3487+
operands:
3488+
- !ScalarExpression
3489+
scalar_arg: O
3490+
- !ScalarExpression
3491+
scalar_fn:
3492+
kind: binary
3493+
fn_name: mul
3494+
operands:
3495+
- !ScalarExpression
3496+
scalar_fn:
3497+
kind: type
3498+
fn_name: cast_signed
3499+
type_var: U
3500+
operands:
3501+
- !ScalarExpression
3502+
scalar_arg: I
3503+
- !ScalarExpression
3504+
scalar_fn:
3505+
kind: type
3506+
fn_name: cast_signed
3507+
type_var: U
3508+
operands:
3509+
- !ScalarExpression
3510+
scalar_arg: K
3511+
--- !LinalgOpConfig
3512+
metadata: !LinalgOpMetadata
3513+
name: conv_2d_nhwgc_gfhwc_q
3514+
cpp_class_name: Conv2DNhwgcGfhwcQOp
3515+
doc: |-
3516+
Performs 2-D grouped convolution with zero point offsets.
3517+
3518+
Layout:
3519+
* Input: NHWGC.
3520+
* Kernel: GFHWC.
3521+
3522+
Numeric casting is performed on the operands to the inner multiply, promoting
3523+
them to the same data type as the accumulator/output. This includes the zero
3524+
point offsets common to quantized operations.
3525+
implements:
3526+
- LinalgConvolutionOpInterface
3527+
structured_op: !LinalgStructuredOpConfig
3528+
args:
3529+
- !LinalgOperandDefConfig
3530+
name: I
3531+
kind: input_tensor
3532+
type_var: T1
3533+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3534+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3535+
- !LinalgOperandDefConfig
3536+
name: K
3537+
kind: input_tensor
3538+
type_var: T2
3539+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3540+
(s9, s11, s3, s7, s10)>
3541+
- !LinalgOperandDefConfig
3542+
name: IZp
3543+
kind: scalar
3544+
type_var: I32
3545+
- !LinalgOperandDefConfig
3546+
name: KZp
3547+
kind: scalar
3548+
type_var: I32
3549+
- !LinalgOperandDefConfig
3550+
name: O
3551+
kind: output_tensor
3552+
type_var: U
3553+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3554+
(s0, s1, s5, s9, s11)>
3555+
- !LinalgOperandDefConfig
3556+
name: strides
3557+
kind: index_attr
3558+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3559+
-> (s2, s6)>
3560+
default_indices:
3561+
- 1
3562+
- 1
3563+
- !LinalgOperandDefConfig
3564+
name: dilations
3565+
kind: index_attr
3566+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3567+
-> (s4, s8)>
3568+
default_indices:
3569+
- 1
3570+
- 1
3571+
indexing_maps: !LinalgIndexingMapsConfig
3572+
static_indexing_maps:
3573+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3574+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3575+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3576+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3577+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3578+
s8, s9, s10, s11] -> ()>
3579+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3580+
s8, s9, s10, s11] -> ()>
3581+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3582+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3583+
iterator_types:
3584+
- parallel
3585+
- parallel
3586+
- parallel
3587+
- parallel
3588+
- parallel
3589+
- reduction
3590+
- reduction
3591+
- reduction
3592+
assignments:
3593+
- !ScalarAssign
3594+
arg: O
3595+
value: !ScalarExpression
3596+
scalar_fn:
3597+
kind: binary
3598+
fn_name: add
3599+
operands:
3600+
- !ScalarExpression
3601+
scalar_arg: O
3602+
- !ScalarExpression
3603+
scalar_fn:
3604+
kind: binary
3605+
fn_name: mul
3606+
operands:
3607+
- !ScalarExpression
3608+
scalar_fn:
3609+
kind: binary
3610+
fn_name: sub
3611+
operands:
3612+
- !ScalarExpression
3613+
scalar_fn:
3614+
kind: type
3615+
fn_name: cast_signed
3616+
type_var: U
3617+
operands:
3618+
- !ScalarExpression
3619+
scalar_arg: I
3620+
- !ScalarExpression
3621+
scalar_fn:
3622+
kind: type
3623+
fn_name: cast_signed
3624+
type_var: U
3625+
operands:
3626+
- !ScalarExpression
3627+
scalar_arg: IZp
3628+
- !ScalarExpression
3629+
scalar_fn:
3630+
kind: binary
3631+
fn_name: sub
3632+
operands:
3633+
- !ScalarExpression
3634+
scalar_fn:
3635+
kind: type
3636+
fn_name: cast_signed
3637+
type_var: U
3638+
operands:
3639+
- !ScalarExpression
3640+
scalar_arg: K
3641+
- !ScalarExpression
3642+
scalar_fn:
3643+
kind: type
3644+
fn_name: cast_signed
3645+
type_var: U
3646+
operands:
3647+
- !ScalarExpression
3648+
scalar_arg: KZp
3649+
--- !LinalgOpConfig
34133650
metadata: !LinalgOpMetadata
34143651
name: conv_2d_ngchw_gfchw_q
34153652
cpp_class_name: Conv2DNgchwGfchwQOp

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
133133
pad, stride, dilation);
134134
}]>;
135135

136+
// Handles grouped convolution
137+
def Tosa_ConvOpGroupQuantBuilder : OpBuilder<
138+
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
139+
"::mlir::Value":$weight, "::mlir::Value":$bias,
140+
"::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
141+
"::mlir::DenseI64ArrayAttr":$dilation, "::mlir::IntegerAttr":$group),
142+
[{
143+
buildConvOpWithQuantInfo($_builder, $_state, outputType,
144+
input, weight, bias,
145+
pad, stride, dilation, group);
146+
}]>;
147+
136148
// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
137149
def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
138150
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
108108
Tosa_IntArrayAttr4:$pad,
109109
Tosa_IntArrayAttr2:$stride,
110110
Tosa_IntArrayAttr2:$dilation,
111+
OptionalAttr<I64Attr>:$group,
111112
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
112113
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
113114
);
@@ -116,7 +117,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
116117
Tosa_Tensor4D:$output
117118
);
118119

119-
let builders = [Tosa_ConvOpQuantInfoBuilder];
120+
let builders = [Tosa_ConvOpQuantInfoBuilder, Tosa_ConvOpGroupQuantBuilder];
120121
let hasVerifier = 1;
121122
}
122123

0 commit comments

Comments
 (0)