@@ -18759,6 +18759,63 @@ bool isAxisFusible(int dimension, ArrayRef<Value> vals) {
1875918759 return false;
1876018760}
1876118761
18762+ // slice(extend x) -> extend(slice x)
18763+ // This pattern pushes a slice operation through an extend operation.
18764+ struct ExtendSlice final
18765+ : CheckedOpRewritePattern<stablehlo::SliceOp, ExtendSlice> {
18766+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
18767+
18768+ LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
18769+ PatternRewriter &rewriter) const {
18770+ auto extendOp = op.getOperand().getDefiningOp<enzymexla::ExtendOp>();
18771+ if (!extendOp)
18772+ return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");
18773+
18774+ // This transformation is simplified if strides are 1.
18775+ if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
18776+ return rewriter.notifyMatchFailure(op, "Requires strides of 1");
18777+
18778+ Value operand = extendOp.getOperand();
18779+ auto originalShape = cast<RankedTensorType>(operand.getType()).getShape();
18780+ int64_t d = extendOp.getDimension();
18781+ int64_t lhs = extendOp.getLhs();
18782+ int64_t rhs = extendOp.getRhs();
18783+
18784+ auto starts = op.getStartIndices();
18785+ auto limits = op.getLimitIndices();
18786+
18787+ SmallVector<int64_t> new_starts = llvm::to_vector(starts);
18788+ SmallVector<int64_t> new_limits = llvm::to_vector(limits);
18789+ SmallVector<int64_t> new_strides = llvm::to_vector(op.getStrides());
18790+
18791+ int64_t start_d = starts[d];
18792+ int64_t limit_d = limits[d];
18793+ int64_t size_d = originalShape[d];
18794+
18795+ // Calculate the parameters for the new slice operation on the original operand.
18796+ // The new slice covers the part of the original tensor that is visible in the final output.
18797+ new_starts[d] = std::max((int64_t)0, start_d - lhs);
18798+ new_limits[d] = std::min(size_d, limit_d - lhs);
18799+
18800+ // Calculate the new padding amounts for the extend operation.
18801+ // new_lhs is the size of the overlap between the slice and the prepended padding.
18802+ int64_t new_lhs = std::max((int64_t)0, std::min(limit_d, lhs) - start_d);
18803+ // new_rhs is the size of the overlap between the slice and the appended padding.
18804+ int64_t new_rhs = std::max((int64_t)0, limit_d - std::max(start_d, lhs + size_d));
18805+
18806+ // Create the new slice on the original tensor.
18807+ auto newSlice = rewriter.create<stablehlo::SliceOp>(
18808+ op.getLoc(), operand, new_starts, new_limits, new_strides);
18809+
18810+ // Create the new extend on the newly sliced tensor.
18811+ rewriter.replaceOpWithNewOp<enzymexla::ExtendOp>(
18812+ op, op.getType(), newSlice, new_lhs, new_rhs, d);
18813+
18814+ return success();
18815+ }
18816+ };
18817+
18818+
1876218819struct SliceExtend final
1876318820 : CheckedOpRewritePattern<enzymexla::ExtendOp, SliceExtend> {
1876418821 using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -22285,6 +22342,7 @@ struct EnzymeHLOOptPass
2228522342 mlir::enzyme::populateWithGenerated(patterns);
2228622343
2228722344 patterns.add<SliceExtend>(context);
22345+ patterns.add<ExtendSlice>(context);
2228822346 patterns.add<SliceRotate>(context);
2228922347 patterns.add<SliceWrap>(context);
2229022348 patterns.add<ReshapeWrap>(context);
0 commit comments