@@ -5943,36 +5943,51 @@ class CanonicalizeConvolutionWithSingleIntTuple
59435943 return rewriter.notifyMatchFailure (op,
59445944 " non-const int padding unsupported!" );
59455945 }
5946- SmallVector<int64_t , 3 > outputPaddingInts;
5947- if (!matchPattern (op.getOutputPadding (),
5948- m_TorchListOfConstantInts (outputPaddingInts))) {
5949- return rewriter.notifyMatchFailure (
5950- op, " non-const int output_padding unsupported!" );
5951- }
5946+
59525947 SmallVector<int64_t , 3 > dilationInts;
59535948 if (!matchPattern (op.getDilation (),
59545949 m_TorchListOfConstantInts (dilationInts))) {
59555950 return rewriter.notifyMatchFailure (op,
59565951 " non-const int dilation unsupported!" );
59575952 }
59585953
5954+ bool transposed;
5955+ if (!matchPattern (op.getTransposed (), m_TorchConstantBool (&transposed))) {
5956+ return rewriter.notifyMatchFailure (
5957+ op, " non-const int tranposed unsupported!" );
5958+ }
5959+
5960+ SmallVector<int64_t , 3 > outputPaddingInts;
5961+ if (!matchPattern (op.getOutputPadding (),
5962+ m_TorchListOfConstantInts (outputPaddingInts))) {
5963+ return rewriter.notifyMatchFailure (
5964+ op, " non-const int output_padding unsupported!" );
5965+ }
5966+
59595967 // Canonicalization Logic: Only rewrite if padding provided is 1 element
59605968 // but the convolution requires 2 or 3 elements.
59615969 auto isCanonical = [requiredSpatialDims](ArrayRef<int64_t > param) {
59625970 return param.size () == static_cast <size_t >(requiredSpatialDims);
59635971 };
59645972
59655973 if (isCanonical (strideInts) && isCanonical (paddingInts) &&
5966- isCanonical (dilationInts) && isCanonical (outputPaddingInts) ) {
5974+ isCanonical (dilationInts)) {
59675975 return rewriter.notifyMatchFailure (
59685976 op, " stride, padding, dialtion and outputPadding is already fully "
59695977 " specified" );
59705978 }
59715979
5980+ if (transposed && isCanonical (outputPaddingInts)) {
5981+ return rewriter.notifyMatchFailure (
5982+ op, " output_padding is already fully specified" );
5983+ }
5984+
59725985 expand (strideInts, requiredSpatialDims);
59735986 expand (paddingInts, requiredSpatialDims);
59745987 expand (dilationInts, requiredSpatialDims);
5975- expand (outputPaddingInts, requiredSpatialDims);
5988+
5989+ if (transposed)
5990+ expand (outputPaddingInts, requiredSpatialDims);
59765991
59775992 // Construct the new List
59785993 // For example: If user provided padding=[1], and we need 2 or 3 dims, we
@@ -5991,8 +6006,9 @@ class CanonicalizeConvolutionWithSingleIntTuple
59916006 cstDilation.push_back (Torch::ConstantIntOp::create (
59926007 rewriter, loc, rewriter.getI64IntegerAttr (dilationInts[dim])));
59936008
5994- cstOutputPadding.push_back (Torch::ConstantIntOp::create (
5995- rewriter, loc, rewriter.getI64IntegerAttr (outputPaddingInts[dim])));
6009+ if (transposed)
6010+ cstOutputPadding.push_back (Torch::ConstantIntOp::create (
6011+ rewriter, loc, rewriter.getI64IntegerAttr (outputPaddingInts[dim])));
59966012 }
59976013
59986014 auto targetListType =
@@ -6005,8 +6021,14 @@ class CanonicalizeConvolutionWithSingleIntTuple
60056021 rewriter, loc, targetListType, cstPadding);
60066022 auto dilationsList = Torch::PrimListConstructOp::create (
60076023 rewriter, loc, targetListType, cstDilation);
6008- auto outputPaddingList = Torch::PrimListConstructOp::create (
6009- rewriter, loc, targetListType, cstOutputPadding);
6024+
6025+ Value outputPaddingList;
6026+ if (transposed) {
6027+ outputPaddingList = Torch::PrimListConstructOp::create (
6028+ rewriter, loc, targetListType, cstOutputPadding);
6029+ } else {
6030+ outputPaddingList = op.getOutputPadding ();
6031+ }
60106032
60116033 // Replace the Op
60126034 // We create a new convolution op, keeping all operands the same except
@@ -6015,8 +6037,8 @@ class CanonicalizeConvolutionWithSingleIntTuple
60156037 rewriter.replaceOpWithNewOp <AtenConvolutionOp>(
60166038 op, op.getType (), op.getInput (), op.getWeight (), op.getBias (),
60176039 stridesList.getResult (), paddingList.getResult (),
6018- dilationsList.getResult (), op.getTransposed (),
6019- outputPaddingList. getResult (), op.getGroups ());
6040+ dilationsList.getResult (), op.getTransposed (), outputPaddingList,
6041+ op.getGroups ());
60206042
60216043 return success ();
60226044 }
0 commit comments