@@ -18759,66 +18759,6 @@ bool isAxisFusible(int dimension, ArrayRef<Value> vals) {
1875918759 return false;
1876018760}
1876118761
18762- // slice(extend x) -> extend(slice x)
18763- // This pattern pushes a slice operation through an extend operation.
18764- struct ExtendSlice final
18765- : CheckedOpRewritePattern<stablehlo::SliceOp, ExtendSlice> {
18766- using CheckedOpRewritePattern::CheckedOpRewritePattern;
18767-
18768- LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
18769- PatternRewriter &rewriter) const {
18770- auto extendOp = op.getOperand().getDefiningOp<enzymexla::ExtendOp>();
18771- if (!extendOp)
18772- return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");
18773-
18774- // This transformation is simplified if strides are 1.
18775- if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
18776- return rewriter.notifyMatchFailure(op, "Requires strides of 1");
18777-
18778- Value operand = extendOp.getOperand();
18779- auto originalShape = cast<RankedTensorType>(operand.getType()).getShape();
18780- int64_t d = extendOp.getDimension();
18781- int64_t lhs = extendOp.getLhs();
18782- int64_t rhs = extendOp.getRhs();
18783-
18784- auto starts = op.getStartIndices();
18785- auto limits = op.getLimitIndices();
18786-
18787- SmallVector<int64_t> new_starts = llvm::to_vector(starts);
18788- SmallVector<int64_t> new_limits = llvm::to_vector(limits);
18789- SmallVector<int64_t> new_strides = llvm::to_vector(op.getStrides());
18790-
18791- int64_t start_d = starts[d];
18792- int64_t limit_d = limits[d];
18793- int64_t size_d = originalShape[d];
18794-
18795- // Calculate the parameters for the new slice operation on the original
18796- // operand. The new slice covers the part of the original tensor that is
18797- // visible in the final output.
18798- new_starts[d] = std::max((int64_t)0, start_d - lhs);
18799- new_limits[d] = std::min(size_d, limit_d - lhs);
18800-
18801- // Calculate the new padding amounts for the extend operation.
18802- // new_lhs is the size of the overlap between the slice and the prepended
18803- // padding.
18804- int64_t new_lhs = std::max((int64_t)0, std::min(limit_d, lhs) - start_d);
18805- // new_rhs is the size of the overlap between the slice and the appended
18806- // padding.
18807- int64_t new_rhs =
18808- std::max((int64_t)0, limit_d - std::max(start_d, lhs + size_d));
18809-
18810- // Create the new slice on the original tensor.
18811- auto newSlice = rewriter.create<stablehlo::SliceOp>(
18812- op.getLoc(), operand, new_starts, new_limits, new_strides);
18813-
18814- // Create the new extend on the newly sliced tensor.
18815- rewriter.replaceOpWithNewOp<enzymexla::ExtendOp>(op, op.getType(), newSlice,
18816- new_lhs, new_rhs, d);
18817-
18818- return success();
18819- }
18820- };
18821-
1882218762struct SliceExtend final
1882318763 : CheckedOpRewritePattern<enzymexla::ExtendOp, SliceExtend> {
1882418764 using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -22345,7 +22285,6 @@ struct EnzymeHLOOptPass
2234522285 mlir::enzyme::populateWithGenerated(patterns);
2234622286
2234722287 patterns.add<SliceExtend>(context);
22348- patterns.add<ExtendSlice>(context);
2234922288 patterns.add<SliceRotate>(context);
2235022289 patterns.add<SliceWrap>(context);
2235122290 patterns.add<ReshapeWrap>(context);
0 commit comments