Skip to content

Commit 2414bdb

Browse files
ljfitzsilvasean
authored andcommitted
Linalg lowering for aten.conv2d(bias=True)
Previously aten.conv2d was only lowered if there was no bias. Here lowering is extended to support bias.
1 parent c598e01 commit 2414bdb

File tree

2 files changed

+55
-11
lines changed

2 files changed

+55
-11
lines changed

e2e_testing/torchscript/conv.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,28 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
3333
module.forward(t)
3434

3535

36+
class Conv2dBiasNoPaddingModule(torch.nn.Module):
37+
def __init__(self):
38+
super().__init__()
39+
torch.manual_seed(0)
40+
self.conv = torch.nn.Conv2d(2, 10, 3, bias=True)
41+
self.train(False)
42+
43+
@export
44+
@annotate_args([
45+
None,
46+
([-1, -1, -1, -1], torch.float32, True),
47+
])
48+
def forward(self, x):
49+
return self.conv(x)
50+
51+
52+
@register_test_case(module_factory=lambda: Conv2dBiasNoPaddingModule())
53+
def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
54+
t = tu.rand(5, 2, 10, 20)
55+
module.forward(t)
56+
57+
3658
class Conv2dWithPaddingModule(torch.nn.Module):
3759
def __init__(self):
3860
super().__init__()

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,6 @@ class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
442442
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
443443
return rewriter.notifyMatchFailure(op,
444444
"only support constant int dilations");
445-
if (!op.bias().getType().isa<Torch::NoneType>())
446-
return rewriter.notifyMatchFailure(op, "only support None bias");
447-
448445
Value c1 =
449446
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
450447
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
@@ -473,22 +470,47 @@ class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
473470
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
474471
castIndexToInt(weightW), strideIntValues[1]);
475472

476-
Value c0float = rewriter.create<arith::ConstantOp>(
477-
loc,
478-
FloatAttr::get(
479-
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
480473
Value initTensor = rewriter.create<linalg::InitTensorOp>(
481474
loc, ValueRange{N, F, Hout, Wout}, elementType);
482-
Value initTensor0 =
483-
rewriter.create<linalg::FillOp>(loc, c0float, initTensor).getResult(0);
475+
476+
Value bias = adaptor.bias();
477+
Value biasInitTensor;
478+
if (bias.getType().isa<Torch::NoneType>()) {
479+
Value c0float = rewriter.create<arith::ConstantOp>(
480+
loc, FloatAttr::get(elementType, 0.0));
481+
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
482+
.getResult(0);
483+
} else {
484+
auto biasType = bias.getType().cast<RankedTensorType>();
485+
if (biasType.getRank() != 1)
486+
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
487+
if (elementType != biasType.getElementType())
488+
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
489+
490+
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
491+
SmallVector<AffineMap> indexingMaps = {
492+
// bias is used to initialize the channels - dimension 1 of output
493+
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
494+
rewriter.getAffineDimExpr(1), context),
495+
rewriter.getMultiDimIdentityMap(resultRank)};
496+
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
497+
biasInitTensor = rewriter
498+
.create<linalg::GenericOp>(
499+
loc, initTensor.getType(), bias, initTensor,
500+
indexingMaps, iteratorTypes,
501+
[](OpBuilder &b, Location loc, ValueRange args) {
502+
b.create<linalg::YieldOp>(loc, args[0]);
503+
})
504+
.getResult(0);
505+
}
484506

485507
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
486508
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
487509
Value conv2d =
488510
rewriter
489511
.create<linalg::Conv2DNchwFchwOp>(
490-
loc, initTensor0.getType(), ValueRange{paddedInput, weight},
491-
initTensor0, stridesAttr, dilationAttr)
512+
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
513+
biasInitTensor, stridesAttr, dilationAttr)
492514
.getResult(0);
493515
Type newResultType = getTypeConverter()->convertType(op.getType());
494516
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);

0 commit comments

Comments
 (0)