@@ -1759,6 +1759,65 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
1759
1759
return success ();
1760
1760
}
1761
1761
1762
+ template <>
1763
+ LogicalResult ConvertAtenOp<AtenReflectionPad1dOp>::matchAndRewrite(
1764
+ AtenReflectionPad1dOp op, OpAdaptor adaptor,
1765
+ ConversionPatternRewriter &rewriter) const {
1766
+ Location loc = op.getLoc ();
1767
+ Value self = adaptor.getSelf ();
1768
+ auto selfTy = cast<RankedTensorType>(self.getType ());
1769
+ if (!selfTy.hasStaticShape ()) {
1770
+ return rewriter.notifyMatchFailure (op, " only support static shape" );
1771
+ }
1772
+ int64_t rank = selfTy.getRank ();
1773
+ int64_t dim = rank - 1 ;
1774
+
1775
+ SmallVector<int64_t > padInts;
1776
+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts))) {
1777
+ return rewriter.notifyMatchFailure (op,
1778
+ " only support constant int pad ranges" );
1779
+ }
1780
+ if (padInts.size () != 2 ) {
1781
+ return rewriter.notifyMatchFailure (op, " pad size must be 2" );
1782
+ }
1783
+ if (padInts[0 ] >= selfTy.getDimSize (dim) ||
1784
+ padInts[1 ] >= selfTy.getDimSize (dim)) {
1785
+ return rewriter.notifyMatchFailure (op,
1786
+ " pad size must be less than dim size" );
1787
+ }
1788
+
1789
+ Value left;
1790
+ {
1791
+ SmallVector<int64_t > startIndices (rank, 0 );
1792
+ SmallVector<int64_t > limitIndices (selfTy.getShape ().begin (),
1793
+ selfTy.getShape ().end ());
1794
+ SmallVector<int64_t > strides (rank, 1 );
1795
+ startIndices[dim] = 1 ;
1796
+ limitIndices[dim] = padInts[0 ] + 1 ;
1797
+ left = rewriter.create <stablehlo::SliceOp>(loc, self, startIndices,
1798
+ limitIndices, strides);
1799
+ left = rewriter.create <stablehlo::ReverseOp>(loc, left,
1800
+ ArrayRef<int64_t >({dim}));
1801
+ }
1802
+ Value right;
1803
+ {
1804
+ SmallVector<int64_t > startIndices (rank, 0 );
1805
+ SmallVector<int64_t > limitIndices (selfTy.getShape ().begin (),
1806
+ selfTy.getShape ().end ());
1807
+ SmallVector<int64_t > strides (rank, 1 );
1808
+ startIndices[dim] = selfTy.getDimSize (dim) - 1 - padInts[1 ];
1809
+ limitIndices[dim] = selfTy.getDimSize (dim) - 1 ;
1810
+ right = rewriter.create <stablehlo::SliceOp>(loc, self, startIndices,
1811
+ limitIndices, strides);
1812
+ right = rewriter.create <stablehlo::ReverseOp>(loc, right,
1813
+ ArrayRef<int64_t >({dim}));
1814
+ }
1815
+ Value result = rewriter.create <stablehlo::ConcatenateOp>(
1816
+ loc, ValueRange{left, self, right}, dim);
1817
+ rewriter.replaceOp (op, result);
1818
+ return success ();
1819
+ }
1820
+
1762
1821
template <>
1763
1822
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
1764
1823
AtenGeluBackwardOp op, OpAdaptor adaptor,
@@ -2269,6 +2328,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
2269
2328
INSERT_ATENOP_PATTERN (AtenScalarImplicitOp);
2270
2329
INSERT_ATENOP_PATTERN (AtenContiguousOp);
2271
2330
INSERT_ATENOP_PATTERN (AtenConstantPadNdOp);
2331
+ INSERT_ATENOP_PATTERN (AtenReflectionPad1dOp);
2272
2332
2273
2333
INSERT_ATENOP_PATTERN (AtenReluOp);
2274
2334
INSERT_ATENOP_PATTERN (AtenGeluOp);
0 commit comments