Skip to content

Commit eb6b440

Browse files
committed
Revert "ExtendSlice (#1296)"
This reverts commit 5ee7816.
1 parent 5ee7816 commit eb6b440

File tree

3 files changed

+0
-82
lines changed

3 files changed

+0
-82
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18759,66 +18759,6 @@ bool isAxisFusible(int dimension, ArrayRef<Value> vals) {
1875918759
return false;
1876018760
}
1876118761

18762-
// slice(extend x) -> extend(slice x)
18763-
// This pattern pushes a slice operation through an extend operation.
18764-
struct ExtendSlice final
18765-
: CheckedOpRewritePattern<stablehlo::SliceOp, ExtendSlice> {
18766-
using CheckedOpRewritePattern::CheckedOpRewritePattern;
18767-
18768-
LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
18769-
PatternRewriter &rewriter) const {
18770-
auto extendOp = op.getOperand().getDefiningOp<enzymexla::ExtendOp>();
18771-
if (!extendOp)
18772-
return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");
18773-
18774-
// This transformation is simplified if strides are 1.
18775-
if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
18776-
return rewriter.notifyMatchFailure(op, "Requires strides of 1");
18777-
18778-
Value operand = extendOp.getOperand();
18779-
auto originalShape = cast<RankedTensorType>(operand.getType()).getShape();
18780-
int64_t d = extendOp.getDimension();
18781-
int64_t lhs = extendOp.getLhs();
18782-
int64_t rhs = extendOp.getRhs();
18783-
18784-
auto starts = op.getStartIndices();
18785-
auto limits = op.getLimitIndices();
18786-
18787-
SmallVector<int64_t> new_starts = llvm::to_vector(starts);
18788-
SmallVector<int64_t> new_limits = llvm::to_vector(limits);
18789-
SmallVector<int64_t> new_strides = llvm::to_vector(op.getStrides());
18790-
18791-
int64_t start_d = starts[d];
18792-
int64_t limit_d = limits[d];
18793-
int64_t size_d = originalShape[d];
18794-
18795-
// Calculate the parameters for the new slice operation on the original
18796-
// operand. The new slice covers the part of the original tensor that is
18797-
// visible in the final output.
18798-
new_starts[d] = std::max((int64_t)0, start_d - lhs);
18799-
new_limits[d] = std::min(size_d, limit_d - lhs);
18800-
18801-
// Calculate the new padding amounts for the extend operation.
18802-
// new_lhs is the size of the overlap between the slice and the prepended
18803-
// padding.
18804-
int64_t new_lhs = std::max((int64_t)0, std::min(limit_d, lhs) - start_d);
18805-
// new_rhs is the size of the overlap between the slice and the appended
18806-
// padding.
18807-
int64_t new_rhs =
18808-
std::max((int64_t)0, limit_d - std::max(start_d, lhs + size_d));
18809-
18810-
// Create the new slice on the original tensor.
18811-
auto newSlice = rewriter.create<stablehlo::SliceOp>(
18812-
op.getLoc(), operand, new_starts, new_limits, new_strides);
18813-
18814-
// Create the new extend on the newly sliced tensor.
18815-
rewriter.replaceOpWithNewOp<enzymexla::ExtendOp>(op, op.getType(), newSlice,
18816-
new_lhs, new_rhs, d);
18817-
18818-
return success();
18819-
}
18820-
};
18821-
1882218762
struct SliceExtend final
1882318763
: CheckedOpRewritePattern<enzymexla::ExtendOp, SliceExtend> {
1882418764
using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -22345,7 +22285,6 @@ struct EnzymeHLOOptPass
2234522285
mlir::enzyme::populateWithGenerated(patterns);
2234622286

2234722287
patterns.add<SliceExtend>(context);
22348-
patterns.add<ExtendSlice>(context);
2234922288
patterns.add<SliceRotate>(context);
2235022289
patterns.add<SliceWrap>(context);
2235122290
patterns.add<ReshapeWrap>(context);

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,11 +1868,6 @@ def SliceExtend : EnzymeHLOPatternOp<
18681868
let patterns = ["SliceExtend"];
18691869
}
18701870

1871-
def ExtendSlice : EnzymeHLOPatternOp<
1872-
"extend_slice"> {
1873-
let patterns = ["ExtendSlice"];
1874-
}
1875-
18761871
def SliceRotate : EnzymeHLOPatternOp<
18771872
"slice_rotate"> {
18781873
let patterns = ["SliceRotate"];

test/lit_tests/extendslice.mlir

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)