Skip to content

Commit cd25294

Browse files
authored
fix: input permutation to lu (#1302)
* fix: input permutation to lu * test: add 3d test
1 parent f255ee6 commit cd25294

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ function _lu_overload(
594594
) where {T,N}
595595
# TODO: don't ignore the check and allowsingular flags
596596
# Batching here is in the last dimensions. `Ops.lu` expects the last dimensions
597-
permdims = vcat(Int64[N - 1, N], collect(Int64, 1:(N - 2)))
597+
permdims = vcat(collect(Int64, 3:N), 1, 2)
598598
A = Ops.transpose(materialize_traced_array(A), permdims)
599599
factors, ipiv, perm, info = Reactant.Ops.lu(A)
600600

test/integration/linear_algebra.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,13 @@ end
370370
@test @jit(solve_with_lu(A_ra, B_ra)) solve_with_lu_batched(A, B)
371371
end
372372
end
373+
374+
@testset "Input Permutation" begin
375+
A = rand(Float32, 10, 10, 32)
376+
B = rand(Float32, 10, 32)
377+
A_ra = Reactant.to_rarray(A)
378+
B_ra = Reactant.to_rarray(B)
379+
380+
@test @jit(solve_with_lu(A_ra, B_ra)) solve_with_lu_batched(A, B)
381+
end
373382
end

0 commit comments

Comments
 (0)