Skip to content

Commit 5ee7816

Browse files
jumerckxwsmoses
andauthored
ExtendSlice (#1296)
* ExtendSlice * formatting --------- Co-authored-by: William Moses <[email protected]>
1 parent 6718483 commit 5ee7816

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18759,6 +18759,66 @@ bool isAxisFusible(int dimension, ArrayRef<Value> vals) {
1875918759
return false;
1876018760
}
1876118761

18762+
// slice(extend x) -> extend(slice x)
18763+
// This pattern pushes a slice operation through an extend operation.
18764+
struct ExtendSlice final
18765+
: CheckedOpRewritePattern<stablehlo::SliceOp, ExtendSlice> {
18766+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
18767+
18768+
LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
18769+
PatternRewriter &rewriter) const {
18770+
auto extendOp = op.getOperand().getDefiningOp<enzymexla::ExtendOp>();
18771+
if (!extendOp)
18772+
return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");
18773+
18774+
// This transformation is simplified if strides are 1.
18775+
if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
18776+
return rewriter.notifyMatchFailure(op, "Requires strides of 1");
18777+
18778+
Value operand = extendOp.getOperand();
18779+
auto originalShape = cast<RankedTensorType>(operand.getType()).getShape();
18780+
int64_t d = extendOp.getDimension();
18781+
int64_t lhs = extendOp.getLhs();
18782+
int64_t rhs = extendOp.getRhs();
18783+
18784+
auto starts = op.getStartIndices();
18785+
auto limits = op.getLimitIndices();
18786+
18787+
SmallVector<int64_t> new_starts = llvm::to_vector(starts);
18788+
SmallVector<int64_t> new_limits = llvm::to_vector(limits);
18789+
SmallVector<int64_t> new_strides = llvm::to_vector(op.getStrides());
18790+
18791+
int64_t start_d = starts[d];
18792+
int64_t limit_d = limits[d];
18793+
int64_t size_d = originalShape[d];
18794+
18795+
// Calculate the parameters for the new slice operation on the original
18796+
// operand. The new slice covers the part of the original tensor that is
18797+
// visible in the final output.
18798+
new_starts[d] = std::max((int64_t)0, start_d - lhs);
18799+
new_limits[d] = std::min(size_d, limit_d - lhs);
18800+
18801+
// Calculate the new padding amounts for the extend operation.
18802+
// new_lhs is the size of the overlap between the slice and the prepended
18803+
// padding.
18804+
int64_t new_lhs = std::max((int64_t)0, std::min(limit_d, lhs) - start_d);
18805+
// new_rhs is the size of the overlap between the slice and the appended
18806+
// padding.
18807+
int64_t new_rhs =
18808+
std::max((int64_t)0, limit_d - std::max(start_d, lhs + size_d));
18809+
18810+
// Create the new slice on the original tensor.
18811+
auto newSlice = rewriter.create<stablehlo::SliceOp>(
18812+
op.getLoc(), operand, new_starts, new_limits, new_strides);
18813+
18814+
// Create the new extend on the newly sliced tensor.
18815+
rewriter.replaceOpWithNewOp<enzymexla::ExtendOp>(op, op.getType(), newSlice,
18816+
new_lhs, new_rhs, d);
18817+
18818+
return success();
18819+
}
18820+
};
18821+
1876218822
struct SliceExtend final
1876318823
: CheckedOpRewritePattern<enzymexla::ExtendOp, SliceExtend> {
1876418824
using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -22285,6 +22345,7 @@ struct EnzymeHLOOptPass
2228522345
mlir::enzyme::populateWithGenerated(patterns);
2228622346

2228722347
patterns.add<SliceExtend>(context);
22348+
patterns.add<ExtendSlice>(context);
2228822349
patterns.add<SliceRotate>(context);
2228922350
patterns.add<SliceWrap>(context);
2229022351
patterns.add<ReshapeWrap>(context);

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,11 @@ def SliceExtend : EnzymeHLOPatternOp<
18681868
let patterns = ["SliceExtend"];
18691869
}
18701870

1871+
def ExtendSlice : EnzymeHLOPatternOp<
1872+
"extend_slice"> {
1873+
let patterns = ["ExtendSlice"];
1874+
}
1875+
18711876
def SliceRotate : EnzymeHLOPatternOp<
18721877
"slice_rotate"> {
18731878
let patterns = ["SliceRotate"];

test/lit_tests/extendslice.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=extend_slice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
2+
3+
4+
// CHECK: func.func @f(%arg0: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
5+
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:2, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<2x1520x3056xf64>
6+
// CHECK-NEXT: %1 = "enzymexla.extend"(%0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<2x1520x3056xf64>) -> tensor<3x1520x3056xf64>
7+
// CHECK-NEXT: %2 = stablehlo.slice %arg0 [0:3, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64>
8+
// CHECK-NEXT: %3 = "enzymexla.extend"(%2) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<3x1520x3056xf64>) -> tensor<4x1520x3056xf64>
9+
// CHECK-NEXT: return %1, %3 : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
10+
// CHECK-NEXT: }
11+
func.func @f(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
12+
%b = "enzymexla.extend"(%a) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
13+
%c = stablehlo.slice %b [0:3, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
14+
%d = stablehlo.slice %b [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>
15+
return %c, %d : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
16+
}

0 commit comments

Comments
 (0)