Skip to content

Commit ca26d78

Browse files
committed
Add CUSTOMIZABLE_PIVOT compat
1 parent 811d710 commit ca26d78

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/lu.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,30 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
2222
return lu!(copy(A), normalize_pivot(pivot), thread; kwargs...)
2323
end
2424

25+
const CUSTOMIZABLE_PIVOT = VERSION >= v"1.8.0-DEV.1507"
26+
2527
struct NotIPIV <: AbstractVector{BlasInt}
2628
len::Int
2729
end
2830
Base.size(A::NotIPIV) = (A.len,)
2931
Base.getindex(::NotIPIV, i::Int) = i
3032
Base.view(::NotIPIV, r::AbstractUnitRange) = NotIPIV(length(r))
31-
init_pivot(::Val{false}, minmn) = NotIPIV(minmn)
33+
function init_pivot(::Val{false}, minmn)
34+
@static if CUSTOMIZABLE_PIVOT
35+
NotIPIV(minmn)
36+
else
37+
init_pivot(Val(true), minmn)
38+
end
39+
end
3240
init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
3341

34-
if isdefined(LinearAlgebra, :_ipiv_cols!)
42+
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_cols!)
3543
function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
3644
B::StridedVecOrMat)
3745
return B
3846
end
3947
end
40-
if isdefined(LinearAlgebra, :_ipiv_rows!)
48+
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_rows!)
4149
function LinearAlgebra._ipiv_rows!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
4250
B::StridedVecOrMat)
4351
return B
@@ -47,12 +55,13 @@ end
4755
function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
4856
m, n = size(A)
4957
minmn = min(m, n)
58+
npivot = normalize_pivot(pivot)
5059
# we want the type on both branches to match. When pivot = Val(false), we construct
51-
# a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
60+
# a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
5261
F = if pivot === Val(true) && minmn < 10 # avx introduces small performance degradation
5362
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check = check)
5463
else
55-
lu!(A, init_pivot(pivot, minmn), normalize_pivot(pivot), thread; check = check,
64+
lu!(A, init_pivot(npivot, minmn), npivot, thread; check = check,
5665
kwargs...)
5766
end
5867
return F
@@ -80,6 +89,9 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
8089
info = zero(BlasInt)
8190
m, n = size(A)
8291
mnmin = min(m, n)
92+
if pivot === Val(false) && !CUSTOMIZABLE_PIVOT
93+
copyto!(ipiv, 1:mnmin)
94+
end
8395
if recurse(A) && mnmin > threshold
8496
if T <: Union{Float32, Float64}
8597
GC.@preserve ipiv A begin info = recurse!(view(PtrArray(A), axes(A)...), pivot,
@@ -116,7 +128,7 @@ end
116128
# [AL AR]
117129
AL = @view A[:, 1:m]
118130
AR = @view A[:, (m + 1):n]
119-
apply_permutation!(ipiv, AR, Val{Thread}())
131+
Pivot && apply_permutation!(ipiv, AR, Val{Thread}())
120132
ldiv!(_unit_lower_triangular(AL), AR, Val{Thread}())
121133
end
122134
info

0 commit comments

Comments
 (0)