@@ -4722,122 +4722,6 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
47224722 return DenseElementsAttr::get (attrty, attrs);
47234723}
47244724
4725- namespace {
4726- class CanonicalizeConvolutionWithSingleIntTuple
4727- : public OpRewritePattern<AtenConvolutionOp> {
4728- public:
4729- using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern;
4730-
4731- LogicalResult matchAndRewrite (AtenConvolutionOp op,
4732- PatternRewriter &rewriter) const override {
4733-
4734- auto weight = op.getWeight ();
4735- auto weightType = dyn_cast<ValueTensorType>(weight.getType ());
4736-
4737- if (!weightType) {
4738- return rewriter.notifyMatchFailure (op, " weight is not a vtensor" );
4739- }
4740- auto optionalSizes = weightType.getOptionalSizes ();
4741- if (!optionalSizes.has_value ()) {
4742- return rewriter.notifyMatchFailure (op,
4743- " unranked weight tensor unsupported!" );
4744- }
4745-
4746- // The rank is the size of the dimensions array
4747- int64_t weightRank = optionalSizes.value ().size ();
4748-
4749- // We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv).
4750- if (weightRank < 4 || weightRank > 5 ) {
4751- return rewriter.notifyMatchFailure (
4752- op, " unsupported weight rank (must be 4 or 5)" );
4753- }
4754- int64_t requiredSpatialDims = weightRank - 2 ;
4755-
4756- // Validate stride, padding, output_padding, and dilation are constant
4757- // lists.
4758- SmallVector<int64_t > strideInts;
4759- if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts))) {
4760- return rewriter.notifyMatchFailure (op,
4761- " non-const int stride unsupported!" );
4762- }
4763- SmallVector<int64_t > paddingInts;
4764- if (!matchPattern (op.getPadding (),
4765- m_TorchListOfConstantInts (paddingInts))) {
4766- return rewriter.notifyMatchFailure (op,
4767- " non-const int padding unsupported!" );
4768- }
4769- SmallVector<int64_t > outputPaddingInts;
4770- if (!matchPattern (op.getOutputPadding (),
4771- m_TorchListOfConstantInts (outputPaddingInts))) {
4772- return rewriter.notifyMatchFailure (
4773- op, " non-const int output_padding unsupported!" );
4774- }
4775- SmallVector<int64_t > dilationInts;
4776- if (!matchPattern (op.getDilation (),
4777- m_TorchListOfConstantInts (dilationInts))) {
4778- return rewriter.notifyMatchFailure (op,
4779- " non-const int dilation unsupported!" );
4780- }
4781-
4782- // Canonicalization Logic: Only rewrite if padding provided is 1 element
4783- // but the convolution requires 2 or 3 elements.
4784- if (strideInts.size () == static_cast <size_t >(requiredSpatialDims)) {
4785- return rewriter.notifyMatchFailure (op,
4786- " stride is already fully specified" );
4787- }
4788- if (paddingInts.size () == static_cast <size_t >(requiredSpatialDims)) {
4789- return rewriter.notifyMatchFailure (op,
4790- " padding is already fully specified" );
4791- }
4792- if (outputPaddingInts.size () == static_cast <size_t >(requiredSpatialDims)) {
4793- return rewriter.notifyMatchFailure (
4794- op, " output_padding is already fully specified" );
4795- }
4796- if (dilationInts.size () == static_cast <size_t >(requiredSpatialDims)) {
4797- return rewriter.notifyMatchFailure (op,
4798- " dialtion is already fully specified" );
4799- }
4800-
4801- // Construct the new Padding List
4802- // If user provided padding=[1], and we need 2 or 3 dims, we create
4803- // padding=[1, 1] or padding = [1,1,1]
4804- int64_t padVal = paddingInts[0 ];
4805- Location loc = op.getLoc ();
4806-
4807- SmallVector<Value> newPaddingValues;
4808- Value paddingConst = ConstantIntOp::create (
4809- rewriter, loc, rewriter.getI64IntegerAttr (padVal));
4810-
4811- for (int i = 0 ; i < requiredSpatialDims; ++i) {
4812- newPaddingValues.push_back (paddingConst);
4813- }
4814-
4815- // Create the list construct op
4816- auto newListOp = PrimListConstructOp::create (
4817- rewriter, loc, Torch::ListType::get (rewriter.getType <Torch::IntType>()),
4818- newPaddingValues);
4819-
4820- // Replace the Op
4821- // We create a new convolution op, keeping all operands the same except
4822- // padding
4823- rewriter.replaceOpWithNewOp <AtenConvolutionOp>(
4824- op, op.getType (), op.getInput (), op.getWeight (), op.getBias (),
4825- op.getStride (), newListOp.getResult (), op.getDilation (),
4826- op.getTransposed (), op.getOutputPadding (), op.getGroups ());
4827-
4828- return success ();
4829- }
4830- };
4831- } // namespace
4832-
4833- // ===----------------------------------------------------------------------===//
4834- // AtenConvolutionOp Registration
4835- // ===----------------------------------------------------------------------===//
4836- void AtenConvolutionOp::getCanonicalizationPatterns (RewritePatternSet &results,
4837- MLIRContext *context) {
4838- results.add <CanonicalizeConvolutionWithSingleIntTuple>(context);
4839- }
4840-
48414725// ===----------------------------------------------------------------------===//
48424726// AtenIntTensorOp
48434727// ===----------------------------------------------------------------------===//
@@ -6015,6 +5899,138 @@ void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
60155899 patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
60165900}
60175901
5902+ namespace {
5903+ class CanonicalizeConvolutionWithSingleIntTuple
5904+ : public OpRewritePattern<AtenConvolutionOp> {
5905+ public:
5906+ using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern;
5907+
5908+ LogicalResult matchAndRewrite (AtenConvolutionOp op,
5909+ PatternRewriter &rewriter) const override {
5910+
5911+ auto weight = op.getWeight ();
5912+ auto weightType = dyn_cast<ValueTensorType>(weight.getType ());
5913+
5914+ if (!weightType) {
5915+ return rewriter.notifyMatchFailure (op, " weight is not a vtensor" );
5916+ }
5917+ auto optionalSizes = weightType.getOptionalSizes ();
5918+ if (!optionalSizes.has_value ()) {
5919+ return rewriter.notifyMatchFailure (op,
5920+ " unranked weight tensor unsupported!" );
5921+ }
5922+
5923+ // The rank is the size of the dimensions array
5924+ int64_t weightRank = optionalSizes.value ().size ();
5925+
5926+ // We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv).
5927+ if (weightRank < 4 || weightRank > 5 ) {
5928+ return rewriter.notifyMatchFailure (
5929+ op, " unsupported weight rank (must be 4 or 5)" );
5930+ }
5931+ int requiredSpatialDims = weightRank - 2 ;
5932+
5933+ // Validate stride, padding, output_padding, and dilation are constant
5934+ // lists.
5935+ SmallVector<int64_t , 3 > strideInts;
5936+ if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts))) {
5937+ return rewriter.notifyMatchFailure (op,
5938+ " non-const int stride unsupported!" );
5939+ }
5940+ SmallVector<int64_t , 3 > paddingInts;
5941+ if (!matchPattern (op.getPadding (),
5942+ m_TorchListOfConstantInts (paddingInts))) {
5943+ return rewriter.notifyMatchFailure (op,
5944+ " non-const int padding unsupported!" );
5945+ }
5946+ SmallVector<int64_t , 3 > outputPaddingInts;
5947+ if (!matchPattern (op.getOutputPadding (),
5948+ m_TorchListOfConstantInts (outputPaddingInts))) {
5949+ return rewriter.notifyMatchFailure (
5950+ op, " non-const int output_padding unsupported!" );
5951+ }
5952+ SmallVector<int64_t , 3 > dilationInts;
5953+ if (!matchPattern (op.getDilation (),
5954+ m_TorchListOfConstantInts (dilationInts))) {
5955+ return rewriter.notifyMatchFailure (op,
5956+ " non-const int dilation unsupported!" );
5957+ }
5958+
5959+ // Canonicalization Logic: Only rewrite if padding provided is 1 element
5960+ // but the convolution requires 2 or 3 elements.
5961+ auto isCanonical = [requiredSpatialDims](ArrayRef<int64_t > param) {
5962+ return param.size () == static_cast <size_t >(requiredSpatialDims);
5963+ };
5964+
5965+ if (isCanonical (strideInts) && isCanonical (paddingInts) &&
5966+ isCanonical (dilationInts) && isCanonical (outputPaddingInts)) {
5967+ return rewriter.notifyMatchFailure (
5968+ op, " stride, padding, dialtion and outputPadding is already fully "
5969+ " specified" );
5970+ }
5971+
5972+ expand (strideInts, requiredSpatialDims);
5973+ expand (paddingInts, requiredSpatialDims);
5974+ expand (dilationInts, requiredSpatialDims);
5975+ expand (outputPaddingInts, requiredSpatialDims);
5976+
5977+ // Construct the new List
5978+ // For example: If user provided padding=[1], and we need 2 or 3 dims, we
5979+ // create padding=[1, 1] or padding = [1,1,1]
5980+ Location loc = op.getLoc ();
5981+ SmallVector<Value> cstPadding, cstStrides, cstDilation, cstOutputPadding;
5982+
5983+ for (auto dim : llvm::seq<int >(0 , requiredSpatialDims)) {
5984+
5985+ cstStrides.push_back (Torch::ConstantIntOp::create (
5986+ rewriter, loc, rewriter.getI64IntegerAttr (strideInts[dim])));
5987+
5988+ cstPadding.push_back (Torch::ConstantIntOp::create (
5989+ rewriter, loc, rewriter.getI64IntegerAttr (paddingInts[dim])));
5990+
5991+ cstDilation.push_back (Torch::ConstantIntOp::create (
5992+ rewriter, loc, rewriter.getI64IntegerAttr (dilationInts[dim])));
5993+
5994+ cstOutputPadding.push_back (Torch::ConstantIntOp::create (
5995+ rewriter, loc, rewriter.getI64IntegerAttr (outputPaddingInts[dim])));
5996+ }
5997+
5998+ auto targetListType =
5999+ Torch::ListType::get (Torch::IntType::get (op->getContext ()));
6000+
6001+ // Create the list construct op
6002+ auto stridesList = Torch::PrimListConstructOp::create (
6003+ rewriter, loc, targetListType, cstStrides);
6004+ auto paddingList = Torch::PrimListConstructOp::create (
6005+ rewriter, loc, targetListType, cstPadding);
6006+ auto dilationsList = Torch::PrimListConstructOp::create (
6007+ rewriter, loc, targetListType, cstDilation);
6008+ auto outputPaddingList = Torch::PrimListConstructOp::create (
6009+ rewriter, loc, targetListType, cstOutputPadding);
6010+
6011+ // Replace the Op
6012+ // We create a new convolution op, keeping all operands the same except
6013+ // stride, padding,dilation, and output_padding which are now fully
6014+ // specified
6015+ rewriter.replaceOpWithNewOp <AtenConvolutionOp>(
6016+ op, op.getType (), op.getInput (), op.getWeight (), op.getBias (),
6017+ stridesList.getResult (), paddingList.getResult (),
6018+ dilationsList.getResult (), op.getTransposed (),
6019+ outputPaddingList.getResult (), op.getGroups ());
6020+
6021+ return success ();
6022+ }
6023+ };
6024+ } // namespace
6025+
6026+ // ===----------------------------------------------------------------------===//
6027+ // AtenConvolutionOp Registration
6028+ // ===----------------------------------------------------------------------===//
6029+ void AtenConvolutionOp::getCanonicalizationPatterns (RewritePatternSet &results,
6030+ MLIRContext *context) {
6031+ results.add <CanonicalizeConvolutionWithSingleIntTuple>(context);
6032+ }
6033+
60186034// ===----------------------------------------------------------------------===//
60196035// AtenLinalgCrossOp
60206036// ===----------------------------------------------------------------------===//
0 commit comments