Skip to content

Commit c051c60

Browse files
authored
feat: enable raising more operations from loops (#1567)
* feat: enable raising more operations from loops * test: add test cases
1 parent bca1e1c commit c051c60

File tree

5 files changed

+128
-5
lines changed

5 files changed

+128
-5
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -867,11 +867,8 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
867867
if (!validReshapes)
868868
continue;
869869

870-
// TODO: add scatter here once batch interface is
871-
if (isa<stablehlo::DotGeneralOp, stablehlo::GatherOp, stablehlo::ReduceOp,
872-
stablehlo::SortOp, stablehlo::TransposeOp,
873-
stablehlo::BroadcastInDimOp, stablehlo::ReduceWindowOp>(op) ||
874-
op->hasTrait<OpTrait::Elementwise>()) {
870+
auto batchInterface = dyn_cast<BatchOpInterface>(op);
871+
if (batchInterface || op->hasTrait<OpTrait::Elementwise>()) {
875872
if (liftOperationByBatching(rewriter, whileOp, slices, op, info,
876873
intermediateReshape)) {
877874
anyOpRewritten = true;
@@ -1247,6 +1244,9 @@ struct AutoBatchingPass
12471244
// op interface is implemented
12481245
ConcatInsertDimToBatch<stablehlo::SortOp>,
12491246
ConcatInsertDimToBatch<stablehlo::ReduceWindowOp>,
1247+
ConcatInsertDimToBatch<stablehlo::ConcatenateOp>,
1248+
ConcatInsertDimToBatch<stablehlo::GetDimensionSizeOp>,
1249+
ConcatInsertDimToBatch<stablehlo::ReverseOp>,
12501250
ConcatInsertDimElementwiseToBatch>(context);
12511251
}
12521252

@@ -1258,6 +1258,9 @@ struct AutoBatchingPass
12581258
SliceToBatch<stablehlo::TransposeOp>,
12591259
SliceToBatch<stablehlo::BroadcastInDimOp>,
12601260
SliceToBatch<stablehlo::ReduceWindowOp>,
1261+
SliceToBatch<stablehlo::ConcatenateOp>,
1262+
SliceToBatch<stablehlo::GetDimensionSizeOp>,
1263+
SliceToBatch<stablehlo::ReverseOp>,
12611264
// SliceToBatchReshape,
12621265
SliceToBatchElementwise>(context);
12631266
}

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,6 +2418,18 @@ def ApplyConcatInsertDimSortPatterns : EnzymeHLOPatternOp<
24182418
"concat_insert_dim_sort"> {
24192419
let patterns = ["ConcatInsertDimToBatch<stablehlo::SortOp>"];
24202420
}
2421+
def ApplyConcatInsertDimConcatenatePatterns : EnzymeHLOPatternOp<
2422+
"concat_insert_dim_concatenate"> {
2423+
let patterns = ["ConcatInsertDimToBatch<stablehlo::ConcatenateOp>"];
2424+
}
2425+
def ApplyConcatInsertDimGetDimensionSizePatterns : EnzymeHLOPatternOp<
2426+
"concat_insert_dim_get_dimension_size"> {
2427+
let patterns = ["ConcatInsertDimToBatch<stablehlo::GetDimensionSizeOp>"];
2428+
}
2429+
def ApplyConcatInsertDimReversePatterns : EnzymeHLOPatternOp<
2430+
"concat_insert_dim_reverse"> {
2431+
let patterns = ["ConcatInsertDimToBatch<stablehlo::ReverseOp>"];
2432+
}
24212433
def ApplyConcatInsertDimReduceWindowPatterns : EnzymeHLOPatternOp<
24222434
"concat_insert_dim_reduce_window"> {
24232435
let patterns = ["ConcatInsertDimToBatch<stablehlo::ReduceWindowOp>"];
@@ -2456,6 +2468,18 @@ def ApplyBroadcastInDimSliceToBatchPatterns : EnzymeHLOPatternOp<
24562468
"broadcastindim_slice_to_batch"> {
24572469
let patterns = ["SliceToBatch<stablehlo::BroadcastInDimOp>"];
24582470
}
2471+
def ApplyConcatenateSliceToBatchPatterns : EnzymeHLOPatternOp<
2472+
"concatenate_slice_to_batch"> {
2473+
let patterns = ["SliceToBatch<stablehlo::ConcatenateOp>"];
2474+
}
2475+
def ApplyGetDimensionSizeSliceToBatchPatterns : EnzymeHLOPatternOp<
2476+
"get_dimension_size_slice_to_batch"> {
2477+
let patterns = ["SliceToBatch<stablehlo::GetDimensionSizeOp>"];
2478+
}
2479+
def ApplyReverseSliceToBatchPatterns : EnzymeHLOPatternOp<
2480+
"reverse_slice_to_batch"> {
2481+
let patterns = ["SliceToBatch<stablehlo::ReverseOp>"];
2482+
}
24592483
def ApplyReduceWindowSliceToBatchPatterns : EnzymeHLOPatternOp<
24602484
"reducewindow_slice_to_batch"> {
24612485
let patterns = ["SliceToBatch<stablehlo::ReduceWindowOp>"];
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: enzymexlamlir-opt --auto-batching --enzyme-hlo-opt %s | FileCheck %s
2+
3+
module @reactant_loop1 attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
4+
func.func @main(%arg0: tensor<4x2x3xf32> {enzymexla.memory_effects = []}) -> tensor<4x2x3xf32> attributes {enzymexla.memory_effects = []} {
5+
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<4x2x3xf32>) -> tensor<3x2x4xf32>
6+
%1 = stablehlo.slice %0 [0:3, 0:1, 0:4] : (tensor<3x2x4xf32>) -> tensor<3x1x4xf32>
7+
%2 = stablehlo.reshape %1 : (tensor<3x1x4xf32>) -> tensor<3x4xf32>
8+
%3 = stablehlo.reverse %2, dims = [0, 1] : tensor<3x4xf32>
9+
%4 = stablehlo.broadcast_in_dim %3, dims = [2, 0] : (tensor<3x4xf32>) -> tensor<4x1x3xf32>
10+
%5 = stablehlo.slice %0 [0:3, 1:2, 0:4] : (tensor<3x2x4xf32>) -> tensor<3x1x4xf32>
11+
%6 = stablehlo.reshape %5 : (tensor<3x1x4xf32>) -> tensor<3x4xf32>
12+
%7 = stablehlo.reverse %6, dims = [0, 1] : tensor<3x4xf32>
13+
%8 = stablehlo.broadcast_in_dim %7, dims = [2, 0] : (tensor<3x4xf32>) -> tensor<4x1x3xf32>
14+
%9 = stablehlo.concatenate %4, %8, dim = 1 : (tensor<4x1x3xf32>, tensor<4x1x3xf32>) -> tensor<4x2x3xf32>
15+
return %9 : tensor<4x2x3xf32>
16+
}
17+
}
18+
19+
// CHECK: func.func @main(%arg0: tensor<4x2x3xf32> {enzymexla.memory_effects = []}) -> tensor<4x2x3xf32> attributes {enzymexla.memory_effects = []} {
20+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 2, 0] : (tensor<4x2x3xf32>) -> tensor<2x3x4xf32>
21+
// CHECK-NEXT: %1 = stablehlo.reverse %0, dims = [1, 2] : tensor<2x3x4xf32>
22+
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<2x3x4xf32>) -> tensor<4x2x3xf32>
23+
// CHECK-NEXT: return %2 : tensor<4x2x3xf32>
24+
// CHECK-NEXT: }
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: enzymexlamlir-opt --auto-batching --enzyme-hlo-opt %s | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<4x7x3xf32> {enzymexla.memory_effects = []}, %arg1: tensor<4x7x3xf32> {enzymexla.memory_effects = []}) -> tensor<8x7x3xf32> attributes {enzymexla.memory_effects = []} {
5+
%c = stablehlo.constant dense<0> : tensor<i32>
6+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
7+
%c_1 = stablehlo.constant dense<0> : tensor<i64>
8+
%c_2 = stablehlo.constant dense<1> : tensor<i64>
9+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<8x7x3xf32>
10+
%c_3 = stablehlo.constant dense<7> : tensor<i64>
11+
%0:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %cst) : tensor<i64>, tensor<8x7x3xf32> attributes {enzyme.disable_mincut}
12+
cond {
13+
%1 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i64>, tensor<i64>) -> tensor<i1>
14+
stablehlo.return %1 : tensor<i1>
15+
} do {
16+
%1 = stablehlo.add %c_2, %iterArg : tensor<i64>
17+
%2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
18+
%3 = stablehlo.subtract %2, %c_0 : tensor<i32>
19+
%4 = stablehlo.dynamic_slice %arg0, %c, %3, %c, sizes = [4, 1, 3] : (tensor<4x7x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x1x3xf32>
20+
%5 = stablehlo.dynamic_slice %arg1, %c, %3, %c, sizes = [4, 1, 3] : (tensor<4x7x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x1x3xf32>
21+
%6 = stablehlo.concatenate %4, %5, dim = 0 : (tensor<4x1x3xf32>, tensor<4x1x3xf32>) -> tensor<8x1x3xf32>
22+
%7 = stablehlo.dynamic_update_slice %iterArg_4, %6, %c, %3, %c : (tensor<8x7x3xf32>, tensor<8x1x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<8x7x3xf32>
23+
stablehlo.return %1, %7 : tensor<i64>, tensor<8x7x3xf32>
24+
}
25+
return %0#1 : tensor<8x7x3xf32>
26+
}
27+
}
28+
29+
// CHECK: func.func @main(%arg0: tensor<4x7x3xf32> {enzymexla.memory_effects = []}, %arg1: tensor<4x7x3xf32> {enzymexla.memory_effects = []}) -> tensor<8x7x3xf32> attributes {enzymexla.memory_effects = []} {
30+
// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0, 3] : (tensor<4x7x3xf32>) -> tensor<7x4x1x3xf32>
31+
// CHECK-NEXT: %1 = stablehlo.broadcast_in_dim %arg1, dims = [1, 0, 3] : (tensor<4x7x3xf32>) -> tensor<7x4x1x3xf32>
32+
// CHECK-NEXT: %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<7x4x1x3xf32>, tensor<7x4x1x3xf32>) -> tensor<7x8x1x3xf32>
33+
// CHECK-NEXT: %3 = stablehlo.reshape %2 : (tensor<7x8x1x3xf32>) -> tensor<7x8x3xf32>
34+
// CHECK-NEXT: %4 = stablehlo.transpose %3, dims = [1, 0, 2] : (tensor<7x8x3xf32>) -> tensor<8x7x3xf32>
35+
// CHECK-NEXT: return %4 : tensor<8x7x3xf32>
36+
// CHECK-NEXT: }
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: enzymexlamlir-opt --auto-batching --enzyme-hlo-opt %s | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<4x10x3xf32>) -> tensor<4x10x3xf32> {
5+
%c = stablehlo.constant dense<0> : tensor<i32>
6+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x10x3xf32>
7+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
8+
%c_1 = stablehlo.constant dense<0> : tensor<i64>
9+
%c_2 = stablehlo.constant dense<10> : tensor<i64>
10+
%c_3 = stablehlo.constant dense<1> : tensor<i64>
11+
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<4x10x3xf32>) -> tensor<3x10x4xf32>
12+
%1:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %cst) : tensor<i64>, tensor<4x10x3xf32> attributes {enzyme.disable_mincut}
13+
cond {
14+
%2 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1>
15+
stablehlo.return %2 : tensor<i1>
16+
} do {
17+
%2 = stablehlo.add %c_3, %iterArg : tensor<i64>
18+
%3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32>
19+
%4 = stablehlo.subtract %3, %c_0 : tensor<i32>
20+
%5 = stablehlo.dynamic_slice %0, %c, %4, %c, sizes = [3, 1, 4] : (tensor<3x10x4xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3x1x4xf32>
21+
%6 = stablehlo.reshape %5 : (tensor<3x1x4xf32>) -> tensor<3x4xf32>
22+
%7 = stablehlo.reverse %6, dims = [0, 1] : tensor<3x4xf32>
23+
%8 = stablehlo.broadcast_in_dim %7, dims = [2, 0] : (tensor<3x4xf32>) -> tensor<4x1x3xf32>
24+
%9 = stablehlo.dynamic_update_slice %iterArg_4, %8, %c, %4, %c : (tensor<4x10x3xf32>, tensor<4x1x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x10x3xf32>
25+
stablehlo.return %2, %9 : tensor<i64>, tensor<4x10x3xf32>
26+
}
27+
return %1#1 : tensor<4x10x3xf32>
28+
}
29+
}
30+
31+
// CHECK: func.func @main(%arg0: tensor<4x10x3xf32>) -> tensor<4x10x3xf32> {
32+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 2, 0] : (tensor<4x10x3xf32>) -> tensor<10x3x4xf32>
33+
// CHECK-NEXT: %1 = stablehlo.reverse %0, dims = [1, 2] : tensor<10x3x4xf32>
34+
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<10x3x4xf32>) -> tensor<4x10x3xf32>
35+
// CHECK-NEXT: return %2 : tensor<4x10x3xf32>
36+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)