Skip to content

Commit 9dd99e5

Browse files
authored
Revert "avoid allocating ipiv when not pivoting"
1 parent 973cf28 commit 9dd99e5

File tree

1 file changed

+8
-36
lines changed

1 file changed

+8
-36
lines changed

src/lu.jl

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,13 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
2222
return lu!(copy(A), normalize_pivot(pivot), thread; kwargs...)
2323
end
2424

25-
struct NotIPIV <: AbstractVector{BlasInt}
26-
len::Int
27-
end
28-
Base.size(A::NotIPIV) = (A.len,)
29-
Base.getindex(::NotIPIV, i::Int) = i
30-
Base.view(::NotIPIV, r::AbstractUnitRange) = NotIPIV(length(r))
31-
init_pivot(::Val{false}, minmn) = NotIPIV(minmn)
32-
init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
33-
34-
if isdefined(LinearAlgebra, :_ipiv_cols!)
35-
function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
36-
B::StridedVecOrMat)
37-
return B
38-
end
39-
end
40-
if isdefined(LinearAlgebra, :_ipiv_rows!)
41-
function LinearAlgebra._ipiv_rows!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
42-
B::StridedVecOrMat)
43-
return B
44-
end
45-
end
46-
4725
function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
4826
m, n = size(A)
4927
minmn = min(m, n)
50-
# we want the type on both branches to match. When pivot = Val(false), we construct
51-
# a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
52-
F = if pivot === Val(true) && minmn < 10 # avx introduces small performance degradation
28+
F = if minmn < 10 # avx introduces small performance degradation
5329
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check = check)
5430
else
55-
lu!(A, init_pivot(pivot, minmn), normalize_pivot(pivot), thread; check = check,
31+
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot), thread; check = check,
5632
kwargs...)
5733
end
5834
return F
@@ -68,8 +44,6 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
6844
recurse(::StridedArray) = true
6945
recurse(_) = false
7046

71-
_ptrarray(ipiv) = PtrArray(ipiv)
72-
_ptrarray(ipiv::NotIPIV) = ipiv
7347
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
7448
pivot = Val(true), thread = Val(true);
7549
check::Bool = true,
@@ -84,7 +58,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
8458
if T <: Union{Float32, Float64}
8559
GC.@preserve ipiv A begin info = recurse!(view(PtrArray(A), axes(A)...), pivot,
8660
m, n, mnmin,
87-
_ptrarray(ipiv), info, blocksize,
61+
PtrArray(ipiv), info, blocksize,
8862
thread) end
8963
else
9064
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
@@ -116,8 +90,8 @@ end
11690
# [AL AR]
11791
AL = @view A[:, 1:m]
11892
AR = @view A[:, (m + 1):n]
119-
apply_permutation!(ipiv, AR, Val{Thread}())
120-
ldiv!(_unit_lower_triangular(AL), AR, Val{Thread}())
93+
apply_permutation!(ipiv, AR, Val(Thread))
94+
ldiv!(_unit_lower_triangular(AL), AR, Val(Thread))
12195
end
12296
info
12397
end
@@ -213,10 +187,8 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
213187
Pivot && apply_permutation!(P2, A21, thread)
214188

215189
info != previnfo && (info += n1)
216-
if Pivot
217-
@turbo warn_check_args=false for i in 1:n2
218-
P2[i] += n1
219-
end
190+
@turbo warn_check_args=false for i in 1:n2
191+
P2[i] += n1
220192
end
221193
return info
222194
end # inbounds
@@ -262,8 +234,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
262234
amax = absi
263235
end
264236
end
265-
ipiv[k] = kp
266237
end
238+
ipiv[k] = kp
267239
if !iszero(A[kp, k])
268240
if k != kp
269241
# Interchange

0 commit comments

Comments
 (0)