Skip to content

Commit b642d9d

Browse files
committed
Make lu more generic
1 parent ed26db8 commit b642d9d

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

src/lu.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
using LinearAlgebra: BlasInt, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
1+
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
22

3-
lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true);
4-
check::Bool = true, blocksize::Integer = 16) = lu!(copy(A), pivot;
5-
check = check, blocksize = blocksize)
3+
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true);
4+
check::Bool = true, blocksize::Integer = 16)
5+
lu!(copy(A), pivot; check = check, blocksize = blocksize)
6+
end
67

7-
lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true);
8-
check::Bool = true, blocksize::Integer = 16) = lu!(copy(A), Vector{BlasInt}(undef, min(size(A)...)), pivot;
9-
check = check, blocksize = blocksize)
8+
function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true);
9+
check::Bool = true, blocksize::Integer = 16)
10+
lu!(A, Vector{BlasInt}(undef, min(size(A)...)), pivot;
11+
check = check, blocksize = blocksize)
12+
end
1013

1114
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
1215
pivot::Union{Val{false}, Val{true}} = Val(true);
1316
check::Bool=true, blocksize::Integer=16) where T
1417
info = Ref(zero(BlasInt))
1518
m, n = size(A)
1619
mnmin = min(m, n)
17-
if isconcretetype(T)
20+
if T <: BlasFloat && A isa StridedArray
1821
reckernel!(A, pivot, m, mnmin, ipiv, info, blocksize)
1922
if m < n # fat matrix
2023
# [AL AR]
@@ -115,15 +118,15 @@ end
115118
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
116119
License is MIT: https://julialang.org/license
117120
=#
118-
function _generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot}, ipiv, info) where {Pivot,T}
121+
function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
119122
m, n = size(A)
120123
minmn = length(ipiv)
121124
@inbounds begin
122125
for k = 1:minmn
123126
# find index max
124127
kp = k
125128
if Pivot
126-
amax = abs(zero(T))
129+
amax = abs(zero(eltype(A)))
127130
for i = k:m
128131
absi = abs(A[i,k])
129132
if absi > amax

0 commit comments

Comments
 (0)