@@ -40,43 +40,35 @@ module {
4040// CPU-NEXT: %6 = stablehlo.convert %0#2 : (tensor<4x3xi64>) -> tensor<4x3xi32>
4141// CPU-NEXT: return %0#0, %4, %5, %6 : tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32>
4242// CPU-NEXT: }
43- // CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) {
44- // CPU-NEXT: %c = stablehlo.constant dense<4> : tensor<i64>
45- // CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor<i64>
46- // CPU-NEXT: %c_1 = stablehlo.constant dense<12> : tensor<i64>
43+ // CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) {
44+ // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor<i64>
45+ // CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64>
46+ // CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor<i64>
47+ // CPU-NEXT: %c_2 = stablehlo.constant dense<4> : tensor<i64>
48+ // CPU-NEXT: %c_3 = stablehlo.constant dense<1> : tensor<i64>
49+ // CPU-NEXT: %c_4 = stablehlo.constant dense<12> : tensor<i64>
4750// CPU-NEXT: %cst = arith.constant dense<0> : tensor<4x3xi64>
48- // CPU-NEXT: %cst_2 = arith.constant dense<0> : tensor<4x3x64xi64>
49- // CPU-NEXT: %cst_3 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32>
50- // CPU-NEXT: %c_4 = stablehlo.constant dense<0> : tensor<i64>
51- // CPU-NEXT: %c_5 = stablehlo.constant dense<-1> : tensor<4x3xi64>
52- // CPU-NEXT: %c_6 = stablehlo.constant dense<-1> : tensor<4x3x64xi64>
53- // CPU-NEXT: %c_7 = stablehlo.constant dense<64> : tensor<4x3xi64>
54- // CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_4, %iterArg_8 = %cst_3, %iterArg_9 = %cst_2, %iterArg_10 = %cst) : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
51+ // CPU-NEXT: %cst_5 = arith.constant dense<0> : tensor<4x3x64xi64>
52+ // CPU-NEXT: %cst_6 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32>
53+ // CPU-NEXT: %c_7 = stablehlo.constant dense<0> : tensor<i64>
54+ // CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_7, %iterArg_8 = %cst_6, %iterArg_9 = %cst_5, %iterArg_10 = %cst) : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
5555// CPU-NEXT: cond {
56- // CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
56+ // CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor<i64>, tensor<i64>) -> tensor<i1>
5757// CPU-NEXT: stablehlo.return %1 : tensor<i1>
5858// CPU-NEXT: } do {
59- // CPU-NEXT: %1 = stablehlo.add %iterArg, %c_0 : tensor<i64>
60- // CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c : tensor<i64>
61- // CPU-NEXT: %3 = stablehlo.divide %iterArg, %c : tensor<i64>
62- // CPU-NEXT: %4 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor<i64>, tensor<i64>) -> tensor<1x1xi64>
63- // CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1xi64>) -> tensor<i64>
64- // CPU-NEXT: %6 = stablehlo.dynamic_slice %arg0, %2, %3, %c_4, %c_4, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x64x64xf32>
65- // CPU-NEXT: %7 = stablehlo.reshape %6 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32>
66- // CPU-NEXT: %8 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor<i64>, tensor<i64>) -> tensor<1x1xi64>
67- // CPU-NEXT: %9 = stablehlo.reshape %8 : (tensor<1x1xi64>) -> tensor<i64>
68- // CPU-NEXT: %10 = stablehlo.dynamic_slice %c_6, %2, %3, %c_4, sizes = [1, 1, 64] : (tensor<4x3x64xi64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x64xi64>
69- // CPU-NEXT: %11 = stablehlo.reshape %10 : (tensor<1x1x64xi64>) -> tensor<64xi64>
70- // CPU-NEXT: %12 = stablehlo.dynamic_slice %c_5, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor<i64>, tensor<i64>) -> tensor<1x1xi64>
71- // CPU-NEXT: %13 = stablehlo.reshape %12 : (tensor<1x1xi64>) -> tensor<i64>
72- // CPU-NEXT: %14:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%9, %5, %7, %9, %11, %13) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
73- // CPU-NEXT: %15 = stablehlo.reshape %14#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32>
74- // CPU-NEXT: %16 = stablehlo.dynamic_update_slice %iterArg_8, %15, %2, %3, %c_4, %c_4 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64x64xf32>
75- // CPU-NEXT: %17 = stablehlo.reshape %14#1 : (tensor<64xi64>) -> tensor<1x1x64xi64>
76- // CPU-NEXT: %18 = stablehlo.dynamic_update_slice %iterArg_9, %17, %2, %3, %c_4 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64xi64>
77- // CPU-NEXT: %19 = stablehlo.reshape %14#2 : (tensor<i64>) -> tensor<1x1xi64>
78- // CPU-NEXT: %20 = stablehlo.dynamic_update_slice %iterArg_10, %19, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor<i64>, tensor<i64>) -> tensor<4x3xi64>
79- // CPU-NEXT: stablehlo.return %1, %16, %18, %20 : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
59+ // CPU-NEXT: %1 = stablehlo.add %iterArg, %c_3 : tensor<i64>
60+ // CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c_2 : tensor<i64>
61+ // CPU-NEXT: %3 = stablehlo.divide %iterArg, %c_2 : tensor<i64>
62+ // CPU-NEXT: %4 = stablehlo.dynamic_slice %arg0, %2, %3, %c_7, %c_7, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x64x64xf32>
63+ // CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32>
64+ // CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
65+ // CPU-NEXT: %7 = stablehlo.reshape %6#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32>
66+ // CPU-NEXT: %8 = stablehlo.dynamic_update_slice %iterArg_8, %7, %2, %3, %c_7, %c_7 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64x64xf32>
67+ // CPU-NEXT: %9 = stablehlo.reshape %6#1 : (tensor<64xi64>) -> tensor<1x1x64xi64>
68+ // CPU-NEXT: %10 = stablehlo.dynamic_update_slice %iterArg_9, %9, %2, %3, %c_7 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64xi64>
69+ // CPU-NEXT: %11 = stablehlo.reshape %6#2 : (tensor<i64>) -> tensor<1x1xi64>
70+ // CPU-NEXT: %12 = stablehlo.dynamic_update_slice %iterArg_10, %11, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor<i64>, tensor<i64>) -> tensor<4x3xi64>
71+ // CPU-NEXT: stablehlo.return %1, %8, %10, %12 : tensor<i64>, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
8072// CPU-NEXT: }
8173// CPU-NEXT: stablehlo.return %0#1, %0#2, %0#3 : tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>
8274// CPU-NEXT: }
0 commit comments