Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7119,6 +7119,7 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
Expand Down
155 changes: 155 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/SmallVector.h"
#define DEBUG_TYPE "torch-mlir-torch-dialect"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
Expand Down Expand Up @@ -5898,6 +5899,160 @@ void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
}

namespace {
class CanonicalizeConvolutionWithSingleIntTuple
: public OpRewritePattern<AtenConvolutionOp> {
public:
using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(AtenConvolutionOp op,
PatternRewriter &rewriter) const override {

auto weight = op.getWeight();
auto weightType = dyn_cast<ValueTensorType>(weight.getType());

if (!weightType) {
return rewriter.notifyMatchFailure(op, "weight is not a vtensor");
}
auto optionalSizes = weightType.getOptionalSizes();
if (!optionalSizes.has_value()) {
return rewriter.notifyMatchFailure(op,
"unranked weight tensor unsupported!");
}

// The rank is the size of the dimensions array
int64_t weightRank = optionalSizes.value().size();

// We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv).
if (weightRank < 4 || weightRank > 5) {
return rewriter.notifyMatchFailure(
op, "unsupported weight rank (must be 4 or 5)");
}
int requiredSpatialDims = weightRank - 2;

// Validate stride, padding, output_padding, and dilation are constant
// lists.
SmallVector<int64_t, 3> strideInts;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) {
return rewriter.notifyMatchFailure(op,
"non-const int stride unsupported!");
}
SmallVector<int64_t, 3> paddingInts;
if (!matchPattern(op.getPadding(),
m_TorchListOfConstantInts(paddingInts))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}

SmallVector<int64_t, 3> dilationInts;
if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationInts))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}

bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) {
return rewriter.notifyMatchFailure(
op, "non-const int tranposed unsupported!");
}

SmallVector<int64_t, 3> outputPaddingInts;
if (!matchPattern(op.getOutputPadding(),
m_TorchListOfConstantInts(outputPaddingInts))) {
return rewriter.notifyMatchFailure(
op, "non-const int output_padding unsupported!");
}

// Canonicalization Logic: Only rewrite if padding provided is 1 element
// but the convolution requires 2 or 3 elements.
auto isCanonical = [requiredSpatialDims](ArrayRef<int64_t> param) {
return param.size() == static_cast<size_t>(requiredSpatialDims);
};

if (isCanonical(strideInts) && isCanonical(paddingInts) &&
isCanonical(dilationInts)) {
return rewriter.notifyMatchFailure(
op, "stride, padding, dialtion and outputPadding is already fully "
"specified");
}

if (transposed && isCanonical(outputPaddingInts)) {
return rewriter.notifyMatchFailure(
op, "output_padding is already fully specified");
}

expand(strideInts, requiredSpatialDims);
expand(paddingInts, requiredSpatialDims);
expand(dilationInts, requiredSpatialDims);

if (transposed)
expand(outputPaddingInts, requiredSpatialDims);

// Construct the new List
// For example: If user provided padding=[1], and we need 2 or 3 dims, we
// create padding=[1, 1] or padding = [1,1,1]
Location loc = op.getLoc();
SmallVector<Value> cstPadding, cstStrides, cstDilation, cstOutputPadding;

for (auto dim : llvm::seq<int>(0, requiredSpatialDims)) {

cstStrides.push_back(Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(strideInts[dim])));

cstPadding.push_back(Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(paddingInts[dim])));

cstDilation.push_back(Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(dilationInts[dim])));

if (transposed)
cstOutputPadding.push_back(Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(outputPaddingInts[dim])));
}

auto targetListType =
Torch::ListType::get(Torch::IntType::get(op->getContext()));

// Create the list construct op
auto stridesList = Torch::PrimListConstructOp::create(
rewriter, loc, targetListType, cstStrides);
auto paddingList = Torch::PrimListConstructOp::create(
rewriter, loc, targetListType, cstPadding);
auto dilationsList = Torch::PrimListConstructOp::create(
rewriter, loc, targetListType, cstDilation);

Value outputPaddingList;
if (transposed) {
outputPaddingList = Torch::PrimListConstructOp::create(
rewriter, loc, targetListType, cstOutputPadding);
} else {
outputPaddingList = op.getOutputPadding();
}

// Replace the Op
// We create a new convolution op, keeping all operands the same except
// stride, padding,dilation, and output_padding which are now fully
// specified
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op.getType(), op.getInput(), op.getWeight(), op.getBias(),
stridesList.getResult(), paddingList.getResult(),
dilationsList.getResult(), op.getTransposed(), outputPaddingList,
op.getGroups());

return success();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// AtenConvolutionOp Registration
//===----------------------------------------------------------------------===//
void AtenConvolutionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeConvolutionWithSingleIntTuple>(context);
}

//===----------------------------------------------------------------------===//
// AtenLinalgCrossOp
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,8 +1133,10 @@
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Convolution2DStaticModule_basic",
"Convolution2DSingleIntTupleModule_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ConvolutionModule2DTransposeScalarTupleParams_basic",
"Conv_Transpose1dStaticModule_basic",
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dStaticModule_basic",
Expand Down Expand Up @@ -2168,6 +2170,7 @@
"Conv2dWithValidPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
"Convolution2DStaticModule_basic",
"Convolution2DSingleIntTupleModule_basic",
"CosineSimilarityStaticModule_basic",
"DetachModule_basic",
"DropoutEvalFloatModule_basic",
Expand Down Expand Up @@ -2908,6 +2911,7 @@
"Conv2dWithSamePaddingModule_basic",
"Conv2dWithValidPaddingModule_basic",
"Conv3dModule_basic",
"Conv3dModuleScalarTupleParams_basic",
"Conv3dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ConvolutionModule3DGroups_basic",
Expand All @@ -2923,7 +2927,9 @@
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
"ConvolutionModule2DGroups_basic",
"Convolution2DSingleIntTupleModule_basic",
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
"ConvolutionModule2DTransposeScalarTupleParams_basic",
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
# Error: onnx lowering,
Expand Down Expand Up @@ -3694,6 +3700,7 @@
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv3dModule_basic",
"Conv3dModuleScalarTupleParams_basic",
"Conv3dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ConvTbcModule_basic",
Expand Down Expand Up @@ -4333,20 +4340,23 @@
"Conv2dWithSamePaddingModule_basic",
"Conv2dWithValidPaddingModule_basic",
"Conv3dModule_basic",
"Conv3dModuleScalarTupleParams_basic",
"Conv3dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
"Convolution2DModule_basic",
"Convolution2DStridedModule_basic",
"Convolution2DSingleIntTupleModule_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
"ConvolutionModule2DGroups_basic",
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ConvolutionModule2DTransposeScalarTupleParams_basic",
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"ConvolutionModule2DGroupedTranspose_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,8 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)"
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)",
has_canonicalizer=True,
)
emit(
"aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)"
Expand Down
97 changes: 97 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,37 @@ def Convolution2DStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))


class Convolution2DSingleIntTupleModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3, 3, 10, 10], torch.float32, True),
([3, 3, 2, 2], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(
inputVec,
weight,
bias=None,
stride=(1,),
padding=(0,),
dilation=(1,),
transposed=False,
output_padding=[0, 0],
groups=1,
)


@register_test_case(module_factory=lambda: Convolution2DSingleIntTupleModule())
def Convolution2DSingleIntTupleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))


class Convolution2DStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -901,6 +932,39 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils
module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3))


class ConvolutionModule2DTransposeScalarTupleParams(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([5, 2, 5, 6], torch.float32, True),
([2, 5, 2, 2], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(
inputVec,
weight,
bias=None,
stride=(1,),
padding=(1,),
dilation=(1,),
transposed=True,
output_padding=(0,),
groups=1,
)


@register_test_case(
module_factory=lambda: ConvolutionModule2DTransposeScalarTupleParams()
)
def ConvolutionModule2DTransposeScalarTupleParams_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))


class Conv_Transpose1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1569,6 +1633,39 @@ def Conv3dWithValidPaddingModule_basic(module, tu: TestUtils):
module.forward(inputVec, weight, bias)


class Conv3dModuleScalarTupleParams(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, inputVec, weight, bias):
return torch.ops.aten.conv3d(
inputVec,
weight,
bias=bias,
stride=(1,),
padding=(0,),
dilation=(1,),
groups=1,
)


@register_test_case(module_factory=lambda: Conv3dModuleScalarTupleParams())
def Conv3dModuleScalarTupleParams_basic(module, tu: TestUtils):
inputVec = tu.rand(2, 2, 6, 6, 6)
weight = torch.randn(8, 2, 3, 3, 3)
bias = torch.randn(8)
module.forward(inputVec, weight, bias)


class ConvTbcModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading