Skip to content

Commit b65e5f2

Browse files
committed
don't apply if ExtendOp has multiple uses
1 parent 115bfa1 commit b65e5f2

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18771,6 +18771,10 @@ struct ExtendSlice final
1877118771
if (!extendOp)
1877218772
return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");
1877318773

18774+
if (extendOp.getResult().getNumUses() > 1)
18775+
return rewriter.notifyMatchFailure(
18776+
op, "ExtendOp result is used multiple times");
18777+
1877418778
// This transformation is simplified if strides are 1.
1877518779
if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
1877618780
return rewriter.notifyMatchFailure(op, "Requires strides of 1");

test/lit_tests/extendslice.mlir

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=extend_slice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
22

3-
4-
// CHECK: func.func @f(%arg0: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
3+
// CHECK: func.func @f_single_use(%arg0: tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64> {
54
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:2, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<2x1520x3056xf64>
65
// CHECK-NEXT: %1 = "enzymexla.extend"(%0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<2x1520x3056xf64>) -> tensor<3x1520x3056xf64>
7-
// CHECK-NEXT: %2 = stablehlo.slice %arg0 [0:3, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64>
8-
// CHECK-NEXT: %3 = "enzymexla.extend"(%2) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<3x1520x3056xf64>) -> tensor<4x1520x3056xf64>
9-
// CHECK-NEXT: return %1, %3 : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
6+
// CHECK-NEXT: return %1 : tensor<3x1520x3056xf64>
7+
// CHECK-NEXT: }
8+
func.func @f_single_use(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>) {
9+
%b = "enzymexla.extend"(%a) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
10+
%c = stablehlo.slice %b [0:3, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
11+
return %c : tensor<3x1520x3056xf64>
12+
}
13+
14+
15+
// CHECK: func.func @f_multiple_uses(%arg0: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
16+
// CHECK-NEXT: %0 = "enzymexla.extend"(%arg0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
17+
// CHECK-NEXT: %1 = stablehlo.slice %0 [0:3, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
18+
// CHECK-NEXT: %2 = stablehlo.slice %0 [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>
19+
// CHECK-NEXT: return %1, %2 : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
1020
// CHECK-NEXT: }
11-
func.func @f(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
21+
func.func @f_multiple_uses(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
1222
%b = "enzymexla.extend"(%a) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
1323
%c = stablehlo.slice %b [0:3, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
1424
%d = stablehlo.slice %b [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>

0 commit comments

Comments
 (0)