Skip to content

Commit 1358852

Browse files
Add backwards data convolution op to MIGraphX dialect (#1946)
* Add initial lowering from MIGraphX -> TOSA * Initial changes to adding multiple kernels * Partial implementation of TosaToRock lowering * Remove old check disabling split k and bwd convs * Remove TosaToRock conversions from this pass * Remove unnecessary code changes * Minor fix * Add checks to LIT test * Clang-format * Fix Copilot review comments * Add back splitK logic * Update comment * Remove additional line * Attend to review comments * Add new line to new LIT test * Fix newline * Update out_pad and add some more tests * Remove bwd_weight * Attend to minor review comments * Change wording of op definition * Clang format MIGraphX.td --------- Co-authored-by: Umang Yadav <[email protected]>
1 parent 7d0d249 commit 1358852

File tree

4 files changed

+141
-25
lines changed

4 files changed

+141
-25
lines changed

mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,18 @@ def MIGraphX_ConvolutionOp :
340340
}];
341341
}
342342

343+
def MIGraphX_ConvolutionBwdDataOp
344+
: MIGraphX_ConvOpBase<
345+
"backwards_data_convolution", [F32, F16, BF16], [F32, F16, BF16]> {
346+
let summary = "Backwards data convolution";
347+
let description = [{
348+
The `migraphx.backwards_data_convolution` op computes a transposed
349+
convolution op which effectively reverses a standard convolution's
350+
spatial transformation. It is an upsampling technique that increases
351+
the height and width of an input feature map.
352+
}];
353+
}
354+
343355
def MIGraphX_BatchNormOp :
344356
MIGraphX_Op<"batch_norm_inference">,
345357
Arguments<(ins AnyMIXRShaped:$input,

mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
245245
auto outputTy = cast<MIXRShapedType>(results[0].getType());
246246
Type outElementTy = outputTy.getElementType();
247247
Type newOutElementTy = getTypeConverter()->convertType(outElementTy);
248+
bool isBwdDataConvOp = isa<migraphx::ConvolutionBwdDataOp>(op);
248249

249250
if (outElementTy.isUnsignedInteger())
250251
return op.emitError("No support for unsigned convolution.\n");
@@ -270,8 +271,8 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
270271
newShape.push_back(outShape[1]);
271272
Type newOutTy = RankedTensorType::get(newShape, newOutElementTy);
272273

273-
// There is no tosa.conv1d, so instead we'll add a dummy x1 dimension
274-
// to the input tensors, and make a tosa.conv2d.
274+
// There is no tosa.conv1d or tosa.transpose_conv1d, so instead we'll add a
275+
// dummy x1 dimension to the input tensors, and make a tosa.conv2d.
275276
auto expandTo2D = [&rewriter, loc](mlir::Value value) {
276277
ArrayRef<int64_t> origShape = cast<ShapedType>(value.getType()).getShape();
277278
SmallVector<int64_t> expShape(origShape.drop_back());
@@ -283,13 +284,14 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
283284
return reshaped;
284285
};
285286

286-
// Construct a new Conv2DOp.
287+
// Construct a new Conv2DOp/TransposeConv2DOp.
287288
Operation *cop;
288289
Type new1DOutTy;
289290
Value inputZp, weightZp;
290291
switch (dims) {
291292
case 1:
292-
// Expand to do a conv2d, because there's no conv1d op.
293+
// Expand to do a conv2d/transpose_conv2d, because there's no 1d version of
294+
// the ops.
293295
newShape.insert(std::prev(newShape.end()), 1);
294296
new1DOutTy = RankedTensorType::get(newShape, newOutElementTy);
295297
input = expandTo2D(input);
@@ -299,31 +301,57 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
299301
weightZp =
300302
tosa::createZeroPointTensor(rewriter, loc, filter.getType(), 0).value();
301303

302-
cop = rewriter.create<tosa::Conv2DOp>(
303-
loc, new1DOutTy,
304-
ValueRange{
305-
input, filter,
306-
getZeroTensor(loc, newOutElementTy,
307-
cast<ShapedType>(filter.getType()).getShape()[0],
308-
rewriter),
309-
inputZp, weightZp});
304+
if (isBwdDataConvOp) {
305+
cop = rewriter.create<tosa::TransposeConv2DOp>(
306+
loc, new1DOutTy,
307+
ValueRange{
308+
input, filter,
309+
getZeroTensor(loc, newOutElementTy,
310+
cast<ShapedType>(filter.getType()).getShape()[0],
311+
rewriter),
312+
inputZp, weightZp});
313+
} else {
314+
cop = rewriter.create<tosa::Conv2DOp>(
315+
loc, new1DOutTy,
316+
ValueRange{
317+
input, filter,
318+
getZeroTensor(loc, newOutElementTy,
319+
cast<ShapedType>(filter.getType()).getShape()[0],
320+
rewriter),
321+
inputZp, weightZp});
322+
}
310323
break;
311324

312325
case 2:
313326
inputZp =
314327
tosa::createZeroPointTensor(rewriter, loc, input.getType(), 0).value();
315328
weightZp =
316329
tosa::createZeroPointTensor(rewriter, loc, filter.getType(), 0).value();
317-
cop = rewriter.create<tosa::Conv2DOp>(
318-
loc, newOutTy,
319-
ValueRange{
320-
input, filter,
321-
getZeroTensor(loc, newOutElementTy,
322-
cast<ShapedType>(filter.getType()).getShape()[0],
323-
rewriter),
324-
inputZp, weightZp});
330+
if (isBwdDataConvOp) {
331+
cop = rewriter.create<tosa::TransposeConv2DOp>(
332+
loc, newOutTy,
333+
ValueRange{
334+
input, filter,
335+
getZeroTensor(loc, newOutElementTy,
336+
cast<ShapedType>(filter.getType()).getShape()[0],
337+
rewriter),
338+
inputZp, weightZp});
339+
} else {
340+
cop = rewriter.create<tosa::Conv2DOp>(
341+
loc, newOutTy,
342+
ValueRange{
343+
input, filter,
344+
getZeroTensor(loc, newOutElementTy,
345+
cast<ShapedType>(filter.getType()).getShape()[0],
346+
rewriter),
347+
inputZp, weightZp});
348+
}
325349
break;
326350
case 3:
351+
if (isBwdDataConvOp)
352+
return op->emitError("Only 1-D and 2-D backwards convolution ops are "
353+
"supported");
354+
327355
inputZp =
328356
tosa::createZeroPointTensor(rewriter, loc, input.getType(), 0).value();
329357
weightZp =
@@ -361,8 +389,6 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
361389
dilations.push_back(dyn_cast<IntegerAttr>(dilationAttr[i]).getInt());
362390
}
363391

364-
int64_t group = op.getGroup();
365-
366392
// Determine the accumulation type based on the output type.
367393
Type accType;
368394
if (isa<FloatType>(elementTy)) {
@@ -386,11 +412,31 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
386412
pads.push_back(0);
387413
}
388414

415+
// Set attributes common to both forwards and backwards conv
389416
cop->setAttr("dilation", rewriter.getDenseI64ArrayAttr(dilations));
390417
cop->setAttr("stride", rewriter.getDenseI64ArrayAttr(strides));
391-
cop->setAttr("pad", rewriter.getDenseI64ArrayAttr(pads));
392-
cop->setAttr("group", rewriter.getI64IntegerAttr(group));
393418
cop->setAttr("acc_type", TypeAttr::get(accType));
419+
int64_t group = op.getGroup();
420+
cop->setAttr("group", rewriter.getI64IntegerAttr(group));
421+
422+
// Set padding for forwards and backwards convolution. Note: the padding here
423+
// applies to input padding (which transpose.conv2D does not inherently
424+
// support). TransposeConv2D will still require an output pad attribute, so we
425+
// can just set that to zeros
426+
if (isBwdDataConvOp) {
427+
SmallVector<int64_t> zeroPads(pads.size(), 0);
428+
cop->setAttr("out_pad", rewriter.getDenseI64ArrayAttr(zeroPads));
429+
}
430+
cop->setAttr("pad", rewriter.getDenseI64ArrayAttr(pads));
431+
432+
// For both types of backwards convolution, we will be using
433+
// tosa.transpose_conv2d, so we are going to add a conv_kind attribute so
434+
// that we can distinguish between the two types in TosaToRock.
435+
// TODO: We will need to add conv_kind = "bwd_weight" when we eventually
436+
// add support for bwd_weight ops in MIGraphX.
437+
if (isa<migraphx::ConvolutionBwdDataOp>(op)) {
438+
cop->setAttr("conv_kind", rewriter.getStringAttr("bwd_data"));
439+
}
394440

395441
// Convert optional attributes
396442
if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
@@ -1499,7 +1545,8 @@ LogicalResult MHALLaunchConverter::matchAndRewrite(
14991545

15001546
void migraphx::populateMIGraphXToTosaConversionPatterns(
15011547
RewritePatternSet &patterns, TypeConverter &typeConverter) {
1502-
patterns.add<ConvConverter<ConvolutionOp>, ConvConverter<QuantConvolutionOp>,
1548+
patterns.add<ConvConverter<ConvolutionBwdDataOp>,
1549+
ConvConverter<ConvolutionOp>, ConvConverter<QuantConvolutionOp>,
15031550
DotConverter<DotOp>, DotConverter<QuantDotOp>,
15041551
BroadcastConverter, MultiBroadcastConverter, TransposeConverter,
15051552
ReshapeConverter, SliceConverter, ReduceMeanConverter,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// This test checks that we emit an error when trying to convert 3-D
2+
// backwards convolution ops.
3+
4+
// RUN: not rocmlir-opt -split-input-file --migraphx-to-tosa %s 2>&1 | FileCheck %s
5+
6+
module {
7+
func.func @bwd_data_conv_3d(
8+
%arg0: !migraphx.shaped<1x16x4x4x4xf32, 1024x64x16x4x1>,
9+
%arg1: !migraphx.shaped<16x16x1x1x1xf32, 16x1x1x1x1>,
10+
%arg2: !migraphx.shaped<1x16x4x4x4xf32, 1024x64x16x4x1>
11+
) -> !migraphx.shaped<1x16x4x4x4xf32, 1024x64x16x4x1> {
12+
// CHECK: Only 1-D and 2-D backwards convolution ops are supported
13+
%0 = migraphx.backwards_data_convolution %arg1, %arg0 {
14+
dilation = [1, 1, 1],
15+
group = 1 : i64,
16+
padding = [0, 0, 0, 0, 0, 0],
17+
padding_mode = 0 : i64,
18+
stride = [1, 1, 1],
19+
kernelId = 0 : i64
20+
} : <16x16x1x1x1xf32, 16x1x1x1x1>, <1x16x4x4x4xf32, 1024x64x16x4x1> -> <1x16x4x4x4xf32, 1024x64x16x4x1>
21+
return %0 : !migraphx.shaped<1x16x4x4x4xf32, 1024x64x16x4x1>
22+
}
23+
}
24+

mlir/test/Conversion/MIGraphXToTosa/migraphx-to-tosa.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,39 @@ func.func @quant_conv2d_float8(%arg0: !migraphx.shaped<1x16x4x4xf8E5M2, 256x16x4
206206
return %0 : !migraphx.shaped<1x16x4x4xf32, 256x16x4x1>
207207
}
208208

209+
// CHECK-LABEL: @bwd_data_conv
210+
func.func @bwd_data_conv(%arg0: !migraphx.shaped<1x16x4x4xf32, 256x16x4x1>, %arg1: !migraphx.shaped<16x16x1x1xf32, 16x1x1x1>, %arg2: !migraphx.shaped<1x16x4x4xf32, 256x16x4x1>) -> !migraphx.shaped<1x16x4x4xf32, 256x16x4x1> {
211+
// CHECK: tosa.transpose_conv2d
212+
// CHECK-SAME: {acc_type = f32, conv_kind = "bwd_data", dilation = array<i64: 1, 1>, group = 1 : i64, out_pad = array<i64: 0, 0, 0, 0>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<16x1x1x16xf32>, tensor<1x4x4x16xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x16xf32>
213+
%0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], kernelId = 0 : i64} : <16x16x1x1xf32, 16x1x1x1>, <1x16x4x4xf32, 256x16x4x1> -> <1x16x4x4xf32, 256x16x4x1>
214+
return %0 : !migraphx.shaped<1x16x4x4xf32, 256x16x4x1>
215+
}
216+
217+
// CHECK-LABEL: @bwd_data_conv_attributes
218+
func.func @bwd_data_conv_attributes(%arg0: !migraphx.shaped<1x16x4x4xf32, 256x16x4x1>, %arg1: !migraphx.shaped<16x16x1x1xf32, 16x1x1x1>, %arg2: !migraphx.shaped<1x16x4x4xf32, 256x16x4x1>) -> !migraphx.shaped<1x16x4x4xf32, 256x16x4x1> {
219+
// CHECK: tosa.transpose_conv2d
220+
// CHECK-SAME: {acc_type = f32, conv_kind = "bwd_data", dilation = array<i64: 2, 2>, group = 2 : i64, out_pad = array<i64: 0, 0, 0, 0>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<16x1x1x16xf32>, tensor<1x4x4x16xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x16xf32>
221+
%0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [2, 2], group = 2 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], kernelId = 0 : i64} : <16x16x1x1xf32, 16x1x1x1>, <1x16x4x4xf32, 256x16x4x1> -> <1x16x4x4xf32, 256x16x4x1>
222+
return %0 : !migraphx.shaped<1x16x4x4xf32, 256x16x4x1>
223+
}
224+
225+
// CHECK-LABEL: @bwd_data_conv_stride
226+
func.func @bwd_data_conv_stride(%arg0: !migraphx.shaped<1x32x3x3xf32, 288x9x3x1>, %arg1: !migraphx.shaped<32x16x4x4xf32, 256x16x4x1>, %arg2: !migraphx.shaped<1x32x9x9xf32, 2592x81x9x1>) -> !migraphx.shaped<1x32x9x9xf32, 2592x81x9x1> {
227+
// CHECK: tosa.transpose_conv2d
228+
// CHECK-SAME: {acc_type = f32, conv_kind = "bwd_data", dilation = array<i64: 1, 1>, group = 1 : i64, out_pad = array<i64: 0, 0, 0, 0>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>}
229+
%0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [2, 2], kernelId = 0 : i64} : <32x16x4x4xf32, 256x16x4x1>, <1x32x3x3xf32, 288x9x3x1> -> <1x32x9x9xf32, 2592x81x9x1>
230+
return %0 : !migraphx.shaped<1x32x9x9xf32, 2592x81x9x1>
231+
}
232+
233+
// CHECK-LABEL: @bwd_data_conv1d
234+
func.func @bwd_data_conv1d(%arg0: !migraphx.shaped<1x64x224xf32, 0x1x0>, %arg1: !migraphx.shaped<1x3x224xf32, 672x224x1>, %arg2: !migraphx.shaped<64x3x1xf32, 3x1x1>) -> !migraphx.shaped<1x64x224xf32, 14336x224x1> {
235+
// CHECK: tosa.transpose_conv2d
236+
// CHECK-SAME: {acc_type = f32, conv_kind = "bwd_data", dilation = array<i64: 1, 1>, group = 1 : i64, out_pad = array<i64: 0, 0, 0, 0>, pad = array<i64: 3, 3, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x224x1x3xf32>, tensor<64x1x1x3xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x224x1x64xf32>
237+
%0 = migraphx.backwards_data_convolution %arg1, %arg2 {dilation = [1], group = 1 : i64, padding = [3, 3], padding_mode = 0 : i64, stride = [1]} : <1x3x224xf32, 672x224x1>, <64x3x1xf32, 3x1x1> -> <1x64x224xf32, 14336x224x1>
238+
%1 = migraphx.add %0, %arg0 : <1x64x224xf32, 14336x224x1>, <1x64x224xf32, 0x1x0> -> <1x64x224xf32, 14336x224x1>
239+
return %1 : !migraphx.shaped<1x64x224xf32, 14336x224x1>
240+
}
241+
209242
// -----
210243

211244
// CHECK-LABEL: @dot_f16

0 commit comments

Comments
 (0)