Skip to content

Commit 811d710

Browse files
committed
Revert "Revert "avoid allocating ipiv when not pivoting""
This reverts commit 9dd99e5.
1 parent 0544d01 commit 811d710

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

src/lu.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,37 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
2222
return lu!(copy(A), normalize_pivot(pivot), thread; kwargs...)
2323
end
2424

25+
struct NotIPIV <: AbstractVector{BlasInt}
26+
len::Int
27+
end
28+
Base.size(A::NotIPIV) = (A.len,)
29+
Base.getindex(::NotIPIV, i::Int) = i
30+
Base.view(::NotIPIV, r::AbstractUnitRange) = NotIPIV(length(r))
31+
init_pivot(::Val{false}, minmn) = NotIPIV(minmn)
32+
init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
33+
34+
if isdefined(LinearAlgebra, :_ipiv_cols!)
35+
function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
36+
B::StridedVecOrMat)
37+
return B
38+
end
39+
end
40+
if isdefined(LinearAlgebra, :_ipiv_rows!)
41+
function LinearAlgebra._ipiv_rows!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
42+
B::StridedVecOrMat)
43+
return B
44+
end
45+
end
46+
2547
function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
2648
m, n = size(A)
2749
minmn = min(m, n)
28-
F = if minmn < 10 # avx introduces small performance degradation
50+
# we want the type on both branches to match. When pivot = Val(false), we construct
51+
# a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
52+
F = if pivot === Val(true) && minmn < 10 # avx introduces small performance degradation
2953
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check = check)
3054
else
31-
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot), thread; check = check,
55+
lu!(A, init_pivot(pivot, minmn), normalize_pivot(pivot), thread; check = check,
3256
kwargs...)
3357
end
3458
return F
@@ -44,6 +68,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
4468
recurse(::StridedArray) = true
4569
recurse(_) = false
4670

71+
_ptrarray(ipiv) = PtrArray(ipiv)
72+
_ptrarray(ipiv::NotIPIV) = ipiv
4773
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
4874
pivot = Val(true), thread = Val(true);
4975
check::Bool = true,
@@ -58,7 +84,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
5884
if T <: Union{Float32, Float64}
5985
GC.@preserve ipiv A begin info = recurse!(view(PtrArray(A), axes(A)...), pivot,
6086
m, n, mnmin,
61-
PtrArray(ipiv), info, blocksize,
87+
_ptrarray(ipiv), info, blocksize,
6288
thread) end
6389
else
6490
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
@@ -90,8 +116,8 @@ end
90116
# [AL AR]
91117
AL = @view A[:, 1:m]
92118
AR = @view A[:, (m + 1):n]
93-
apply_permutation!(ipiv, AR, Val(Thread))
94-
ldiv!(_unit_lower_triangular(AL), AR, Val(Thread))
119+
apply_permutation!(ipiv, AR, Val{Thread}())
120+
ldiv!(_unit_lower_triangular(AL), AR, Val{Thread}())
95121
end
96122
info
97123
end
@@ -187,8 +213,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187213
Pivot && apply_permutation!(P2, A21, thread)
188214

189215
info != previnfo && (info += n1)
190-
@turbo warn_check_args=false for i in 1:n2
191-
P2[i] += n1
216+
if Pivot
217+
@turbo warn_check_args=false for i in 1:n2
218+
P2[i] += n1
219+
end
192220
end
193221
return info
194222
end # inbounds
@@ -234,8 +262,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234262
amax = absi
235263
end
236264
end
265+
ipiv[k] = kp
237266
end
238-
ipiv[k] = kp
239267
if !iszero(A[kp, k])
240268
if k != kp
241269
# Interchange

0 commit comments

Comments
 (0)