Skip to content

Commit 2391c5e

Browse files
committed
avoid allocating ipiv when not pivoting
1 parent b33b0e9 commit 2391c5e

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

src/lu.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,23 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
2222
return lu!(copy(A), normalize_pivot(pivot), thread; kwargs...)
2323
end
2424

25+
struct NotIPIV <: AbstractVector{BlasInt} len::Int end
26+
Base.size(A::NotIPIV) = (A.len,)
27+
Base.getindex(::NotIPIV, i::Int) = i
28+
Base.view(::NotIPIV, r::AbstractUnitRange) = NotIPIV(length(r))
29+
init_pivot(::Val{false}, minmn) = NotIPIV(minmn)
30+
init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
31+
32+
2533
function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
2634
m, n = size(A)
2735
minmn = min(m, n)
28-
F = if minmn < 10 # avx introduces small performance degradation
36+
# we want the type on both branches to match. When pivot = Val(false), we construct
37+
# a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
38+
F = if pivot === Val(true) && minmn < 10 # avx introduces small performance degradation
2939
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check = check)
3040
else
31-
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot), thread; check = check,
41+
lu!(A, init_pivot(pivot, minmn), normalize_pivot(pivot), thread; check = check,
3242
kwargs...)
3343
end
3444
return F
@@ -44,6 +54,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
4454
recurse(::StridedArray) = true
4555
recurse(_) = false
4656

57+
_ptrarray(ipiv) = PtrArray(ipiv)
58+
_ptrarray(ipiv::NotIPIV) = ipiv
4759
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
4860
pivot = Val(true), thread = Val(true);
4961
check::Bool = true,
@@ -58,7 +70,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
5870
if T <: Union{Float32, Float64}
5971
GC.@preserve ipiv A begin info = recurse!(view(PtrArray(A), axes(A)...), pivot,
6072
m, n, mnmin,
61-
PtrArray(ipiv), info, blocksize,
73+
_ptrarray(ipiv), info, blocksize,
6274
thread) end
6375
else
6476
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
@@ -187,8 +199,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187199
Pivot && apply_permutation!(P2, A21, thread)
188200

189201
info != previnfo && (info += n1)
190-
@turbo warn_check_args=false for i in 1:n2
191-
P2[i] += n1
202+
if Pivot
203+
@turbo warn_check_args=false for i in 1:n2
204+
P2[i] += n1
205+
end
192206
end
193207
return info
194208
end # inbounds
@@ -234,8 +248,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234248
amax = absi
235249
end
236250
end
251+
ipiv[k] = kp
237252
end
238-
ipiv[k] = kp
239253
if !iszero(A[kp, k])
240254
if k != kp
241255
# Interchange

0 commit comments

Comments
 (0)