@@ -1125,54 +1125,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
1125
1125
}
1126
1126
1127
1127
if (numGroups == 1 && inputZp) {
1128
- // The quantized version uses a different channel ordering so we need to
1129
- // permute the tensors in order to use the existing path. We should
1130
- // eventually directly support this channel ordering.
1131
- llvm::SmallVector<int64_t > inPerms, weightPerms;
1132
- inPerms.push_back (0 ); // N stays at the front for input.
1133
- // Then we expect the spatial dimensions
1134
- for (size_t i = 0 ; i < numSpatialDims; ++i) {
1135
- inPerms.push_back (i + 2 );
1136
- weightPerms.push_back (i + 2 );
1137
- }
1138
- inPerms.push_back (1 );
1139
- weightPerms.append ({1 , 0 });
1140
-
1141
- paddedInput = transposeValue (op.getLoc (), paddedInput, inPerms, rewriter);
1142
- weight = transposeValue (op.getLoc (), weight, weightPerms, rewriter);
1143
- outputTensor =
1144
- transposeValue (op.getLoc (), outputTensor, inPerms, rewriter);
1145
-
1146
1128
switch (numSpatialDims) {
1147
1129
case 2 :
1148
1130
conv = rewriter
1149
- .create <linalg::Conv2DNhwcHwcfQOp >(
1131
+ .create <linalg::Conv2DNchwFchwQOp >(
1150
1132
loc, outputTensor.getType (),
1151
1133
ValueRange{paddedInput, weight, inputZp, weightZp},
1152
1134
outputTensor, stridesAttr, dilationAttr)
1153
1135
.getResult (0 );
1154
1136
break ;
1155
- case 3 :
1137
+ case 3 : {
1138
+ // The quantized version uses a different channel ordering so we need to
1139
+ // permute the tensors in order to use the existing path. We should
1140
+ // eventually directly support this channel ordering.
1141
+ llvm::SmallVector<int64_t > inPerms, weightPerms;
1142
+ inPerms.push_back (0 ); // N stays at the front for input.
1143
+ // Then we expect the spatial dimensions
1144
+ for (size_t i = 0 ; i < numSpatialDims; ++i) {
1145
+ inPerms.push_back (i + 2 );
1146
+ weightPerms.push_back (i + 2 );
1147
+ }
1148
+ inPerms.push_back (1 );
1149
+ weightPerms.append ({1 , 0 });
1150
+
1151
+ paddedInput =
1152
+ transposeValue (op.getLoc (), paddedInput, inPerms, rewriter);
1153
+ weight = transposeValue (op.getLoc (), weight, weightPerms, rewriter);
1154
+ outputTensor =
1155
+ transposeValue (op.getLoc (), outputTensor, inPerms, rewriter);
1156
+
1156
1157
conv = rewriter
1157
1158
.create <linalg::Conv3DNdhwcDhwcfQOp>(
1158
1159
loc, outputTensor.getType (),
1159
1160
ValueRange{paddedInput, weight, inputZp, weightZp},
1160
1161
outputTensor, stridesAttr, dilationAttr)
1161
1162
.getResult (0 );
1163
+
1164
+ llvm::SmallVector<int64_t > outPerms;
1165
+ outPerms.push_back (0 );
1166
+ outPerms.push_back (inPerms.size () - 1 );
1167
+ for (size_t i = 0 ; i < numSpatialDims; ++i) {
1168
+ outPerms.push_back (i + 1 );
1169
+ }
1170
+ conv = transposeValue (op.getLoc (), conv, outPerms, rewriter);
1171
+
1162
1172
break ;
1173
+ }
1163
1174
default :
1164
1175
return rewriter.notifyMatchFailure (
1165
1176
op, " unimplemented: only 1D, 2D, and 3D convolution supported" );
1166
1177
};
1167
1178
1168
- llvm::SmallVector<int64_t > outPerms;
1169
- outPerms.push_back (0 );
1170
- outPerms.push_back (inPerms.size () - 1 );
1171
- for (size_t i = 0 ; i < numSpatialDims; ++i) {
1172
- outPerms.push_back (i + 1 );
1173
- }
1174
- conv = transposeValue (op.getLoc (), conv, outPerms, rewriter);
1175
-
1176
1179
Type newResultType = getTypeConverter ()->convertType (op.getType ());
1177
1180
if (accumulatorDType != resultDTy) {
1178
1181
Type resultElementType =
0 commit comments