@@ -442,9 +442,6 @@ class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
442442 if (!matchPattern (op.dilation (), m_TorchConstantIntList (dilationInts)))
443443 return rewriter.notifyMatchFailure (op,
444444 " only support constant int dilations" );
445- if (!op.bias ().getType ().isa <Torch::NoneType>())
446- return rewriter.notifyMatchFailure (op, " only support None bias" );
447-
448445 Value c1 =
449446 rewriter.create <arith::ConstantOp>(loc, IntegerAttr::get (intType, 1 ));
450447 Value groupEqual1 = rewriter.create <arith::CmpIOp>(
@@ -473,22 +470,47 @@ class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
473470 rewriter, loc, Win, paddingIntValues[1 ], dilationIntValues[1 ],
474471 castIndexToInt (weightW), strideIntValues[1 ]);
475472
476- Value c0float = rewriter.create <arith::ConstantOp>(
477- loc,
478- FloatAttr::get (
479- input.getType ().cast <RankedTensorType>().getElementType (), 0.0 ));
480473 Value initTensor = rewriter.create <linalg::InitTensorOp>(
481474 loc, ValueRange{N, F, Hout, Wout}, elementType);
482- Value initTensor0 =
483- rewriter.create <linalg::FillOp>(loc, c0float, initTensor).getResult (0 );
475+
476+ Value bias = adaptor.bias ();
477+ Value biasInitTensor;
478+ if (bias.getType ().isa <Torch::NoneType>()) {
479+ Value c0float = rewriter.create <arith::ConstantOp>(
480+ loc, FloatAttr::get (elementType, 0.0 ));
481+ biasInitTensor = rewriter.create <linalg::FillOp>(loc, c0float, initTensor)
482+ .getResult (0 );
483+ } else {
484+ auto biasType = bias.getType ().cast <RankedTensorType>();
485+ if (biasType.getRank () != 1 )
486+ return rewriter.notifyMatchFailure (op, " expect bias to be rank 1" );
487+ if (elementType != biasType.getElementType ())
488+ return rewriter.notifyMatchFailure (op, " unimplemented: type promotion" );
489+
490+ auto resultRank = initTensor.getType ().cast <RankedTensorType>().getRank ();
491+ SmallVector<AffineMap> indexingMaps = {
492+ // bias is used to initialize the channels - dimension 1 of output
493+ AffineMap::get (/* dimCount=*/ resultRank, /* symbolCount=*/ 0 ,
494+ rewriter.getAffineDimExpr (1 ), context),
495+ rewriter.getMultiDimIdentityMap (resultRank)};
496+ SmallVector<StringRef> iteratorTypes (resultRank, " parallel" );
497+ biasInitTensor = rewriter
498+ .create <linalg::GenericOp>(
499+ loc, initTensor.getType (), bias, initTensor,
500+ indexingMaps, iteratorTypes,
501+ [](OpBuilder &b, Location loc, ValueRange args) {
502+ b.create <linalg::YieldOp>(loc, args[0 ]);
503+ })
504+ .getResult (0 );
505+ }
484506
485507 auto stridesAttr = rewriter.getI64VectorAttr (strideInts);
486508 auto dilationAttr = rewriter.getI64VectorAttr (dilationInts);
487509 Value conv2d =
488510 rewriter
489511 .create <linalg::Conv2DNchwFchwOp>(
490- loc, initTensor0 .getType (), ValueRange{paddedInput, weight},
491- initTensor0 , stridesAttr, dilationAttr)
512+ loc, biasInitTensor .getType (), ValueRange{paddedInput, weight},
513+ biasInitTensor , stridesAttr, dilationAttr)
492514 .getResult (0 );
493515 Type newResultType = getTypeConverter ()->convertType (op.getType ());
494516 rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, conv2d);
0 commit comments