Skip to content

Commit fe49a45

Browse files
Generalize canonicalization for padding, stride, dilation and output_padding
1 parent b0cd290 commit fe49a45

File tree

4 files changed

+240
-122
lines changed

4 files changed

+240
-122
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 132 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -4722,122 +4722,6 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
47224722
return DenseElementsAttr::get(attrty, attrs);
47234723
}
47244724

4725-
namespace {
4726-
class CanonicalizeConvolutionWithSingleIntTuple
4727-
: public OpRewritePattern<AtenConvolutionOp> {
4728-
public:
4729-
using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern;
4730-
4731-
LogicalResult matchAndRewrite(AtenConvolutionOp op,
4732-
PatternRewriter &rewriter) const override {
4733-
4734-
auto weight = op.getWeight();
4735-
auto weightType = dyn_cast<ValueTensorType>(weight.getType());
4736-
4737-
if (!weightType) {
4738-
return rewriter.notifyMatchFailure(op, "weight is not a vtensor");
4739-
}
4740-
auto optionalSizes = weightType.getOptionalSizes();
4741-
if (!optionalSizes.has_value()) {
4742-
return rewriter.notifyMatchFailure(op,
4743-
"unranked weight tensor unsupported!");
4744-
}
4745-
4746-
// The rank is the size of the dimensions array
4747-
int64_t weightRank = optionalSizes.value().size();
4748-
4749-
// We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv).
4750-
if (weightRank < 4 || weightRank > 5) {
4751-
return rewriter.notifyMatchFailure(
4752-
op, "unsupported weight rank (must be 4 or 5)");
4753-
}
4754-
int64_t requiredSpatialDims = weightRank - 2;
4755-
4756-
// Validate stride, padding, output_padding, and dilation are constant
4757-
// lists.
4758-
SmallVector<int64_t> strideInts;
4759-
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) {
4760-
return rewriter.notifyMatchFailure(op,
4761-
"non-const int stride unsupported!");
4762-
}
4763-
SmallVector<int64_t> paddingInts;
4764-
if (!matchPattern(op.getPadding(),
4765-
m_TorchListOfConstantInts(paddingInts))) {
4766-
return rewriter.notifyMatchFailure(op,
4767-
"non-const int padding unsupported!");
4768-
}
4769-
SmallVector<int64_t> outputPaddingInts;
4770-
if (!matchPattern(op.getOutputPadding(),
4771-
m_TorchListOfConstantInts(outputPaddingInts))) {
4772-
return rewriter.notifyMatchFailure(
4773-
op, "non-const int output_padding unsupported!");
4774-
}
4775-
SmallVector<int64_t> dilationInts;
4776-
if (!matchPattern(op.getDilation(),
4777-
m_TorchListOfConstantInts(dilationInts))) {
4778-
return rewriter.notifyMatchFailure(op,
4779-
"non-const int dilation unsupported!");
4780-
}
4781-
4782-
// Canonicalization Logic: Only rewrite if padding provided is 1 element
4783-
// but the convolution requires 2 or 3 elements.
4784-
if (strideInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4785-
return rewriter.notifyMatchFailure(op,
4786-
"stride is already fully specified");
4787-
}
4788-
if (paddingInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4789-
return rewriter.notifyMatchFailure(op,
4790-
"padding is already fully specified");
4791-
}
4792-
if (outputPaddingInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4793-
return rewriter.notifyMatchFailure(
4794-
op, "output_padding is already fully specified");
4795-
}
4796-
if (dilationInts.size() == static_cast<size_t>(requiredSpatialDims)) {
4797-
return rewriter.notifyMatchFailure(op,
4798-
"dialtion is already fully specified");
4799-
}
4800-
4801-
// Construct the new Padding List
4802-
// If user provided padding=[1], and we need 2 or 3 dims, we create
4803-
// padding=[1, 1] or padding = [1,1,1]
4804-
int64_t padVal = paddingInts[0];
4805-
Location loc = op.getLoc();
4806-
4807-
SmallVector<Value> newPaddingValues;
4808-
Value paddingConst = ConstantIntOp::create(
4809-
rewriter, loc, rewriter.getI64IntegerAttr(padVal));
4810-
4811-
for (int i = 0; i < requiredSpatialDims; ++i) {
4812-
newPaddingValues.push_back(paddingConst);
4813-
}
4814-
4815-
// Create the list construct op
4816-
auto newListOp = PrimListConstructOp::create(
4817-
rewriter, loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
4818-
newPaddingValues);
4819-
4820-
// Replace the Op
4821-
// We create a new convolution op, keeping all operands the same except
4822-
// padding
4823-
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
4824-
op, op.getType(), op.getInput(), op.getWeight(), op.getBias(),
4825-
op.getStride(), newListOp.getResult(), op.getDilation(),
4826-
op.getTransposed(), op.getOutputPadding(), op.getGroups());
4827-
4828-
return success();
4829-
}
4830-
};
4831-
} // namespace
4832-
4833-
//===----------------------------------------------------------------------===//
4834-
// AtenConvolutionOp Registration
4835-
//===----------------------------------------------------------------------===//
4836-
void AtenConvolutionOp::getCanonicalizationPatterns(RewritePatternSet &results,
4837-
MLIRContext *context) {
4838-
results.add<CanonicalizeConvolutionWithSingleIntTuple>(context);
4839-
}
4840-
48414725
//===----------------------------------------------------------------------===//
48424726
// AtenIntTensorOp
48434727
//===----------------------------------------------------------------------===//
@@ -6015,6 +5899,138 @@ void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
60155899
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
60165900
}
60175901

5902+
namespace {
5903+
class CanonicalizeConvolutionWithSingleIntTuple
5904+
: public OpRewritePattern<AtenConvolutionOp> {
5905+
public:
5906+
using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern;
5907+
5908+
LogicalResult matchAndRewrite(AtenConvolutionOp op,
5909+
PatternRewriter &rewriter) const override {
5910+
5911+
auto weight = op.getWeight();
5912+
auto weightType = dyn_cast<ValueTensorType>(weight.getType());
5913+
5914+
if (!weightType) {
5915+
return rewriter.notifyMatchFailure(op, "weight is not a vtensor");
5916+
}
5917+
auto optionalSizes = weightType.getOptionalSizes();
5918+
if (!optionalSizes.has_value()) {
5919+
return rewriter.notifyMatchFailure(op,
5920+
"unranked weight tensor unsupported!");
5921+
}
5922+
5923+
// The rank is the size of the dimensions array
5924+
int64_t weightRank = optionalSizes.value().size();
5925+
5926+
// We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv).
5927+
if (weightRank < 4 || weightRank > 5) {
5928+
return rewriter.notifyMatchFailure(
5929+
op, "unsupported weight rank (must be 4 or 5)");
5930+
}
5931+
int requiredSpatialDims = weightRank - 2;
5932+
5933+
// Validate stride, padding, output_padding, and dilation are constant
5934+
// lists.
5935+
SmallVector<int64_t, 3> strideInts;
5936+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) {
5937+
return rewriter.notifyMatchFailure(op,
5938+
"non-const int stride unsupported!");
5939+
}
5940+
SmallVector<int64_t, 3> paddingInts;
5941+
if (!matchPattern(op.getPadding(),
5942+
m_TorchListOfConstantInts(paddingInts))) {
5943+
return rewriter.notifyMatchFailure(op,
5944+
"non-const int padding unsupported!");
5945+
}
5946+
SmallVector<int64_t, 3> outputPaddingInts;
5947+
if (!matchPattern(op.getOutputPadding(),
5948+
m_TorchListOfConstantInts(outputPaddingInts))) {
5949+
return rewriter.notifyMatchFailure(
5950+
op, "non-const int output_padding unsupported!");
5951+
}
5952+
SmallVector<int64_t, 3> dilationInts;
5953+
if (!matchPattern(op.getDilation(),
5954+
m_TorchListOfConstantInts(dilationInts))) {
5955+
return rewriter.notifyMatchFailure(op,
5956+
"non-const int dilation unsupported!");
5957+
}
5958+
5959+
// Canonicalization Logic: Only rewrite if padding provided is 1 element
5960+
// but the convolution requires 2 or 3 elements.
5961+
auto isCanonical = [requiredSpatialDims](ArrayRef<int64_t> param) {
5962+
return param.size() == static_cast<size_t>(requiredSpatialDims);
5963+
};
5964+
5965+
if (isCanonical(strideInts) && isCanonical(paddingInts) &&
5966+
isCanonical(dilationInts) && isCanonical(outputPaddingInts)) {
5967+
return rewriter.notifyMatchFailure(
5968+
op, "stride, padding, dialtion and outputPadding is already fully "
5969+
"specified");
5970+
}
5971+
5972+
expand(strideInts, requiredSpatialDims);
5973+
expand(paddingInts, requiredSpatialDims);
5974+
expand(dilationInts, requiredSpatialDims);
5975+
expand(outputPaddingInts, requiredSpatialDims);
5976+
5977+
// Construct the new List
5978+
// For example: If user provided padding=[1], and we need 2 or 3 dims, we
5979+
// create padding=[1, 1] or padding = [1,1,1]
5980+
Location loc = op.getLoc();
5981+
SmallVector<Value> cstPadding, cstStrides, cstDilation, cstOutputPadding;
5982+
5983+
for (auto dim : llvm::seq<int>(0, requiredSpatialDims)) {
5984+
5985+
cstStrides.push_back(Torch::ConstantIntOp::create(
5986+
rewriter, loc, rewriter.getI64IntegerAttr(strideInts[dim])));
5987+
5988+
cstPadding.push_back(Torch::ConstantIntOp::create(
5989+
rewriter, loc, rewriter.getI64IntegerAttr(paddingInts[dim])));
5990+
5991+
cstDilation.push_back(Torch::ConstantIntOp::create(
5992+
rewriter, loc, rewriter.getI64IntegerAttr(dilationInts[dim])));
5993+
5994+
cstOutputPadding.push_back(Torch::ConstantIntOp::create(
5995+
rewriter, loc, rewriter.getI64IntegerAttr(outputPaddingInts[dim])));
5996+
}
5997+
5998+
auto targetListType =
5999+
Torch::ListType::get(Torch::IntType::get(op->getContext()));
6000+
6001+
// Create the list construct op
6002+
auto stridesList = Torch::PrimListConstructOp::create(
6003+
rewriter, loc, targetListType, cstStrides);
6004+
auto paddingList = Torch::PrimListConstructOp::create(
6005+
rewriter, loc, targetListType, cstPadding);
6006+
auto dilationsList = Torch::PrimListConstructOp::create(
6007+
rewriter, loc, targetListType, cstDilation);
6008+
auto outputPaddingList = Torch::PrimListConstructOp::create(
6009+
rewriter, loc, targetListType, cstOutputPadding);
6010+
6011+
// Replace the Op
6012+
// We create a new convolution op, keeping all operands the same except
6013+
// stride, padding,dilation, and output_padding which are now fully
6014+
// specified
6015+
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
6016+
op, op.getType(), op.getInput(), op.getWeight(), op.getBias(),
6017+
stridesList.getResult(), paddingList.getResult(),
6018+
dilationsList.getResult(), op.getTransposed(),
6019+
outputPaddingList.getResult(), op.getGroups());
6020+
6021+
return success();
6022+
}
6023+
};
6024+
} // namespace
6025+
6026+
//===----------------------------------------------------------------------===//
6027+
// AtenConvolutionOp Registration
6028+
//===----------------------------------------------------------------------===//
6029+
void AtenConvolutionOp::getCanonicalizationPatterns(RewritePatternSet &results,
6030+
MLIRContext *context) {
6031+
results.add<CanonicalizeConvolutionWithSingleIntTuple>(context);
6032+
}
6033+
60186034
//===----------------------------------------------------------------------===//
60196035
// AtenLinalgCrossOp
60206036
//===----------------------------------------------------------------------===//

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,9 +1131,10 @@
11311131
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
11321132
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
11331133
"Convolution2DStaticModule_basic",
1134-
"Convolution2DSingleIntTuplePaddingModule_basic",
1134+
"Convolution2DSingleIntTupleModule_basic",
11351135
"ConvolutionBackwardModule2DStatic_basic",
11361136
"ConvolutionModule2DTransposeStridedStatic_basic",
1137+
"ConvolutionModule2DTransposeScalarTupleParams_basic",
11371138
"Conv_Transpose1dStaticModule_basic",
11381139
"Conv_Transpose2dStaticModule_basic",
11391140
"Conv_Transpose3dStaticModule_basic",
@@ -2167,7 +2168,7 @@
21672168
"Conv2dWithValidPaddingModule_basic",
21682169
"Conv2dWithSamePaddingModule_basic",
21692170
"Convolution2DStaticModule_basic",
2170-
"Convolution2DSingleIntTuplePaddingModule_basic",
2171+
"Convolution2DSingleIntTupleModule_basic",
21712172
"CosineSimilarityStaticModule_basic",
21722173
"DetachModule_basic",
21732174
"DropoutEvalFloatModule_basic",
@@ -2906,6 +2907,7 @@
29062907
"Conv2dWithSamePaddingModule_basic",
29072908
"Conv2dWithValidPaddingModule_basic",
29082909
"Conv3dModule_basic",
2910+
"Conv3dModuleScalarTupleParams_basic",
29092911
"Conv3dWithSamePaddingModule_basic",
29102912
"Conv3dWithValidPaddingModule_basic",
29112913
"ConvolutionModule3DGroups_basic",
@@ -2921,7 +2923,9 @@
29212923
"ConvolutionBackwardModule2DStrided_basic",
29222924
"ConvolutionBackwardModule2D_basic",
29232925
"ConvolutionModule2DGroups_basic",
2926+
"Convolution2DSingleIntTupleModule_basic",
29242927
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
2928+
"ConvolutionModule2DTransposeScalarTupleParams_basic",
29252929
"ConvolutionModule2DTransposeStrided_basic",
29262930
"ConvolutionModule2DTranspose_basic",
29272931
# Error: onnx lowering,
@@ -3691,6 +3695,7 @@
36913695
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
36923696
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
36933697
"Conv3dModule_basic",
3698+
"Conv3dModuleScalarTupleParams_basic",
36943699
"Conv3dWithSamePaddingModule_basic",
36953700
"Conv3dWithValidPaddingModule_basic",
36963701
"ConvTbcModule_basic",
@@ -4332,20 +4337,23 @@
43324337
"Conv2dWithSamePaddingModule_basic",
43334338
"Conv2dWithValidPaddingModule_basic",
43344339
"Conv3dModule_basic",
4340+
"Conv3dModuleScalarTupleParams_basic",
43354341
"Conv3dWithSamePaddingModule_basic",
43364342
"Conv3dWithValidPaddingModule_basic",
43374343
"ConvTbcModule_basic",
43384344
"ConvTranspose2DQInt8_basic",
43394345
"Conv_Transpose2dModule_basic",
43404346
"Convolution2DModule_basic",
43414347
"Convolution2DStridedModule_basic",
4348+
"Convolution2DSingleIntTupleModule_basic",
43424349
"ConvolutionBackwardModule2DPadded_basic",
43434350
"ConvolutionBackwardModule2DStatic_basic",
43444351
"ConvolutionBackwardModule2DStrided_basic",
43454352
"ConvolutionBackwardModule2D_basic",
43464353
"ConvolutionModule2DGroups_basic",
43474354
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
43484355
"ConvolutionModule2DTransposeStridedStatic_basic",
4356+
"ConvolutionModule2DTransposeScalarTupleParams_basic",
43494357
"ConvolutionModule2DTransposeStrided_basic",
43504358
"ConvolutionModule2DTranspose_basic",
43514359
"ConvolutionModule2DGroupedTranspose_basic",

0 commit comments

Comments
 (0)