|
1 |
| -using LinearAlgebra: BlasInt, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular |
| 1 | +using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular |
2 | 2 |
|
3 |
| -lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); |
4 |
| - check::Bool = true, blocksize::Integer = 16) = lu!(copy(A), pivot; |
5 |
| - check = check, blocksize = blocksize) |
| 3 | +function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); |
| 4 | + check::Bool = true, blocksize::Integer = 16) |
| 5 | + lu!(copy(A), pivot; check = check, blocksize = blocksize) |
| 6 | +end |
6 | 7 |
|
7 |
| -lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true); |
8 |
| - check::Bool = true, blocksize::Integer = 16) = lu!(copy(A), Vector{BlasInt}(undef, min(size(A)...)), pivot; |
9 |
| - check = check, blocksize = blocksize) |
| 8 | +function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true); |
| 9 | + check::Bool = true, blocksize::Integer = 16) |
| 10 | + lu!(A, Vector{BlasInt}(undef, min(size(A)...)), pivot; |
| 11 | + check = check, blocksize = blocksize) |
| 12 | +end |
10 | 13 |
|
11 | 14 | function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
|
12 | 15 | pivot::Union{Val{false}, Val{true}} = Val(true);
|
13 | 16 | check::Bool=true, blocksize::Integer=16) where T
|
14 | 17 | info = Ref(zero(BlasInt))
|
15 | 18 | m, n = size(A)
|
16 | 19 | mnmin = min(m, n)
|
17 |
| - if isconcretetype(T) |
| 20 | + if T <: BlasFloat && A isa StridedArray |
18 | 21 | reckernel!(A, pivot, m, mnmin, ipiv, info, blocksize)
|
19 | 22 | if m < n # fat matrix
|
20 | 23 | # [AL AR]
|
@@ -115,15 +118,15 @@ end
|
115 | 118 | Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
|
116 | 119 | License is MIT: https://julialang.org/license
|
117 | 120 | =#
|
118 |
| -function _generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot}, ipiv, info) where {Pivot,T} |
| 121 | +function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot |
119 | 122 | m, n = size(A)
|
120 | 123 | minmn = length(ipiv)
|
121 | 124 | @inbounds begin
|
122 | 125 | for k = 1:minmn
|
123 | 126 | # find index max
|
124 | 127 | kp = k
|
125 | 128 | if Pivot
|
126 |
| - amax = abs(zero(T)) |
| 129 | + amax = abs(zero(eltype(A))) |
127 | 130 | for i = k:m
|
128 | 131 | absi = abs(A[i,k])
|
129 | 132 | if absi > amax
|
|
0 commit comments