Skip to content

Commit 8ff731c

Browse files
committed
[linalg] Add quantized version of conv_3d_ncdhw_fcdhw
This patch adds the quantized 3d convolution operator `conv_3d_ncdhw_fcdhw_q`. This is the "channel-first" dimension ordering used by PyTorch and others.
1 parent f0b3b6d commit 8ff731c

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4024,6 +4024,145 @@ structured_op: !LinalgStructuredOpConfig
40244024
- !ScalarExpression
40254025
scalar_arg: K
40264026
--- !LinalgOpConfig
4027+
metadata: !LinalgOpMetadata
4028+
name: conv_3d_ncdhw_fcdhw_q
4029+
cpp_class_name: Conv3DNcdhwFcdhwQOp
4030+
doc: |-
4031+
Performs 3-D convolution with zero point offsets.
4032+
4033+
Numeric casting is performed on the operands to the inner multiply, promoting
4034+
them to the same data type as the accumulator/output. This includes the zero
4035+
point offsets common to quantized operations.
4036+
implements:
4037+
- LinalgConvolutionOpInterface
4038+
structured_op: !LinalgStructuredOpConfig
4039+
args:
4040+
- !LinalgOperandDefConfig
4041+
name: I
4042+
kind: input_tensor
4043+
type_var: T1
4044+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
4045+
s13, s14] -> (s0, s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9, s10 * s11 + s12
4046+
* s13)>
4047+
- !LinalgOperandDefConfig
4048+
name: K
4049+
kind: input_tensor
4050+
type_var: T2
4051+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
4052+
s13, s14] -> (s14, s1, s4, s8, s12)>
4053+
- !LinalgOperandDefConfig
4054+
name: IZp
4055+
kind: scalar
4056+
type_var: I32
4057+
- !LinalgOperandDefConfig
4058+
name: KZp
4059+
kind: scalar
4060+
type_var: I32
4061+
- !LinalgOperandDefConfig
4062+
name: O
4063+
kind: output_tensor
4064+
type_var: U
4065+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
4066+
s13, s14] -> (s0, s14, s2, s6, s10)>
4067+
- !LinalgOperandDefConfig
4068+
name: strides
4069+
kind: index_attr
4070+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
4071+
s12, s13, s14] -> (s3, s7, s11)>
4072+
default_indices:
4073+
- 1
4074+
- 1
4075+
- 1
4076+
- !LinalgOperandDefConfig
4077+
name: dilations
4078+
kind: index_attr
4079+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
4080+
s12, s13, s14] -> (s5, s9, s13)>
4081+
default_indices:
4082+
- 1
4083+
- 1
4084+
- 1
4085+
indexing_maps: !LinalgIndexingMapsConfig
4086+
static_indexing_maps:
4087+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
4088+
s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d8, d1 * s3 + d5 * s5, d2 * s7
4089+
+ d6 * s9, d3 * s11 + d7 * s13)>
4090+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
4091+
s7, s8, s9, s10, s11, s12, s13, s14] -> (d4, d8, d5, d6, d7)>
4092+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
4093+
s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
4094+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
4095+
s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
4096+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
4097+
s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d4, d1, d2, d3)>
4098+
iterator_types:
4099+
- parallel
4100+
- parallel
4101+
- parallel
4102+
- parallel
4103+
- parallel
4104+
- reduction
4105+
- reduction
4106+
- reduction
4107+
- reduction
4108+
assignments:
4109+
- !ScalarAssign
4110+
arg: O
4111+
value: !ScalarExpression
4112+
scalar_fn:
4113+
kind: binary
4114+
fn_name: add
4115+
operands:
4116+
- !ScalarExpression
4117+
scalar_arg: O
4118+
- !ScalarExpression
4119+
scalar_fn:
4120+
kind: binary
4121+
fn_name: mul
4122+
operands:
4123+
- !ScalarExpression
4124+
scalar_fn:
4125+
kind: binary
4126+
fn_name: sub
4127+
operands:
4128+
- !ScalarExpression
4129+
scalar_fn:
4130+
kind: type
4131+
fn_name: cast_signed
4132+
type_var: U
4133+
operands:
4134+
- !ScalarExpression
4135+
scalar_arg: I
4136+
- !ScalarExpression
4137+
scalar_fn:
4138+
kind: type
4139+
fn_name: cast_signed
4140+
type_var: U
4141+
operands:
4142+
- !ScalarExpression
4143+
scalar_arg: IZp
4144+
- !ScalarExpression
4145+
scalar_fn:
4146+
kind: binary
4147+
fn_name: sub
4148+
operands:
4149+
- !ScalarExpression
4150+
scalar_fn:
4151+
kind: type
4152+
fn_name: cast_signed
4153+
type_var: U
4154+
operands:
4155+
- !ScalarExpression
4156+
scalar_arg: K
4157+
- !ScalarExpression
4158+
scalar_fn:
4159+
kind: type
4160+
fn_name: cast_signed
4161+
type_var: U
4162+
operands:
4163+
- !ScalarExpression
4164+
scalar_arg: KZp
4165+
--- !LinalgOpConfig
40274166
metadata: !LinalgOpMetadata
40284167
name: depthwise_conv_1d_nwc_wc
40294168
cpp_class_name: DepthwiseConv1DNwcWcOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,46 @@ def conv_3d_ncdhw_fcdhw(
11261126
],
11271127
) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
11281128

1129+
@linalg_structured_op
1130+
def conv_3d_ncdhw_fcdhw_q(
1131+
I=TensorDef(
1132+
T1,
1133+
S.N,
1134+
S.C,
1135+
S.OD * S.SD + S.KD * S.DD,
1136+
S.OH * S.SH + S.KH * S.DH,
1137+
S.OW * S.SW + S.KW * S.DW,
1138+
),
1139+
K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
1140+
IZp=ScalarDef(I32),
1141+
KZp=ScalarDef(I32),
1142+
O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
1143+
strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
1144+
dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
1145+
):
1146+
"""Performs 3-D convolution with zero point offsets.
1147+
1148+
Numeric casting is performed on the operands to the inner multiply, promoting
1149+
them to the same data type as the accumulator/output. This includes the zero
1150+
point offsets common to quantized operations.
1151+
"""
1152+
implements(ConvolutionOpInterface)
1153+
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
1154+
O[D.n, D.f, D.od, D.oh, D.ow] += (
1155+
TypeFn.cast_signed(
1156+
U,
1157+
I[
1158+
D.n,
1159+
D.c,
1160+
D.od * S.SD + D.kd * S.DD,
1161+
D.oh * S.SH + D.kh * S.DH,
1162+
D.ow * S.SW + D.kw * S.DW,
1163+
],
1164+
) - TypeFn.cast_signed(U, IZp)
1165+
) * (
1166+
TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
1167+
- TypeFn.cast_signed(U, KZp)
1168+
)
11291169

11301170
@linalg_structured_op
11311171
def depthwise_conv_1d_nwc_wc(

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,3 +694,18 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
694694
// CHECK-LABEL: func @conv2d_channel_first_q_promote(
695695
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
696696
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
697+
698+
// -----
699+
700+
func.func @conv3d_channel_first_q(%img: tensor<1x27x49x48x47xi8>, %filt: tensor<28x27x3x4x5xi8>, %a: i32, %b: i32) -> tensor<1x28x47x45x43xi32> {
701+
%init = arith.constant dense<0> : tensor<1x28x47x45x43xi32>
702+
%1 = linalg.conv_3d_ncdhw_fcdhw_q {dilations = dense<1> : tensor<3xi64>,
703+
strides = dense<1> : tensor<3xi64>}
704+
ins(%img, %filt, %a, %b : tensor<1x27x49x48x47xi8>, tensor<28x27x3x4x5xi8>, i32, i32)
705+
outs(%init : tensor<1x28x47x45x43xi32>) -> tensor<1x28x47x45x43xi32>
706+
return %1 : tensor<1x28x47x45x43xi32>
707+
}
708+
709+
// CHECK-LABEL: func @conv3d_channel_first_q(
710+
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<1x27x49x48x47xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<28x27x3x4x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i32, %[[arg3:[a-zA-z0-9]*]]: i32)
711+
// CHECK: linalg.conv_3d_ncdhw_fcdhw_q {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<1x27x49x48x47xi8>, tensor<28x27x3x4x5xi8>, i32, i32) outs(%{{.*}} : tensor<1x28x47x45x43xi32>) -> tensor<1x28x47x45x43xi32>

0 commit comments

Comments
 (0)