Skip to content

Commit b0cd290

Browse files
canonicalize aten.convolution
1 parent 244f4b6 commit b0cd290

File tree

5 files changed

+153
-1
lines changed

5 files changed

+153
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7119,6 +7119,7 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
71197119
printDefaultTorchOp(printer, *this, 9, 1);
71207120
}
71217121
}];
7122+
let hasCanonicalizer = 1;
71227123
}
71237124

71247125
def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// Also available under a BSD-style license. See LICENSE.
77
//
88
//===----------------------------------------------------------------------===//
9+
#include "llvm/ADT/SmallVector.h"
910
#define DEBUG_TYPE "torch-mlir-torch-dialect"
1011
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1112
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@@ -4721,6 +4722,122 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
47214722
return DenseElementsAttr::get(attrty, attrs);
47224723
}
47234724

4725+
namespace {
4726+
class CanonicalizeConvolutionWithSingleIntTuple
4727+
: public OpRewritePattern<AtenConvolutionOp> {
4728+
public:
4729+
using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern;
4730+
4731+
LogicalResult matchAndRewrite(AtenConvolutionOp op,
4732+
PatternRewriter &rewriter) const override {
4733+
4734+
auto weight = op.getWeight();
4735+
auto weightType = dyn_cast<ValueTensorType>(weight.getType());
4736+
4737+
if (!weightType) {
4738+
return rewriter.notifyMatchFailure(op, "weight is not a vtensor");
4739+
}
4740+
auto optionalSizes = weightType.getOptionalSizes();
4741+
if (!optionalSizes.has_value()) {
4742+
return rewriter.notifyMatchFailure(op,
4743+
"unranked weight tensor unsupported!");
4744+
}
4745+
4746+
// The rank is the size of the dimensions array
4747+
int64_t weightRank = optionalSizes.value().size();
4748+
4749+
// We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv).
4750+
if (weightRank < 4 || weightRank > 5) {
4751+
return rewriter.notifyMatchFailure(
4752+
op, "unsupported weight rank (must be 4 or 5)");
4753+
}
4754+
int64_t requiredSpatialDims = weightRank - 2;
4755+
4756+
// Validate stride, padding, output_padding, and dilation are constant
4757+
// lists.
4758+
SmallVector<int64_t> strideInts;
4759+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) {
4760+
return rewriter.notifyMatchFailure(op,
4761+
"non-const int stride unsupported!");
4762+
}
4763+
SmallVector<int64_t> paddingInts;
4764+
if (!matchPattern(op.getPadding(),
4765+
m_TorchListOfConstantInts(paddingInts))) {
4766+
return rewriter.notifyMatchFailure(op,
4767+
"non-const int padding unsupported!");
4768+
}
4769+
SmallVector<int64_t> outputPaddingInts;
4770+
if (!matchPattern(op.getOutputPadding(),
4771+
m_TorchListOfConstantInts(outputPaddingInts))) {
4772+
return rewriter.notifyMatchFailure(
4773+
op, "non-const int output_padding unsupported!");
4774+
}
4775+
SmallVector<int64_t> dilationInts;
4776+
if (!matchPattern(op.getDilation(),
4777+
m_TorchListOfConstantInts(dilationInts))) {
4778+
return rewriter.notifyMatchFailure(op,
4779+
"non-const int dilation unsupported!");
4780+
}
4781+
4782+
// Canonicalization Logic: Only rewrite if padding provided is 1 element
4783+
// but the convolution requires 2 or 3 elements.
4784+
if (strideInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4785+
return rewriter.notifyMatchFailure(op,
4786+
"stride is already fully specified");
4787+
}
4788+
if (paddingInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4789+
return rewriter.notifyMatchFailure(op,
4790+
"padding is already fully specified");
4791+
}
4792+
if (outputPaddingInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4793+
return rewriter.notifyMatchFailure(
4794+
op, "output_padding is already fully specified");
4795+
}
4796+
if (dilationInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4797+
return rewriter.notifyMatchFailure(op,
4798+
"dialtion is already fully specified");
4799+
}
4800+
4801+
// Construct the new Padding List
4802+
// If user provided padding=[1], and we need 2 or 3 dims, we create
4803+
// padding=[1, 1] or padding = [1,1,1]
4804+
int64_t padVal = paddingInts[0];
4805+
Location loc = op.getLoc();
4806+
4807+
SmallVector<Value> newPaddingValues;
4808+
Value paddingConst = ConstantIntOp::create(
4809+
rewriter, loc, rewriter.getI64IntegerAttr(padVal));
4810+
4811+
for (int i = 0; i < requiredSpatialDims; ++i) {
4812+
newPaddingValues.push_back(paddingConst);
4813+
}
4814+
4815+
// Create the list construct op
4816+
auto newListOp = PrimListConstructOp::create(
4817+
rewriter, loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
4818+
newPaddingValues);
4819+
4820+
// Replace the Op
4821+
// We create a new convolution op, keeping all operands the same except
4822+
// padding
4823+
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
4824+
op, op.getType(), op.getInput(), op.getWeight(), op.getBias(),
4825+
op.getStride(), newListOp.getResult(), op.getDilation(),
4826+
op.getTransposed(), op.getOutputPadding(), op.getGroups());
4827+
4828+
return success();
4829+
}
4830+
};
4831+
} // namespace
4832+
4833+
//===----------------------------------------------------------------------===//
4834+
// AtenConvolutionOp Registration
4835+
//===----------------------------------------------------------------------===//
4836+
void AtenConvolutionOp::getCanonicalizationPatterns(RewritePatternSet &results,
4837+
MLIRContext *context) {
4838+
results.add<CanonicalizeConvolutionWithSingleIntTuple>(context);
4839+
}
4840+
47244841
//===----------------------------------------------------------------------===//
47254842
// AtenIntTensorOp
47264843
//===----------------------------------------------------------------------===//

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@
11311131
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
11321132
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
11331133
"Convolution2DStaticModule_basic",
1134+
"Convolution2DSingleIntTuplePaddingModule_basic",
11341135
"ConvolutionBackwardModule2DStatic_basic",
11351136
"ConvolutionModule2DTransposeStridedStatic_basic",
11361137
"Conv_Transpose1dStaticModule_basic",
@@ -2166,6 +2167,7 @@
21662167
"Conv2dWithValidPaddingModule_basic",
21672168
"Conv2dWithSamePaddingModule_basic",
21682169
"Convolution2DStaticModule_basic",
2170+
"Convolution2DSingleIntTuplePaddingModule_basic",
21692171
"CosineSimilarityStaticModule_basic",
21702172
"DetachModule_basic",
21712173
"DropoutEvalFloatModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,8 @@ def emit_with_mutating_variants(key, **kwargs):
612612
"aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)"
613613
)
614614
emit(
615-
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)"
615+
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)",
616+
has_canonicalizer=True,
616617
)
617618
emit(
618619
"aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)"

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,37 @@ def Convolution2DStaticModule_basic(module, tu: TestUtils):
304304
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
305305

306306

307+
class Convolution2DSingleIntTuplePaddingModule(torch.nn.Module):
308+
def __init__(self):
309+
super().__init__()
310+
311+
@export
312+
@annotate_args(
313+
[
314+
None,
315+
([3, 3, 10, 10], torch.float32, True),
316+
([3, 3, 2, 2], torch.float32, True),
317+
]
318+
)
319+
def forward(self, inputVec, weight):
320+
return torch.ops.aten.convolution(
321+
inputVec,
322+
weight,
323+
bias=None,
324+
stride=(4,),
325+
padding=(0,),
326+
dilation=(1,),
327+
transposed=False,
328+
output_padding=[0, 0],
329+
groups=1,
330+
)
331+
332+
333+
@register_test_case(module_factory=lambda: Convolution2DSingleIntTuplePaddingModule())
334+
def Convolution2DSingleIntTuplePaddingModule_basic(module, tu: TestUtils):
335+
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
336+
337+
307338
class Convolution2DStridedModule(torch.nn.Module):
308339
def __init__(self):
309340
super().__init__()

0 commit comments

Comments
 (0)