Skip to content

Commit c1ef162

Browse files
authored
Add ldiv! for LU decomposition (#1532)
1 parent 68db0c7 commit c1ef162

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

lib/cusolver/linalg.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,16 @@ LinearAlgebra.ipiv2perm(v::CuVector{T}, maxi::Integer) where T =
338338

339339
end
340340

341+
function LinearAlgebra.ldiv!(F::LU{T,<:StridedCuMatrix{T}}, B::CuVecOrMat{T}) where {T}
342+
return getrs!('N', F.factors, F.ipiv, B)
343+
end
344+
345+
# LinearAlgebra.jl defines a comparable method with these restrictions on T, so we need
346+
# to define with the same type parameters to resolve method ambiguity for these cases
347+
function LinearAlgebra.ldiv!(F::LU{T,<:StridedCuMatrix{T}}, B::CuVecOrMat{T}) where {T <: Union{Float32, Float64, ComplexF32, ComplexF64}}
348+
return getrs!('N', F.factors, F.ipiv, B)
349+
end
350+
341351
## cholesky
342352

343353
if VERSION >= v"1.8-"

test/cusolver/dense.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,15 @@ k = 1
579579

580580
@test_throws LinearAlgebra.SingularException lu(CUDA.zeros(elty,n,n))
581581
end
582+
@testset "lu ldiv! elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
583+
A = rand(elty, m, m)
584+
B = rand(elty, m, m)
585+
A_d = CuArray(A)
586+
B_d = CuArray(B)
587+
lu_cpu = lu(A)
588+
lu_gpu = lu(A_d)
589+
@test ldiv!(lu_cpu, B) collect(ldiv!(lu_gpu, B_d))
590+
end
582591
end
583592
end
584593

0 commit comments

Comments
 (0)