Skip to content

Commit 328c6e4

Browse files
Fix failure due to output_padding's empty list
1 parent 40c0c86 commit 328c6e4

File tree

1 file changed

+36
-14
lines changed

1 file changed

+36
-14
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5943,36 +5943,51 @@ class CanonicalizeConvolutionWithSingleIntTuple
59435943
return rewriter.notifyMatchFailure(op,
59445944
"non-const int padding unsupported!");
59455945
}
5946-
SmallVector<int64_t, 3> outputPaddingInts;
5947-
if (!matchPattern(op.getOutputPadding(),
5948-
m_TorchListOfConstantInts(outputPaddingInts))) {
5949-
return rewriter.notifyMatchFailure(
5950-
op, "non-const int output_padding unsupported!");
5951-
}
5946+
59525947
SmallVector<int64_t, 3> dilationInts;
59535948
if (!matchPattern(op.getDilation(),
59545949
m_TorchListOfConstantInts(dilationInts))) {
59555950
return rewriter.notifyMatchFailure(op,
59565951
"non-const int dilation unsupported!");
59575952
}
59585953

5954+
bool transposed;
5955+
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) {
5956+
return rewriter.notifyMatchFailure(
5957+
op, "non-const int tranposed unsupported!");
5958+
}
5959+
5960+
SmallVector<int64_t, 3> outputPaddingInts;
5961+
if (!matchPattern(op.getOutputPadding(),
5962+
m_TorchListOfConstantInts(outputPaddingInts))) {
5963+
return rewriter.notifyMatchFailure(
5964+
op, "non-const int output_padding unsupported!");
5965+
}
5966+
59595967
// Canonicalization Logic: Only rewrite if padding provided is 1 element
59605968
// but the convolution requires 2 or 3 elements.
59615969
auto isCanonical = [requiredSpatialDims](ArrayRef<int64_t> param) {
59625970
return param.size() == static_cast<size_t>(requiredSpatialDims);
59635971
};
59645972

59655973
if (isCanonical(strideInts) && isCanonical(paddingInts) &&
5966-
isCanonical(dilationInts) && isCanonical(outputPaddingInts)) {
5974+
isCanonical(dilationInts)) {
59675975
return rewriter.notifyMatchFailure(
59685976
op, "stride, padding, dialtion and outputPadding is already fully "
59695977
"specified");
59705978
}
59715979

5980+
if (transposed && isCanonical(outputPaddingInts)) {
5981+
return rewriter.notifyMatchFailure(
5982+
op, "output_padding is already fully specified");
5983+
}
5984+
59725985
expand(strideInts, requiredSpatialDims);
59735986
expand(paddingInts, requiredSpatialDims);
59745987
expand(dilationInts, requiredSpatialDims);
5975-
expand(outputPaddingInts, requiredSpatialDims);
5988+
5989+
if (transposed)
5990+
expand(outputPaddingInts, requiredSpatialDims);
59765991

59775992
// Construct the new List
59785993
// For example: If user provided padding=[1], and we need 2 or 3 dims, we
@@ -5991,8 +6006,9 @@ class CanonicalizeConvolutionWithSingleIntTuple
59916006
cstDilation.push_back(Torch::ConstantIntOp::create(
59926007
rewriter, loc, rewriter.getI64IntegerAttr(dilationInts[dim])));
59936008

5994-
cstOutputPadding.push_back(Torch::ConstantIntOp::create(
5995-
rewriter, loc, rewriter.getI64IntegerAttr(outputPaddingInts[dim])));
6009+
if (transposed)
6010+
cstOutputPadding.push_back(Torch::ConstantIntOp::create(
6011+
rewriter, loc, rewriter.getI64IntegerAttr(outputPaddingInts[dim])));
59966012
}
59976013

59986014
auto targetListType =
@@ -6005,8 +6021,14 @@ class CanonicalizeConvolutionWithSingleIntTuple
60056021
rewriter, loc, targetListType, cstPadding);
60066022
auto dilationsList = Torch::PrimListConstructOp::create(
60076023
rewriter, loc, targetListType, cstDilation);
6008-
auto outputPaddingList = Torch::PrimListConstructOp::create(
6009-
rewriter, loc, targetListType, cstOutputPadding);
6024+
6025+
Value outputPaddingList;
6026+
if (transposed) {
6027+
outputPaddingList = Torch::PrimListConstructOp::create(
6028+
rewriter, loc, targetListType, cstOutputPadding);
6029+
} else {
6030+
outputPaddingList = op.getOutputPadding();
6031+
}
60106032

60116033
// Replace the Op
60126034
// We create a new convolution op, keeping all operands the same except
@@ -6015,8 +6037,8 @@ class CanonicalizeConvolutionWithSingleIntTuple
60156037
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
60166038
op, op.getType(), op.getInput(), op.getWeight(), op.getBias(),
60176039
stridesList.getResult(), paddingList.getResult(),
6018-
dilationsList.getResult(), op.getTransposed(),
6019-
outputPaddingList.getResult(), op.getGroups());
6040+
dilationsList.getResult(), op.getTransposed(), outputPaddingList,
6041+
op.getGroups());
60206042

60216043
return success();
60226044
}

0 commit comments

Comments
 (0)