Skip to content

Commit d843527

Browse files
committed
Simplify slicebroadcast if no new op is created
1 parent 99d2b63 commit d843527

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,9 +2870,6 @@ struct SliceBroadcast final
28702870
if (!bcast)
28712871
return failure();
28722872

2873-
if (!llvm::hasSingleElement(bcast->getUsers()))
2874-
return failure();
2875-
28762873
SmallVector<int64_t> nbcast_idx;
28772874

28782875
auto preShape = cast<RankedTensorType>(bcast.getOperand().getType());

test/lit_tests/slicebroadcast.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ module {
1313
%2389 = stablehlo.slice %2388 [0:1, 0:8, 0:3, 0:1024, 1024:2048] : (tensor<1x8x3x1024x2048xf32>) -> tensor<1x8x3x1024x1024xf32>
1414
return %2389 : tensor<1x8x3x1024x1024xf32>
1515
}
16+
17+
func.func @main3(%20: tensor<f64>) -> (tensor<3056x6128xf64>, tensor<3055x6128xf64>) {
18+
%28 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor<f64>) -> tensor<3056x6128xf64>
19+
%31 = stablehlo.slice %28 [1:3056, 0:6128] : (tensor<3056x6128xf64>) -> tensor<3055x6128xf64>
20+
return %28, %31 : tensor<3056x6128xf64>, tensor<3055x6128xf64>
21+
}
1622
}
1723

1824
// CHECK: func.func @main(%arg0: tensor<2x3x50xf32>) -> tensor<4x1x25x15x2x3xf32> {
@@ -25,3 +31,10 @@ module {
2531
// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2, 3, 4] : (tensor<1x8x3x1024x1xf32>) -> tensor<1x8x3x1024x1024xf32>
2632
// CHECK-NEXT: return %0 : tensor<1x8x3x1024x1024xf32>
2733
// CHECK-NEXT: }
34+
35+
// CHECK: func.func @main3(%arg0: tensor<f64>) -> (tensor<3056x6128xf64>, tensor<3055x6128xf64>) {
36+
// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f64>) -> tensor<3056x6128xf64>
37+
// CHECK-NEXT: %1 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f64>) -> tensor<3055x6128xf64>
38+
// CHECK-NEXT: return %0, %1 : tensor<3056x6128xf64>, tensor<3055x6128xf64>
39+
// CHECK-NEXT: }
40+

0 commit comments

Comments
 (0)