@@ -245,6 +245,7 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
245245 auto outputTy = cast<MIXRShapedType>(results[0 ].getType ());
246246 Type outElementTy = outputTy.getElementType ();
247247 Type newOutElementTy = getTypeConverter ()->convertType (outElementTy);
248+ bool isBwdDataConvOp = isa<migraphx::ConvolutionBwdDataOp>(op);
248249
249250 if (outElementTy.isUnsignedInteger ())
250251 return op.emitError (" No support for unsigned convolution.\n " );
@@ -270,8 +271,8 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
270271 newShape.push_back (outShape[1 ]);
271272 Type newOutTy = RankedTensorType::get (newShape, newOutElementTy);
272273
273- // There is no tosa.conv1d, so instead we'll add a dummy x1 dimension
274- // to the input tensors, and make a tosa.conv2d.
274+ // There is no tosa.conv1d or tosa.transpose_conv1d , so instead we'll add a
275+ // dummy x1 dimension to the input tensors, and make a tosa.conv2d.
275276 auto expandTo2D = [&rewriter, loc](mlir::Value value) {
276277 ArrayRef<int64_t > origShape = cast<ShapedType>(value.getType ()).getShape ();
277278 SmallVector<int64_t > expShape (origShape.drop_back ());
@@ -283,13 +284,14 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
283284 return reshaped;
284285 };
285286
286- // Construct a new Conv2DOp.
287+ // Construct a new Conv2DOp/TransposeConv2DOp .
287288 Operation *cop;
288289 Type new1DOutTy;
289290 Value inputZp, weightZp;
290291 switch (dims) {
291292 case 1 :
292- // Expand to do a conv2d, because there's no conv1d op.
293+ // Expand to do a conv2d/transpose_conv2d, because there's no 1d version of
294+ // the ops.
293295 newShape.insert (std::prev (newShape.end ()), 1 );
294296 new1DOutTy = RankedTensorType::get (newShape, newOutElementTy);
295297 input = expandTo2D (input);
@@ -299,31 +301,57 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
299301 weightZp =
300302 tosa::createZeroPointTensor (rewriter, loc, filter.getType (), 0 ).value ();
301303
302- cop = rewriter.create <tosa::Conv2DOp>(
303- loc, new1DOutTy,
304- ValueRange{
305- input, filter,
306- getZeroTensor (loc, newOutElementTy,
307- cast<ShapedType>(filter.getType ()).getShape ()[0 ],
308- rewriter),
309- inputZp, weightZp});
304+ if (isBwdDataConvOp) {
305+ cop = rewriter.create <tosa::TransposeConv2DOp>(
306+ loc, new1DOutTy,
307+ ValueRange{
308+ input, filter,
309+ getZeroTensor (loc, newOutElementTy,
310+ cast<ShapedType>(filter.getType ()).getShape ()[0 ],
311+ rewriter),
312+ inputZp, weightZp});
313+ } else {
314+ cop = rewriter.create <tosa::Conv2DOp>(
315+ loc, new1DOutTy,
316+ ValueRange{
317+ input, filter,
318+ getZeroTensor (loc, newOutElementTy,
319+ cast<ShapedType>(filter.getType ()).getShape ()[0 ],
320+ rewriter),
321+ inputZp, weightZp});
322+ }
310323 break ;
311324
312325 case 2 :
313326 inputZp =
314327 tosa::createZeroPointTensor (rewriter, loc, input.getType (), 0 ).value ();
315328 weightZp =
316329 tosa::createZeroPointTensor (rewriter, loc, filter.getType (), 0 ).value ();
317- cop = rewriter.create <tosa::Conv2DOp>(
318- loc, newOutTy,
319- ValueRange{
320- input, filter,
321- getZeroTensor (loc, newOutElementTy,
322- cast<ShapedType>(filter.getType ()).getShape ()[0 ],
323- rewriter),
324- inputZp, weightZp});
330+ if (isBwdDataConvOp) {
331+ cop = rewriter.create <tosa::TransposeConv2DOp>(
332+ loc, newOutTy,
333+ ValueRange{
334+ input, filter,
335+ getZeroTensor (loc, newOutElementTy,
336+ cast<ShapedType>(filter.getType ()).getShape ()[0 ],
337+ rewriter),
338+ inputZp, weightZp});
339+ } else {
340+ cop = rewriter.create <tosa::Conv2DOp>(
341+ loc, newOutTy,
342+ ValueRange{
343+ input, filter,
344+ getZeroTensor (loc, newOutElementTy,
345+ cast<ShapedType>(filter.getType ()).getShape ()[0 ],
346+ rewriter),
347+ inputZp, weightZp});
348+ }
325349 break ;
326350 case 3 :
351+ if (isBwdDataConvOp)
352+ return op->emitError (" Only 1-D and 2-D backwards convolution ops are "
353+ " supported" );
354+
327355 inputZp =
328356 tosa::createZeroPointTensor (rewriter, loc, input.getType (), 0 ).value ();
329357 weightZp =
@@ -361,8 +389,6 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
361389 dilations.push_back (dyn_cast<IntegerAttr>(dilationAttr[i]).getInt ());
362390 }
363391
364- int64_t group = op.getGroup ();
365-
366392 // Determine the accumulation type based on the output type.
367393 Type accType;
368394 if (isa<FloatType>(elementTy)) {
@@ -386,11 +412,31 @@ LogicalResult ConvConverter<ConvType>::matchAndRewrite(
386412 pads.push_back (0 );
387413 }
388414
415+ // Set attributes common to both forwards and backwards conv
389416 cop->setAttr (" dilation" , rewriter.getDenseI64ArrayAttr (dilations));
390417 cop->setAttr (" stride" , rewriter.getDenseI64ArrayAttr (strides));
391- cop->setAttr (" pad" , rewriter.getDenseI64ArrayAttr (pads));
392- cop->setAttr (" group" , rewriter.getI64IntegerAttr (group));
393418 cop->setAttr (" acc_type" , TypeAttr::get (accType));
419+ int64_t group = op.getGroup ();
420+ cop->setAttr (" group" , rewriter.getI64IntegerAttr (group));
421+
422+ // Set padding for forwards and backwards convolution. Note: the padding here
423+ // applies to input padding (which transpose.conv2D does not inherently
424+ // support). TransposeConv2D will still require an output pad attribute, so we
425+ // can just set that to zeros
426+ if (isBwdDataConvOp) {
427+ SmallVector<int64_t > zeroPads (pads.size (), 0 );
428+ cop->setAttr (" out_pad" , rewriter.getDenseI64ArrayAttr (zeroPads));
429+ }
430+ cop->setAttr (" pad" , rewriter.getDenseI64ArrayAttr (pads));
431+
432+ // For both types of backwards convolution, we will be using
433+ // tosa.transpose_conv2d, so we are going to add a conv_kind attribute so
434+ // that we can distinguish between the two types in TosaToRock.
435+ // TODO: We will need to add conv_kind = "bwd_weight" when we eventually
436+ // add support for bwd_weight ops in MIGraphX.
437+ if (isa<migraphx::ConvolutionBwdDataOp>(op)) {
438+ cop->setAttr (" conv_kind" , rewriter.getStringAttr (" bwd_data" ));
439+ }
394440
395441 // Convert optional attributes
396442 if (auto attr = (*op).template getAttrOfType <StringAttr>(" perf_config" ))
@@ -1499,7 +1545,8 @@ LogicalResult MHALLaunchConverter::matchAndRewrite(
14991545
15001546void migraphx::populateMIGraphXToTosaConversionPatterns (
15011547 RewritePatternSet &patterns, TypeConverter &typeConverter) {
1502- patterns.add <ConvConverter<ConvolutionOp>, ConvConverter<QuantConvolutionOp>,
1548+ patterns.add <ConvConverter<ConvolutionBwdDataOp>,
1549+ ConvConverter<ConvolutionOp>, ConvConverter<QuantConvolutionOp>,
15031550 DotConverter<DotOp>, DotConverter<QuantDotOp>,
15041551 BroadcastConverter, MultiBroadcastConverter, TransposeConverter,
15051552 ReshapeConverter, SliceConverter, ReduceMeanConverter,
0 commit comments