Skip to content

Commit bcb7d56

Browse files
committed
test: update LU tests
1 parent b789363 commit bcb7d56

File tree

3 files changed

+80
-80
lines changed

3 files changed

+80
-80
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ struct LUFactorizationOpLowering
522522

523523
SmallVector<int64_t> permutationShape(inputShape.begin(),
524524
inputShape.end() - 2);
525-
permutationShape.push_back(inputShape[0]);
525+
permutationShape.push_back(inputShape[inputRank - 2]);
526526
auto permutationType =
527527
RankedTensorType::get(permutationShape, rewriter.getI32Type());
528528

test/lit_tests/linalg/lu.mlir

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,32 @@ module {
99
}
1010
}
1111

12-
// CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr)
13-
// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_wrapper_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
14-
// CPU-NEXT: %0 = llvm.mlir.constant(64 : i64) : i64
15-
// CPU-NEXT: %1 = llvm.mlir.constant(1 : i64) : i64
16-
// CPU-NEXT: %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
17-
// CPU-NEXT: %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
18-
// CPU-NEXT: llvm.store %0, %2 : i64, !llvm.ptr
19-
// CPU-NEXT: llvm.store %0, %3 : i64, !llvm.ptr
20-
// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%2, %3, %arg0, %2, %arg1, %arg2) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
21-
// CPU-NEXT: llvm.return
12+
// CPU: func.func private @enzymexla_lapack_sgetrf_[[WRAPPER_ID:[0-9]+]](%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>) {
13+
// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor<i64>
14+
// CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64>
15+
// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor<i64>
16+
// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
17+
// CPU-NEXT: stablehlo.return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xi64>, tensor<i64>
2218
// CPU-NEXT: }
19+
// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr)
2320
// CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor<i32>) {
2421
// CPU-NEXT: %c = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
2522
// CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor<i32>
2623
// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor<i32>
2724
// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor<64xi64>
2825
// CPU-NEXT: %c_3 = stablehlo.constant dense<0> : tensor<i32>
29-
// CPU-NEXT: %c_4 = stablehlo.constant dense<-1> : tensor<64xi64>
30-
// CPU-NEXT: %c_5 = stablehlo.constant dense<-1> : tensor<i64>
31-
// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper_[[WRAPPER_ID]] (%arg0, %c_4, %c_5) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 1, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 2, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
26+
// CPU-NEXT: %0:3 = call @enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0) : (tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
3227
// CPU-NEXT: %1 = stablehlo.subtract %0#1, %c_2 : tensor<64xi64>
33-
// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_6 = %c) : tensor<i32>, tensor<64xi64>
34-
// CPU-NEXT: cond {
28+
// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_4 = %c) : tensor<i32>, tensor<64xi64>
29+
// CPU-NEXT: cond {
3530
// CPU-NEXT: %7 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
3631
// CPU-NEXT: stablehlo.return %7 : tensor<i1>
3732
// CPU-NEXT: } do {
3833
// CPU-NEXT: %7 = stablehlo.add %iterArg, %c_0 : tensor<i32>
3934
// CPU-NEXT: %8 = stablehlo.dynamic_slice %1, %iterArg, sizes = [1] : (tensor<64xi64>, tensor<i32>) -> tensor<1xi64>
40-
// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_6, %iterArg, sizes = [1] : (tensor<64xi64>, tensor<i32>) -> tensor<1xi64>
41-
// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_6, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], start_index_map = [0]>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64>
42-
// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_6, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor<i32>) -> tensor<64xi64>
35+
// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_4, %iterArg, sizes = [1] : (tensor<64xi64>, tensor<i32>) -> tensor<1xi64>
36+
// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_4, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], start_index_map = [0]>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64>
37+
// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor<i32>) -> tensor<64xi64>
4338
// CPU-NEXT: %12 = stablehlo.reshape %9 : (tensor<1xi64>) -> tensor<i64>
4439
// CPU-NEXT: %13 = "stablehlo.scatter"(%11, %8, %12) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = false}> ({
4540
// CPU-NEXT: ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
@@ -77,27 +72,27 @@ module {
7772
// TPU-NEXT: }
7873

7974
module {
75+
// CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_
8076
// CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<i32>) {
8177
func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<i32>) {
82-
// CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_wrapper_[[WRAPPER_ID:[0-9]+]]
8378
%0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
8479
return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor<i32>
8580
}
8681
}
8782

8883
module {
84+
// CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_
8985
// CPU: func.func @main(%arg0: tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>) {
9086
func.func @main(%arg0: tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>) {
91-
// CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_wrapper_[[WRAPPER_ID:[0-9]+]]
9287
%0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
9388
return %0#0, %0#1, %0#3 : tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>
9489
}
9590
}
9691

9792
module {
93+
// CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_
9894
// CPU: func.func @main(%arg0: tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>) {
9995
func.func @main(%arg0: tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>) {
100-
// CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_wrapper_[[WRAPPER_ID:[0-9]+]]
10196
%0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
10297
return %0#0, %0#1, %0#3 : tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>
10398
}

0 commit comments

Comments
 (0)