Skip to content

Commit a6af9b3

Browse files
authored
feat: remove more intermediate reshape operations (#1870)
1 parent 2680f71 commit a6af9b3

24 files changed

+758
-257
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 308 additions & 42 deletions
Large diffs are not rendered by default.

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ def ReshapeEmptyBroadcastPatterns : EnzymeHLOPatternOp<
301301
"reshape_empty_broadcast"> {
302302
let patterns = ["ReshapeEmptyBroadcast"];
303303
}
304+
def ReshapeBroadcastPatterns : EnzymeHLOPatternOp<
305+
"reshape_broadcast"> {
306+
let patterns = ["ReshapeBroadcast"];
307+
}
304308
def ApplySliceReshapePadPatterns : EnzymeHLOPatternOp<
305309
"slice_reshape_pad"> {
306310
let patterns = ["SliceReshapePad"];
@@ -2561,6 +2565,10 @@ def ApplyReduceMulToDotGeneralPatterns : EnzymeHLOPatternOp<
25612565
"reduce_mul_to_dot_general"> {
25622566
let patterns = ["ReduceMulToDotGeneral"];
25632567
}
2568+
def ApplySplitReduceAddMulToAddDotGeneralPatterns : EnzymeHLOPatternOp<
2569+
"split_reduce_add_mul_to_add_dot_general"> {
2570+
let patterns = ["SplitReduceAddMulToAddDotGeneral"];
2571+
}
25642572

25652573
def ApplyDotGeneralOnlyDiagonalAccessPatterns : EnzymeHLOPatternOp<
25662574
"dot_general_only_diagonal_access"> {

src/enzyme_ad/jax/Utils.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,14 +1020,16 @@ SmallVector<int64_t> findReshapeInsertionDims(ArrayRef<int64_t> inputShape,
10201020

10211021
bool isInsertDimOp(stablehlo::ReshapeOp reshapeOp) {
10221022
RankedTensorType inputTy = reshapeOp.getOperand().getType();
1023-
auto inputShape = inputTy.getShape();
10241023
RankedTensorType outputTy = reshapeOp.getType();
1025-
auto outputShape = outputTy.getShape();
1026-
auto insertDims = findReshapeInsertionDims(inputShape, outputShape);
1027-
if (insertDims.empty()) {
1028-
return false;
1029-
}
1030-
return true;
1024+
auto insertDims = findReshapeInsertionDims(inputTy, outputTy);
1025+
return !insertDims.empty();
1026+
}
1027+
1028+
bool isDeleteDimOp(stablehlo::ReshapeOp reshapeOp) {
1029+
RankedTensorType inputTy = reshapeOp.getOperand().getType();
1030+
RankedTensorType outputTy = reshapeOp.getType();
1031+
auto deleteDims = findReshapeInsertionDims(outputTy, inputTy);
1032+
return !deleteDims.empty();
10311033
}
10321034

10331035
void getSingletonInsertionDims(stablehlo::BroadcastInDimOp bcastOp,
@@ -2427,7 +2429,6 @@ bool isFusible(stablehlo::TransposeOp transpose, Operation *op) {
24272429
return false;
24282430
}
24292431

2430-
// TODO: implement more conditions especially for fusions with transpose
24312432
bool isFusible(Operation *op, stablehlo::ReshapeOp reshape) {
24322433
return TypeSwitch<Operation *, bool>(op)
24332434
.Case<stablehlo::ReshapeOp>([](auto prevOp) { return true; })
@@ -2450,6 +2451,8 @@ bool isFusible(Operation *op, stablehlo::ReshapeOp reshape) {
24502451
}
24512452
return false;
24522453
})
2454+
.Case<stablehlo::ReduceOp>(
2455+
[&](auto redOp) { return isDeleteDimOp(reshape); })
24532456
.Default([](auto other) { return matchPattern(other, m_Constant()); });
24542457
}
24552458

src/enzyme_ad/jax/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ SmallVector<int64_t> findReshapeInsertionDims(ArrayRef<int64_t> inputShape,
905905
ArrayRef<int64_t> outputShape);
906906

907907
bool isInsertDimOp(stablehlo::ReshapeOp reshapeOp);
908+
bool isDeleteDimOp(stablehlo::ReshapeOp reshapeOp);
908909

909910
void getSingletonInsertionDims(stablehlo::BroadcastInDimOp bcastOp,
910911
SmallVectorImpl<int64_t> &insertionDims);

src/enzyme_ad/jax/primitives.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def optimization_passes(
202202
"dot_general_simplify<16>",
203203
"transpose_simplify<16>",
204204
"reshape_empty_broadcast<1>",
205+
"reshape_broadcast<1>",
205206
"broadcast_reshape<1>",
206207
"transpose_dot_reorder<1>",
207208
"dot_transpose<1>",
@@ -306,6 +307,7 @@ def optimization_passes(
306307
"trivial_reduce_window_to_reduce_op",
307308
"case_to_if",
308309
"reduce_mul_to_dot_general",
310+
"split_reduce_add_mul_to_add_dot_general",
309311
"dot_general_add_distributive_simplify",
310312
"dot_general_subtract_distributive_simplify",
311313
"remove_no_ops_from_while_loop",

test/lit_tests/autobatching/elementwise_loop.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,19 @@ module {
123123
// CHECK-NEXT: %c_3 = stablehlo.constant dense<2> : tensor<i32>
124124
// CHECK-NEXT: %0 = stablehlo.dynamic_slice %arg0, %c_3, sizes = [10] : (tensor<10xf64>, tensor<i32>) -> tensor<10xf64>
125125
// CHECK-NEXT: %1 = stablehlo.slice %0 [0:10:3] : (tensor<10xf64>) -> tensor<4xf64>
126-
// CHECK-NEXT: %2 = stablehlo.dynamic_slice %arg0, %c_3, sizes = [10] : (tensor<10xf64>, tensor<i32>) -> tensor<10xf64>
127-
// CHECK-NEXT: %3 = stablehlo.slice %2 [0:10:3] : (tensor<10xf64>) -> tensor<4xf64>
128-
// CHECK-NEXT: %4 = stablehlo.sine %3 : tensor<4xf64>
129-
// CHECK-NEXT: %5 = stablehlo.cosine %1 : tensor<4xf64>
126+
// CHECK-NEXT: %2 = stablehlo.cosine %1 : tensor<4xf64>
127+
// CHECK-NEXT: %3 = stablehlo.dynamic_slice %arg0, %c_3, sizes = [10] : (tensor<10xf64>, tensor<i32>) -> tensor<10xf64>
128+
// CHECK-NEXT: %4 = stablehlo.slice %3 [0:10:3] : (tensor<10xf64>) -> tensor<4xf64>
129+
// CHECK-NEXT: %5 = stablehlo.sine %4 : tensor<4xf64>
130130
// CHECK-NEXT: %6:2 = stablehlo.while(%iterArg = %c_0, %iterArg_4 = %cst) : tensor<i64>, tensor<10xf64>
131131
// CHECK-NEXT: cond {
132132
// CHECK-NEXT: %7 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
133133
// CHECK-NEXT: stablehlo.return %7 : tensor<i1>
134134
// CHECK-NEXT: } do {
135135
// CHECK-NEXT: %7 = stablehlo.add %c_2, %iterArg {enzymexla.bounds = {{.*}}} : tensor<i64>
136136
// CHECK-NEXT: %8 = stablehlo.divide %iterArg, %c_2 {enzymexla.bounds = {{.*}}} : tensor<i64>
137-
// CHECK-NEXT: %9 = stablehlo.dynamic_slice %4, %8, sizes = [1] : (tensor<4xf64>, tensor<i64>) -> tensor<1xf64>
138-
// CHECK-NEXT: %10 = stablehlo.dynamic_slice %5, %8, sizes = [1] : (tensor<4xf64>, tensor<i64>) -> tensor<1xf64>
137+
// CHECK-NEXT: %9 = stablehlo.dynamic_slice %5, %8, sizes = [1] : (tensor<4xf64>, tensor<i64>) -> tensor<1xf64>
138+
// CHECK-NEXT: %10 = stablehlo.dynamic_slice %2, %8, sizes = [1] : (tensor<4xf64>, tensor<i64>) -> tensor<1xf64>
139139
// CHECK-NEXT: %11 = stablehlo.subtract %10, %9 : tensor<1xf64>
140140
// CHECK-NEXT: %12 = stablehlo.convert %7 {enzymexla.bounds = {{.*}}} : (tensor<i64>) -> tensor<i32>
141141
// CHECK-NEXT: %13 = stablehlo.subtract %12, %c {enzymexla.bounds = {{.*}}} : tensor<i32>

test/lit_tests/autobatching/elementwise_loop_affine.mlir

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,14 @@ func.func @main1(%arg0: tensor<25xf32>) -> tensor<13xf32> {
3535

3636
// CHECK: func.func @main1(%arg0: tensor<25xf32>) -> tensor<13xf32> {
3737
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
38-
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<10x1xf32>
39-
// CHECK-NEXT: %cst_1 = stablehlo.constant dense<3.000000e+00> : tensor<10x1xf32>
38+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32>
39+
// CHECK-NEXT: %cst_1 = stablehlo.constant dense<3.000000e+00> : tensor<10xf32>
4040
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [6:25:2] : (tensor<25xf32>) -> tensor<10xf32>
41-
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<10xf32>) -> tensor<10x1xf32>
42-
// CHECK-NEXT: %2 = stablehlo.multiply %1, %cst_1 : tensor<10x1xf32>
43-
// CHECK-NEXT: %3 = stablehlo.subtract %2, %cst_0 : tensor<10x1xf32>
44-
// CHECK-NEXT: %4 = stablehlo.sine %3 : tensor<10x1xf32>
45-
// CHECK-NEXT: %5 = stablehlo.reshape %4 : (tensor<10x1xf32>) -> tensor<10xf32>
46-
// CHECK-NEXT: %6 = stablehlo.pad %5, %cst, low = [2], high = [1], interior = [0] : (tensor<10xf32>, tensor<f32>) -> tensor<13xf32>
47-
// CHECK-NEXT: return %6 : tensor<13xf32>
41+
// CHECK-NEXT: %1 = stablehlo.multiply %0, %cst_1 : tensor<10xf32>
42+
// CHECK-NEXT: %2 = stablehlo.subtract %1, %cst_0 : tensor<10xf32>
43+
// CHECK-NEXT: %3 = stablehlo.sine %2 : tensor<10xf32>
44+
// CHECK-NEXT: %4 = stablehlo.pad %3, %cst, low = [2], high = [1], interior = [0] : (tensor<10xf32>, tensor<f32>) -> tensor<13xf32>
45+
// CHECK-NEXT: return %4 : tensor<13xf32>
4846
// CHECK-NEXT: }
4947

5048
func.func @main2(%arg0: tensor<25xf32>) -> tensor<13xf32> {
@@ -82,14 +80,12 @@ func.func @main2(%arg0: tensor<25xf32>) -> tensor<13xf32> {
8280

8381
// CHECK: func.func @main2(%arg0: tensor<25xf32>) -> tensor<13xf32> {
8482
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
85-
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<10x1xf32>
86-
// CHECK-NEXT: %cst_1 = stablehlo.constant dense<3.000000e+00> : tensor<10x1xf32>
83+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32>
84+
// CHECK-NEXT: %cst_1 = stablehlo.constant dense<3.000000e+00> : tensor<10xf32>
8785
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [6:25:2] : (tensor<25xf32>) -> tensor<10xf32>
88-
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<10xf32>) -> tensor<10x1xf32>
89-
// CHECK-NEXT: %2 = stablehlo.multiply %1, %cst_1 : tensor<10x1xf32>
90-
// CHECK-NEXT: %3 = stablehlo.subtract %2, %cst_0 : tensor<10x1xf32>
91-
// CHECK-NEXT: %4 = stablehlo.sine %3 : tensor<10x1xf32>
92-
// CHECK-NEXT: %5 = stablehlo.reshape %4 : (tensor<10x1xf32>) -> tensor<10xf32>
93-
// CHECK-NEXT: %6 = stablehlo.pad %5, %cst, low = [2], high = [1], interior = [0] : (tensor<10xf32>, tensor<f32>) -> tensor<13xf32>
94-
// CHECK-NEXT: return %6 : tensor<13xf32>
86+
// CHECK-NEXT: %1 = stablehlo.multiply %0, %cst_1 : tensor<10xf32>
87+
// CHECK-NEXT: %2 = stablehlo.subtract %1, %cst_0 : tensor<10xf32>
88+
// CHECK-NEXT: %3 = stablehlo.sine %2 : tensor<10xf32>
89+
// CHECK-NEXT: %4 = stablehlo.pad %3, %cst, low = [2], high = [1], interior = [0] : (tensor<10xf32>, tensor<f32>) -> tensor<13xf32>
90+
// CHECK-NEXT: return %4 : tensor<13xf32>
9591
// CHECK-NEXT: }

test/lit_tests/autobatching/indirect_iota_indexing.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ module {
3232
}
3333

3434
// CHECK: func.func @main(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf32> {
35-
// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 : tensor<10xf64>
36-
// CHECK-NEXT: %1 = stablehlo.maximum %arg0, %arg1 : tensor<10xf64>
37-
// CHECK-NEXT: %2 = stablehlo.add %0, %1 : tensor<10xf64>
35+
// CHECK-NEXT: %0 = stablehlo.maximum %arg0, %arg1 : tensor<10xf64>
36+
// CHECK-NEXT: %1 = stablehlo.add %arg0, %arg1 : tensor<10xf64>
37+
// CHECK-NEXT: %2 = stablehlo.add %1, %0 : tensor<10xf64>
3838
// CHECK-NEXT: %3 = stablehlo.convert %2 : (tensor<10xf64>) -> tensor<10xf32>
3939
// CHECK-NEXT: return %3 : tensor<10xf32>
4040
// CHECK-NEXT: }

test/lit_tests/autobatching/indirect_iota_indexing2.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ module {
3333
// CHECK: func.func @main(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<6xf64> {
3434
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [2:8] : (tensor<10xf64>) -> tensor<6xf64>
3535
// CHECK-NEXT: %1 = stablehlo.slice %arg1 [2:8] : (tensor<10xf64>) -> tensor<6xf64>
36-
// CHECK-NEXT: %2 = stablehlo.add %0, %1 : tensor<6xf64>
37-
// CHECK-NEXT: %3 = stablehlo.maximum %0, %1 : tensor<6xf64>
38-
// CHECK-NEXT: %4 = stablehlo.add %2, %3 : tensor<6xf64>
36+
// CHECK-NEXT: %2 = stablehlo.maximum %0, %1 : tensor<6xf64>
37+
// CHECK-NEXT: %3 = stablehlo.add %0, %1 : tensor<6xf64>
38+
// CHECK-NEXT: %4 = stablehlo.add %3, %2 : tensor<6xf64>
3939
// CHECK-NEXT: return %4 : tensor<6xf64>
4040
// CHECK-NEXT: }

test/lit_tests/autobatching/nbody.mlir

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -73,48 +73,48 @@ module {
7373
// CHECK-NEXT: func.func @main(%arg0: tensor<100x3xf32>, %arg1: tensor<100xf32>) -> tensor<100x100x3xf32> {
7474
// CHECK-NEXT: %cst = stablehlo.constant dense<1.000000e+00> : tensor<100x100xf32>
7575
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<100x100xi64>
76-
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:100, 2:3] : (tensor<100x3xf32>) -> tensor<100x1xf32>
77-
// CHECK-NEXT: %1 = stablehlo.broadcast_in_dim %0, dims = [0, 2] : (tensor<100x1xf32>) -> tensor<100x100x1x1xf32>
78-
// CHECK-NEXT: %2 = stablehlo.slice %arg0 [0:100, 0:2] : (tensor<100x3xf32>) -> tensor<100x2xf32>
79-
// CHECK-NEXT: %3 = stablehlo.broadcast_in_dim %2, dims = [1, 0] : (tensor<100x2xf32>) -> tensor<2x100x100x1x1xf32>
80-
// CHECK-NEXT: %4 = stablehlo.iota dim = 1 : tensor<100x100xi64>
81-
// CHECK-NEXT: %5 = stablehlo.add %c, %4 : tensor<100x100xi64>
82-
// CHECK-NEXT: %6 = stablehlo.broadcast_in_dim %2, dims = [2, 0] : (tensor<100x2xf32>) -> tensor<2x100x100x1x1xf32>
83-
// CHECK-NEXT: %7 = stablehlo.slice %6 [0:1, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
84-
// CHECK-NEXT: %8 = stablehlo.reshape %7 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100x1x1xf32>
85-
// CHECK-NEXT: %9 = stablehlo.slice %6 [1:2, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
86-
// CHECK-NEXT: %10 = stablehlo.reshape %9 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100x1x1xf32>
87-
// CHECK-NEXT: %11 = stablehlo.concatenate %8, %10, dim = 0 : (tensor<100x100x1x1xf32>, tensor<100x100x1x1xf32>) -> tensor<200x100x1x1xf32>
88-
// CHECK-NEXT: %12 = stablehlo.reshape %11 : (tensor<200x100x1x1xf32>) -> tensor<2x100x100x1x1xf32>
89-
// CHECK-NEXT: %13 = stablehlo.subtract %3, %12 : tensor<2x100x100x1x1xf32>
90-
// CHECK-NEXT: %14 = stablehlo.slice %13 [0:1, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
91-
// CHECK-NEXT: %15 = stablehlo.slice %13 [1:2, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
92-
// CHECK-NEXT: %16 = stablehlo.iota dim = 0 : tensor<100x100xi64>
93-
// CHECK-NEXT: %17 = stablehlo.add %c, %16 : tensor<100x100xi64>
94-
// CHECK-NEXT: %18 = stablehlo.compare EQ, %5, %17 : (tensor<100x100xi64>, tensor<100x100xi64>) -> tensor<100x100xi1>
95-
// CHECK-NEXT: %19 = stablehlo.broadcast_in_dim %0, dims = [1, 2] : (tensor<100x1xf32>) -> tensor<100x100x1x1xf32>
96-
// CHECK-NEXT: %20 = stablehlo.subtract %1, %19 : tensor<100x100x1x1xf32>
97-
// CHECK-NEXT: %21 = stablehlo.reshape %20 : (tensor<100x100x1x1xf32>) -> tensor<100x100x1x1x1xf32>
98-
// CHECK-NEXT: %22 = stablehlo.transpose %13, dims = [1, 2, 3, 4, 0] : (tensor<2x100x100x1x1xf32>) -> tensor<100x100x1x1x2xf32>
99-
// CHECK-NEXT: %23 = stablehlo.concatenate %22, %21, dim = 4 : (tensor<100x100x1x1x2xf32>, tensor<100x100x1x1x1xf32>) -> tensor<100x100x1x1x3xf32>
100-
// CHECK-NEXT: %24 = stablehlo.multiply %20, %20 : tensor<100x100x1x1xf32>
101-
// CHECK-NEXT: %25 = stablehlo.reshape %14 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100xf32>
102-
// CHECK-NEXT: %26 = stablehlo.multiply %25, %25 : tensor<100x100xf32>
103-
// CHECK-NEXT: %27 = stablehlo.reshape %15 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100xf32>
104-
// CHECK-NEXT: %28 = stablehlo.multiply %27, %27 : tensor<100x100xf32>
105-
// CHECK-NEXT: %29 = stablehlo.add %26, %28 : tensor<100x100xf32>
106-
// CHECK-NEXT: %30 = stablehlo.reshape %24 : (tensor<100x100x1x1xf32>) -> tensor<100x100xf32>
107-
// CHECK-NEXT: %31 = stablehlo.add %29, %30 : tensor<100x100xf32>
108-
// CHECK-NEXT: %32 = stablehlo.divide %cst, %31 : tensor<100x100xf32>
109-
// CHECK-NEXT: %33 = stablehlo.select %18, %25, %32 : tensor<100x100xi1>, tensor<100x100xf32>
110-
// CHECK-NEXT: %34 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<100xf32>) -> tensor<100x100xf32>
111-
// CHECK-NEXT: %35 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<100xf32>) -> tensor<100x100xf32>
112-
// CHECK-NEXT: %36 = stablehlo.multiply %34, %35 : tensor<100x100xf32>
113-
// CHECK-NEXT: %37 = stablehlo.multiply %36, %33 : tensor<100x100xf32>
76+
// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<100xf32>) -> tensor<100x100xf32>
77+
// CHECK-NEXT: %1 = stablehlo.slice %arg0 [0:100, 2:3] : (tensor<100x3xf32>) -> tensor<100x1xf32>
78+
// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [0, 2] : (tensor<100x1xf32>) -> tensor<100x100x1x1xf32>
79+
// CHECK-NEXT: %3 = stablehlo.slice %arg0 [0:100, 0:2] : (tensor<100x3xf32>) -> tensor<100x2xf32>
80+
// CHECK-NEXT: %4 = stablehlo.broadcast_in_dim %3, dims = [1, 0] : (tensor<100x2xf32>) -> tensor<2x100x100x1x1xf32>
81+
// CHECK-NEXT: %5 = stablehlo.iota dim = 1 : tensor<100x100xi64>
82+
// CHECK-NEXT: %6 = stablehlo.add %c, %5 : tensor<100x100xi64>
83+
// CHECK-NEXT: %7 = stablehlo.broadcast_in_dim %3, dims = [2, 0] : (tensor<100x2xf32>) -> tensor<2x100x100x1x1xf32>
84+
// CHECK-NEXT: %8 = stablehlo.slice %7 [0:1, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
85+
// CHECK-NEXT: %9 = stablehlo.reshape %8 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100x1x1xf32>
86+
// CHECK-NEXT: %10 = stablehlo.slice %7 [1:2, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
87+
// CHECK-NEXT: %11 = stablehlo.reshape %10 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100x1x1xf32>
88+
// CHECK-NEXT: %12 = stablehlo.concatenate %9, %11, dim = 0 : (tensor<100x100x1x1xf32>, tensor<100x100x1x1xf32>) -> tensor<200x100x1x1xf32>
89+
// CHECK-NEXT: %13 = stablehlo.reshape %12 : (tensor<200x100x1x1xf32>) -> tensor<2x100x100x1x1xf32>
90+
// CHECK-NEXT: %14 = stablehlo.subtract %4, %13 : tensor<2x100x100x1x1xf32>
91+
// CHECK-NEXT: %15 = stablehlo.slice %14 [0:1, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
92+
// CHECK-NEXT: %16 = stablehlo.slice %14 [1:2, 0:100, 0:100, 0:1, 0:1] : (tensor<2x100x100x1x1xf32>) -> tensor<1x100x100x1x1xf32>
93+
// CHECK-NEXT: %17 = stablehlo.reshape %16 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100xf32>
94+
// CHECK-NEXT: %18 = stablehlo.iota dim = 0 : tensor<100x100xi64>
95+
// CHECK-NEXT: %19 = stablehlo.add %c, %18 : tensor<100x100xi64>
96+
// CHECK-NEXT: %20 = stablehlo.compare EQ, %6, %19 : (tensor<100x100xi64>, tensor<100x100xi64>) -> tensor<100x100xi1>
97+
// CHECK-NEXT: %21 = stablehlo.broadcast_in_dim %1, dims = [1, 2] : (tensor<100x1xf32>) -> tensor<100x100x1x1xf32>
98+
// CHECK-NEXT: %22 = stablehlo.subtract %2, %21 : tensor<100x100x1x1xf32>
99+
// CHECK-NEXT: %23 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<100xf32>) -> tensor<100x100xf32>
100+
// CHECK-NEXT: %24 = stablehlo.multiply %0, %23 : tensor<100x100xf32>
101+
// CHECK-NEXT: %25 = stablehlo.reshape %22 : (tensor<100x100x1x1xf32>) -> tensor<100x100x1x1x1xf32>
102+
// CHECK-NEXT: %26 = stablehlo.transpose %14, dims = [1, 2, 3, 4, 0] : (tensor<2x100x100x1x1xf32>) -> tensor<100x100x1x1x2xf32>
103+
// CHECK-NEXT: %27 = stablehlo.concatenate %26, %25, dim = 4 : (tensor<100x100x1x1x2xf32>, tensor<100x100x1x1x1xf32>) -> tensor<100x100x1x1x3xf32>
104+
// CHECK-NEXT: %28 = stablehlo.multiply %22, %22 : tensor<100x100x1x1xf32>
105+
// CHECK-NEXT: %29 = stablehlo.multiply %17, %17 : tensor<100x100xf32>
106+
// CHECK-NEXT: %30 = stablehlo.reshape %15 : (tensor<1x100x100x1x1xf32>) -> tensor<100x100xf32>
107+
// CHECK-NEXT: %31 = stablehlo.multiply %30, %30 : tensor<100x100xf32>
108+
// CHECK-NEXT: %32 = stablehlo.add %31, %29 : tensor<100x100xf32>
109+
// CHECK-NEXT: %33 = stablehlo.reshape %28 : (tensor<100x100x1x1xf32>) -> tensor<100x100xf32>
110+
// CHECK-NEXT: %34 = stablehlo.add %32, %33 : tensor<100x100xf32>
111+
// CHECK-NEXT: %35 = stablehlo.divide %cst, %34 : tensor<100x100xf32>
112+
// CHECK-NEXT: %36 = stablehlo.select %20, %30, %35 : tensor<100x100xi1>, tensor<100x100xf32>
113+
// CHECK-NEXT: %37 = stablehlo.multiply %24, %36 : tensor<100x100xf32>
114114
// CHECK-NEXT: %38 = stablehlo.broadcast_in_dim %37, dims = [0, 1] : (tensor<100x100xf32>) -> tensor<100x100x1x1x2xf32>
115115
// CHECK-NEXT: %39 = stablehlo.reshape %37 : (tensor<100x100xf32>) -> tensor<100x100x1x1x1xf32>
116116
// CHECK-NEXT: %40 = stablehlo.concatenate %38, %39, dim = 4 : (tensor<100x100x1x1x2xf32>, tensor<100x100x1x1x1xf32>) -> tensor<100x100x1x1x3xf32>
117-
// CHECK-NEXT: %41 = stablehlo.multiply %40, %23 : tensor<100x100x1x1x3xf32>
117+
// CHECK-NEXT: %41 = stablehlo.multiply %40, %27 : tensor<100x100x1x1x3xf32>
118118
// CHECK-NEXT: %42 = stablehlo.reshape %41 : (tensor<100x100x1x1x3xf32>) -> tensor<100x100x3xf32>
119119
// CHECK-NEXT: %43 = stablehlo.transpose %42, dims = [1, 0, 2] : (tensor<100x100x3xf32>) -> tensor<100x100x3xf32>
120120
// CHECK-NEXT: return %43 : tensor<100x100x3xf32>

0 commit comments

Comments
 (0)