|
6 | 6 | // Also available under a BSD-style license. See LICENSE. |
7 | 7 | // |
8 | 8 | //===----------------------------------------------------------------------===// |
| 9 | +#include "llvm/ADT/SmallVector.h" |
9 | 10 | #define DEBUG_TYPE "torch-mlir-torch-dialect" |
10 | 11 | #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" |
11 | 12 | #include "torch-mlir/Dialect/Torch/Utils/Utils.h" |
@@ -4721,6 +4722,122 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) { |
4721 | 4722 | return DenseElementsAttr::get(attrty, attrs); |
4722 | 4723 | } |
4723 | 4724 |
|
| 4725 | +namespace { |
| 4726 | +class CanonicalizeConvolutionWithSingleIntTuple |
| 4727 | + : public OpRewritePattern<AtenConvolutionOp> { |
| 4728 | +public: |
| 4729 | + using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern; |
| 4730 | + |
| 4731 | + LogicalResult matchAndRewrite(AtenConvolutionOp op, |
| 4732 | + PatternRewriter &rewriter) const override { |
| 4733 | + |
| 4734 | + auto weight = op.getWeight(); |
| 4735 | + auto weightType = dyn_cast<ValueTensorType>(weight.getType()); |
| 4736 | + |
| 4737 | + if (!weightType) { |
| 4738 | + return rewriter.notifyMatchFailure(op, "weight is not a vtensor"); |
| 4739 | + } |
| 4740 | + auto optionalSizes = weightType.getOptionalSizes(); |
| 4741 | + if (!optionalSizes.has_value()) { |
| 4742 | + return rewriter.notifyMatchFailure(op, |
| 4743 | + "unranked weight tensor unsupported!"); |
| 4744 | + } |
| 4745 | + |
| 4746 | + // The rank is the size of the dimensions array |
| 4747 | + int64_t weightRank = optionalSizes.value().size(); |
| 4748 | + |
| 4749 | + // We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv). |
| 4750 | + if (weightRank < 4 || weightRank > 5) { |
| 4751 | + return rewriter.notifyMatchFailure( |
| 4752 | + op, "unsupported weight rank (must be 4 or 5)"); |
| 4753 | + } |
| 4754 | + int64_t requiredSpatialDims = weightRank - 2; |
| 4755 | + |
| 4756 | + // Validate stride, padding, output_padding, and dilation are constant |
| 4757 | + // lists. |
| 4758 | + SmallVector<int64_t> strideInts; |
| 4759 | + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) { |
| 4760 | + return rewriter.notifyMatchFailure(op, |
| 4761 | + "non-const int stride unsupported!"); |
| 4762 | + } |
| 4763 | + SmallVector<int64_t> paddingInts; |
| 4764 | + if (!matchPattern(op.getPadding(), |
| 4765 | + m_TorchListOfConstantInts(paddingInts))) { |
| 4766 | + return rewriter.notifyMatchFailure(op, |
| 4767 | + "non-const int padding unsupported!"); |
| 4768 | + } |
| 4769 | + SmallVector<int64_t> outputPaddingInts; |
| 4770 | + if (!matchPattern(op.getOutputPadding(), |
| 4771 | + m_TorchListOfConstantInts(outputPaddingInts))) { |
| 4772 | + return rewriter.notifyMatchFailure( |
| 4773 | + op, "non-const int output_padding unsupported!"); |
| 4774 | + } |
| 4775 | + SmallVector<int64_t> dilationInts; |
| 4776 | + if (!matchPattern(op.getDilation(), |
| 4777 | + m_TorchListOfConstantInts(dilationInts))) { |
| 4778 | + return rewriter.notifyMatchFailure(op, |
| 4779 | + "non-const int dilation unsupported!"); |
| 4780 | + } |
| 4781 | + |
| 4782 | + // Canonicalization Logic: Only rewrite if padding provided is 1 element |
| 4783 | + // but the convolution requires 2 or 3 elements. |
| 4784 | + if (strideInts.size() == static_cast<size_t>(requiredSpatialDims)) { |
| 4785 | + return rewriter.notifyMatchFailure(op, |
| 4786 | + "stride is already fully specified"); |
| 4787 | + } |
| 4788 | + if (paddingInts.size() == static_cast<size_t>(requiredSpatialDims)) { |
| 4789 | + return rewriter.notifyMatchFailure(op, |
| 4790 | + "padding is already fully specified"); |
| 4791 | + } |
| 4792 | + if (outputPaddingInts.size() == static_cast<size_t>(requiredSpatialDims)) { |
| 4793 | + return rewriter.notifyMatchFailure( |
| 4794 | + op, "output_padding is already fully specified"); |
| 4795 | + } |
| 4796 | + if (dilationInts.size() == static_cast<size_t>(requiredSpatialDims)) { |
| 4797 | + return rewriter.notifyMatchFailure(op, |
| 4798 | + "dialtion is already fully specified"); |
| 4799 | + } |
| 4800 | + |
| 4801 | + // Construct the new Padding List |
| 4802 | + // If user provided padding=[1], and we need 2 or 3 dims, we create |
| 4803 | + // padding=[1, 1] or padding = [1,1,1] |
| 4804 | + int64_t padVal = paddingInts[0]; |
| 4805 | + Location loc = op.getLoc(); |
| 4806 | + |
| 4807 | + SmallVector<Value> newPaddingValues; |
| 4808 | + Value paddingConst = ConstantIntOp::create( |
| 4809 | + rewriter, loc, rewriter.getI64IntegerAttr(padVal)); |
| 4810 | + |
| 4811 | + for (int i = 0; i < requiredSpatialDims; ++i) { |
| 4812 | + newPaddingValues.push_back(paddingConst); |
| 4813 | + } |
| 4814 | + |
| 4815 | + // Create the list construct op |
| 4816 | + auto newListOp = PrimListConstructOp::create( |
| 4817 | + rewriter, loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()), |
| 4818 | + newPaddingValues); |
| 4819 | + |
| 4820 | + // Replace the Op |
| 4821 | + // We create a new convolution op, keeping all operands the same except |
| 4822 | + // padding |
| 4823 | + rewriter.replaceOpWithNewOp<AtenConvolutionOp>( |
| 4824 | + op, op.getType(), op.getInput(), op.getWeight(), op.getBias(), |
| 4825 | + op.getStride(), newListOp.getResult(), op.getDilation(), |
| 4826 | + op.getTransposed(), op.getOutputPadding(), op.getGroups()); |
| 4827 | + |
| 4828 | + return success(); |
| 4829 | + } |
| 4830 | +}; |
| 4831 | +} // namespace |
| 4832 | + |
| 4833 | +//===----------------------------------------------------------------------===// |
| 4834 | +// AtenConvolutionOp Registration |
| 4835 | +//===----------------------------------------------------------------------===// |
| 4836 | +void AtenConvolutionOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 4837 | + MLIRContext *context) { |
| 4838 | + results.add<CanonicalizeConvolutionWithSingleIntTuple>(context); |
| 4839 | +} |
| 4840 | + |
4724 | 4841 | //===----------------------------------------------------------------------===// |
4725 | 4842 | // AtenIntTensorOp |
4726 | 4843 | //===----------------------------------------------------------------------===// |
|
0 commit comments