Skip to content

Commit 545a89e

Browse files
committed
drop the extend if it becomes a no-op.
1 parent b65e5f2 commit 545a89e

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18771,10 +18771,6 @@ struct ExtendSlice final
1877118771
if (!extendOp)
1877218772
return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");
1877318773

18774-
if (extendOp.getResult().getNumUses() > 1)
18775-
return rewriter.notifyMatchFailure(
18776-
op, "ExtendOp result is used multiple times");
18777-
1877818774
// This transformation is simplified if strides are 1.
1877918775
if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
1878018776
return rewriter.notifyMatchFailure(op, "Requires strides of 1");
@@ -18811,6 +18807,17 @@ struct ExtendSlice final
1881118807
int64_t new_rhs =
1881218808
std::max((int64_t)0, limit_d - std::max(start_d, lhs + size_d));
1881318809

18810+
if (new_lhs == 0 && new_rhs == 0) {
18811+
auto newSlice = rewriter.replaceOpWithNewOp<stablehlo::SliceOp>(
18812+
op, op.getType(), operand, new_starts, new_limits, new_strides);
18813+
return success();
18814+
}
18815+
18816+
if (extendOp.getResult().getNumUses() > 1) {
18817+
return rewriter.notifyMatchFailure(
18818+
op, "ExtendOp result is used multiple times");
18819+
}
18820+
1881418821
// Create the new slice on the original tensor.
1881518822
auto newSlice = rewriter.create<stablehlo::SliceOp>(
1881618823
op.getLoc(), operand, new_starts, new_limits, new_strides);

test/lit_tests/extendslice.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,16 @@ func.func @f_multiple_uses(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf
2424
%d = stablehlo.slice %b [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>
2525
return %c, %d : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
2626
}
27+
28+
// CHECK: func.func @f_multiple_uses_superfluous_extend(%arg0: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
29+
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:3, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64>
30+
// CHECK-NEXT: %1 = "enzymexla.extend"(%0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<3x1520x3056xf64>) -> tensor<4x1520x3056xf64>
31+
// CHECK-NEXT: %2 = stablehlo.slice %arg0 [0:3, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64>
32+
// CHECK-NEXT: return %2, %1 : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
33+
// CHECK-NEXT: }
34+
func.func @f_multiple_uses_superfluous_extend(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
35+
%b = "enzymexla.extend"(%a) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
36+
%c = stablehlo.slice %b [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>
37+
%d = stablehlo.slice %b [1:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
38+
return %d, %c : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
39+
}

0 commit comments

Comments
 (0)