Skip to content

Commit aca33f1

Browse files
authored
[TorchToLinalg] Use Op with native channel order for quantized conv2d (#3807)
I've upstreamed the necessary quantized linalg Op with the "channel-first" ordering used by torch (llvm/llvm-project#107740) for 2d convolution. This patch changes the lowering for the quantized 2d case of `aten.convolution` accordingly, which saves three transpositions per convolution (input, weights, result) and therefore removes the requirement to try to optimize these away in downstream passes.
1 parent 42ba541 commit aca33f1

File tree

2 files changed

+33
-34
lines changed

2 files changed

+33
-34
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,54 +1125,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11251125
}
11261126

11271127
if (numGroups == 1 && inputZp) {
1128-
// The quantized version uses a different channel ordering so we need to
1129-
// permute the tensors in order to use the existing path. We should
1130-
// eventually directly support this channel ordering.
1131-
llvm::SmallVector<int64_t> inPerms, weightPerms;
1132-
inPerms.push_back(0); // N stays at the front for input.
1133-
// Then we expect the spatial dimensions
1134-
for (size_t i = 0; i < numSpatialDims; ++i) {
1135-
inPerms.push_back(i + 2);
1136-
weightPerms.push_back(i + 2);
1137-
}
1138-
inPerms.push_back(1);
1139-
weightPerms.append({1, 0});
1140-
1141-
paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
1142-
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
1143-
outputTensor =
1144-
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
1145-
11461128
switch (numSpatialDims) {
11471129
case 2:
11481130
conv = rewriter
1149-
.create<linalg::Conv2DNhwcHwcfQOp>(
1131+
.create<linalg::Conv2DNchwFchwQOp>(
11501132
loc, outputTensor.getType(),
11511133
ValueRange{paddedInput, weight, inputZp, weightZp},
11521134
outputTensor, stridesAttr, dilationAttr)
11531135
.getResult(0);
11541136
break;
1155-
case 3:
1137+
case 3: {
1138+
// The quantized version uses a different channel ordering so we need to
1139+
// permute the tensors in order to use the existing path. We should
1140+
// eventually directly support this channel ordering.
1141+
llvm::SmallVector<int64_t> inPerms, weightPerms;
1142+
inPerms.push_back(0); // N stays at the front for input.
1143+
// Then we expect the spatial dimensions
1144+
for (size_t i = 0; i < numSpatialDims; ++i) {
1145+
inPerms.push_back(i + 2);
1146+
weightPerms.push_back(i + 2);
1147+
}
1148+
inPerms.push_back(1);
1149+
weightPerms.append({1, 0});
1150+
1151+
paddedInput =
1152+
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
1153+
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
1154+
outputTensor =
1155+
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
1156+
11561157
conv = rewriter
11571158
.create<linalg::Conv3DNdhwcDhwcfQOp>(
11581159
loc, outputTensor.getType(),
11591160
ValueRange{paddedInput, weight, inputZp, weightZp},
11601161
outputTensor, stridesAttr, dilationAttr)
11611162
.getResult(0);
1163+
1164+
llvm::SmallVector<int64_t> outPerms;
1165+
outPerms.push_back(0);
1166+
outPerms.push_back(inPerms.size() - 1);
1167+
for (size_t i = 0; i < numSpatialDims; ++i) {
1168+
outPerms.push_back(i + 1);
1169+
}
1170+
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);
1171+
11621172
break;
1173+
}
11631174
default:
11641175
return rewriter.notifyMatchFailure(
11651176
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
11661177
};
11671178

1168-
llvm::SmallVector<int64_t> outPerms;
1169-
outPerms.push_back(0);
1170-
outPerms.push_back(inPerms.size() - 1);
1171-
for (size_t i = 0; i < numSpatialDims; ++i) {
1172-
outPerms.push_back(i + 1);
1173-
}
1174-
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);
1175-
11761179
Type newResultType = getTypeConverter()->convertType(op.getType());
11771180
if (accumulatorDType != resultDTy) {
11781181
Type resultElementType =

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,8 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128]
2424
// CHECK: %[[c7:.*]] = arith.constant 7 : i32
2525
// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
2626
// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
27-
// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor<?x?x?x?xi8>)
28-
// CHECK-SAME: permutation = [0, 2, 3, 1]
29-
// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor<?x?x?x?xi8>)
30-
// CHECK-SAME: permutation = [2, 3, 1, 0]
31-
// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
32-
// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
27+
// CHECK: %[[conv:.*]] = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
28+
// CHECK-SAME: ins(%[[input]], %[[weight]], %[[c7]], %[[c3]] : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
3329
// CHECK-SAME: outs(%[[convout:.*]] : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
3430
func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
3531
%false = torch.constant.bool false

0 commit comments

Comments
 (0)