@@ -18759,6 +18759,66 @@ bool isAxisFusible(int dimension, ArrayRef<Value> vals) {
1875918759 return false;
1876018760}
1876118761
18762+ // slice(extend x) -> extend(slice x)
18763+ // This pattern pushes a slice operation through an extend operation.
18764+ struct ExtendSlice final
18765+ : CheckedOpRewritePattern<stablehlo::SliceOp, ExtendSlice> {
18766+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
18767+
18768+ LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
18769+ PatternRewriter &rewriter) const {
18770+ auto extendOp = op.getOperand().getDefiningOp<enzymexla::ExtendOp>();
18771+ if (!extendOp)
18772+ return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");
18773+
18774+ // This transformation is simplified if strides are 1.
18775+ if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
18776+ return rewriter.notifyMatchFailure(op, "Requires strides of 1");
18777+
18778+ Value operand = extendOp.getOperand();
18779+ auto originalShape = cast<RankedTensorType>(operand.getType()).getShape();
18780+ int64_t d = extendOp.getDimension();
18781+ int64_t lhs = extendOp.getLhs();
18782+ int64_t rhs = extendOp.getRhs();
18783+
18784+ auto starts = op.getStartIndices();
18785+ auto limits = op.getLimitIndices();
18786+
18787+ SmallVector<int64_t> new_starts = llvm::to_vector(starts);
18788+ SmallVector<int64_t> new_limits = llvm::to_vector(limits);
18789+ SmallVector<int64_t> new_strides = llvm::to_vector(op.getStrides());
18790+
18791+ int64_t start_d = starts[d];
18792+ int64_t limit_d = limits[d];
18793+ int64_t size_d = originalShape[d];
18794+
18795+ // Calculate the parameters for the new slice operation on the original
18796+ // operand. The new slice covers the part of the original tensor that is
18797+ // visible in the final output.
18798+ new_starts[d] = std::max((int64_t)0, start_d - lhs);
18799+ new_limits[d] = std::min(size_d, limit_d - lhs);
18800+
18801+ // Calculate the new padding amounts for the extend operation.
18802+ // new_lhs is the size of the overlap between the slice and the prepended
18803+ // padding.
18804+ int64_t new_lhs = std::max((int64_t)0, std::min(limit_d, lhs) - start_d);
18805+ // new_rhs is the size of the overlap between the slice and the appended
18806+ // padding.
18807+ int64_t new_rhs =
18808+ std::max((int64_t)0, limit_d - std::max(start_d, lhs + size_d));
18809+
18810+ // Create the new slice on the original tensor.
18811+ auto newSlice = rewriter.create<stablehlo::SliceOp>(
18812+ op.getLoc(), operand, new_starts, new_limits, new_strides);
18813+
18814+ // Create the new extend on the newly sliced tensor.
18815+ rewriter.replaceOpWithNewOp<enzymexla::ExtendOp>(op, op.getType(), newSlice,
18816+ new_lhs, new_rhs, d);
18817+
18818+ return success();
18819+ }
18820+ };
18821+
1876218822struct SliceExtend final
1876318823 : CheckedOpRewritePattern<enzymexla::ExtendOp, SliceExtend> {
1876418824 using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -22285,6 +22345,7 @@ struct EnzymeHLOOptPass
2228522345 mlir::enzyme::populateWithGenerated(patterns);
2228622346
2228722347 patterns.add<SliceExtend>(context);
22348+ patterns.add<ExtendSlice>(context);
2228822349 patterns.add<SliceRotate>(context);
2228922350 patterns.add<SliceWrap>(context);
2229022351 patterns.add<ReshapeWrap>(context);
0 commit comments