Skip to content

Commit af0d25d

Browse files
committed
feat: dynamic slice simplify
1 parent bcb7d56 commit af0d25d

File tree

1 file changed

+25
-33
lines changed

1 file changed

+25
-33
lines changed

test/lit_tests/linalg/lu_batched.mlir

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -40,43 +40,35 @@ module {
4040
// CPU-NEXT: %6 = stablehlo.convert %0#2 : (tensor<4x3xi64>) -> tensor<4x3xi32>
4141
// CPU-NEXT: return %0#0, %4, %5, %6 : tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32>
4242
// CPU-NEXT: }
43-
// CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) {
44-
// CPU-NEXT: %c = stablehlo.constant dense<4> : tensor<i64>
45-
// CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor<i64>
46-
// CPU-NEXT: %c_1 = stablehlo.constant dense<12> : tensor<i64>
43+
// CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) {
44+
// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor<i64>
45+
// CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64>
46+
// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor<i64>
47+
// CPU-NEXT: %c_2 = stablehlo.constant dense<4> : tensor<i64>
48+
// CPU-NEXT: %c_3 = stablehlo.constant dense<1> : tensor<i64>
49+
// CPU-NEXT: %c_4 = stablehlo.constant dense<12> : tensor<i64>
4750
// CPU-NEXT: %cst = arith.constant dense<0> : tensor<4x3xi64>
48-
// CPU-NEXT: %cst_2 = arith.constant dense<0> : tensor<4x3x64xi64>
49-
// CPU-NEXT: %cst_3 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32>
50-
// CPU-NEXT: %c_4 = stablehlo.constant dense<0> : tensor<i64>
51-
// CPU-NEXT: %c_5 = stablehlo.constant dense<-1> : tensor<4x3xi64>
52-
// CPU-NEXT: %c_6 = stablehlo.constant dense<-1> : tensor<4x3x64xi64>
53-
// CPU-NEXT: %c_7 = stablehlo.constant dense<64> : tensor<4x3xi64>
54-
// CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_4, %iterArg_8 = %cst_3, %iterArg_9 = %cst_2, %iterArg_10 = %cst) : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
51+
// CPU-NEXT: %cst_5 = arith.constant dense<0> : tensor<4x3x64xi64>
52+
// CPU-NEXT: %cst_6 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32>
53+
// CPU-NEXT: %c_7 = stablehlo.constant dense<0> : tensor<i64>
54+
// CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_7, %iterArg_8 = %cst_6, %iterArg_9 = %cst_5, %iterArg_10 = %cst) : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
5555
// CPU-NEXT: cond {
56-
// CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
56+
// CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor<i64>, tensor<i64>) -> tensor<i1>
5757
// CPU-NEXT: stablehlo.return %1 : tensor<i1>
5858
// CPU-NEXT: } do {
59-
// CPU-NEXT: %1 = stablehlo.add %iterArg, %c_0 : tensor<i64>
60-
// CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c : tensor<i64>
61-
// CPU-NEXT: %3 = stablehlo.divide %iterArg, %c : tensor<i64>
62-
// CPU-NEXT: %4 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor<i64>, tensor<i64>) -> tensor<1x1xi64>
63-
// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1xi64>) -> tensor<i64>
64-
// CPU-NEXT: %6 = stablehlo.dynamic_slice %arg0, %2, %3, %c_4, %c_4, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x64x64xf32>
65-
// CPU-NEXT: %7 = stablehlo.reshape %6 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32>
66-
// CPU-NEXT: %8 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor<i64>, tensor<i64>) -> tensor<1x1xi64>
67-
// CPU-NEXT: %9 = stablehlo.reshape %8 : (tensor<1x1xi64>) -> tensor<i64>
68-
// CPU-NEXT: %10 = stablehlo.dynamic_slice %c_6, %2, %3, %c_4, sizes = [1, 1, 64] : (tensor<4x3x64xi64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x64xi64>
69-
// CPU-NEXT: %11 = stablehlo.reshape %10 : (tensor<1x1x64xi64>) -> tensor<64xi64>
70-
// CPU-NEXT: %12 = stablehlo.dynamic_slice %c_5, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor<i64>, tensor<i64>) -> tensor<1x1xi64>
71-
// CPU-NEXT: %13 = stablehlo.reshape %12 : (tensor<1x1xi64>) -> tensor<i64>
72-
// CPU-NEXT: %14:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%9, %5, %7, %9, %11, %13) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
73-
// CPU-NEXT: %15 = stablehlo.reshape %14#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32>
74-
// CPU-NEXT: %16 = stablehlo.dynamic_update_slice %iterArg_8, %15, %2, %3, %c_4, %c_4 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64x64xf32>
75-
// CPU-NEXT: %17 = stablehlo.reshape %14#1 : (tensor<64xi64>) -> tensor<1x1x64xi64>
76-
// CPU-NEXT: %18 = stablehlo.dynamic_update_slice %iterArg_9, %17, %2, %3, %c_4 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64xi64>
77-
// CPU-NEXT: %19 = stablehlo.reshape %14#2 : (tensor<i64>) -> tensor<1x1xi64>
78-
// CPU-NEXT: %20 = stablehlo.dynamic_update_slice %iterArg_10, %19, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor<i64>, tensor<i64>) -> tensor<4x3xi64>
79-
// CPU-NEXT: stablehlo.return %1, %16, %18, %20 : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
59+
// CPU-NEXT: %1 = stablehlo.add %iterArg, %c_3 : tensor<i64>
60+
// CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c_2 : tensor<i64>
61+
// CPU-NEXT: %3 = stablehlo.divide %iterArg, %c_2 : tensor<i64>
62+
// CPU-NEXT: %4 = stablehlo.dynamic_slice %arg0, %2, %3, %c_7, %c_7, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x64x64xf32>
63+
// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32>
64+
// CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
65+
// CPU-NEXT: %7 = stablehlo.reshape %6#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32>
66+
// CPU-NEXT: %8 = stablehlo.dynamic_update_slice %iterArg_8, %7, %2, %3, %c_7, %c_7 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64x64xf32>
67+
// CPU-NEXT: %9 = stablehlo.reshape %6#1 : (tensor<64xi64>) -> tensor<1x1x64xi64>
68+
// CPU-NEXT: %10 = stablehlo.dynamic_update_slice %iterArg_9, %9, %2, %3, %c_7 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64xi64>
69+
// CPU-NEXT: %11 = stablehlo.reshape %6#2 : (tensor<i64>) -> tensor<1x1xi64>
70+
// CPU-NEXT: %12 = stablehlo.dynamic_update_slice %iterArg_10, %11, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor<i64>, tensor<i64>) -> tensor<4x3xi64>
71+
// CPU-NEXT: stablehlo.return %1, %8, %10, %12 : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
8072
// CPU-NEXT: }
8173
// CPU-NEXT: stablehlo.return %0#1, %0#2, %0#3 : tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
8274
// CPU-NEXT: }

0 commit comments

Comments
 (0)