Skip to content

Commit b1053f8

Browse files
authored
[Stablehlo] support lowering aten.reflection_pad1d (#4295)
1 parent 155680c commit b1053f8

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

lib/Conversion/TorchToStablehlo/Basic.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,65 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
17591759
return success();
17601760
}
17611761

1762+
template <>
1763+
LogicalResult ConvertAtenOp<AtenReflectionPad1dOp>::matchAndRewrite(
1764+
AtenReflectionPad1dOp op, OpAdaptor adaptor,
1765+
ConversionPatternRewriter &rewriter) const {
1766+
Location loc = op.getLoc();
1767+
Value self = adaptor.getSelf();
1768+
auto selfTy = cast<RankedTensorType>(self.getType());
1769+
if (!selfTy.hasStaticShape()) {
1770+
return rewriter.notifyMatchFailure(op, "only support static shape");
1771+
}
1772+
int64_t rank = selfTy.getRank();
1773+
int64_t dim = rank - 1;
1774+
1775+
SmallVector<int64_t> padInts;
1776+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) {
1777+
return rewriter.notifyMatchFailure(op,
1778+
"only support constant int pad ranges");
1779+
}
1780+
if (padInts.size() != 2) {
1781+
return rewriter.notifyMatchFailure(op, "pad size must be 2");
1782+
}
1783+
if (padInts[0] >= selfTy.getDimSize(dim) ||
1784+
padInts[1] >= selfTy.getDimSize(dim)) {
1785+
return rewriter.notifyMatchFailure(op,
1786+
"pad size must be less than dim size");
1787+
}
1788+
1789+
Value left;
1790+
{
1791+
SmallVector<int64_t> startIndices(rank, 0);
1792+
SmallVector<int64_t> limitIndices(selfTy.getShape().begin(),
1793+
selfTy.getShape().end());
1794+
SmallVector<int64_t> strides(rank, 1);
1795+
startIndices[dim] = 1;
1796+
limitIndices[dim] = padInts[0] + 1;
1797+
left = rewriter.create<stablehlo::SliceOp>(loc, self, startIndices,
1798+
limitIndices, strides);
1799+
left = rewriter.create<stablehlo::ReverseOp>(loc, left,
1800+
ArrayRef<int64_t>({dim}));
1801+
}
1802+
Value right;
1803+
{
1804+
SmallVector<int64_t> startIndices(rank, 0);
1805+
SmallVector<int64_t> limitIndices(selfTy.getShape().begin(),
1806+
selfTy.getShape().end());
1807+
SmallVector<int64_t> strides(rank, 1);
1808+
startIndices[dim] = selfTy.getDimSize(dim) - 1 - padInts[1];
1809+
limitIndices[dim] = selfTy.getDimSize(dim) - 1;
1810+
right = rewriter.create<stablehlo::SliceOp>(loc, self, startIndices,
1811+
limitIndices, strides);
1812+
right = rewriter.create<stablehlo::ReverseOp>(loc, right,
1813+
ArrayRef<int64_t>({dim}));
1814+
}
1815+
Value result = rewriter.create<stablehlo::ConcatenateOp>(
1816+
loc, ValueRange{left, self, right}, dim);
1817+
rewriter.replaceOp(op, result);
1818+
return success();
1819+
}
1820+
17621821
template <>
17631822
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
17641823
AtenGeluBackwardOp op, OpAdaptor adaptor,
@@ -2269,6 +2328,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
22692328
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
22702329
INSERT_ATENOP_PATTERN(AtenContiguousOp);
22712330
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
2331+
INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp);
22722332

22732333
INSERT_ATENOP_PATTERN(AtenReluOp);
22742334
INSERT_ATENOP_PATTERN(AtenGeluOp);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -829,10 +829,6 @@
829829
"RandnLikeDtypeModule_basic",
830830
"RandnLikeModule_basic",
831831
"RandnModule_basic",
832-
"ReflectionPad1dModule2dInput_Right",
833-
"ReflectionPad1dModule2dInput_basic",
834-
"ReflectionPad1dModule3dInput_Left",
835-
"ReflectionPad1dModule3dInput_basic",
836832
"ReflectionPad2dModule_Bottom",
837833
"ReflectionPad2dModule_Left",
838834
"ReflectionPad2dModule_Right",

0 commit comments

Comments
 (0)