@@ -22,13 +22,46 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
22
22
return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
23
23
end
24
24
25
+ const CUSTOMIZABLE_PIVOT = VERSION >= v " 1.8.0-DEV.1507"
26
+
27
+ struct NotIPIV <: AbstractVector{BlasInt}
28
+ len:: Int
29
+ end
30
+ Base. size (A:: NotIPIV ) = (A. len,)
31
+ Base. getindex (:: NotIPIV , i:: Int ) = i
32
+ Base. view (:: NotIPIV , r:: AbstractUnitRange ) = NotIPIV (length (r))
33
+ function init_pivot (:: Val{false} , minmn)
34
+ @static if CUSTOMIZABLE_PIVOT
35
+ NotIPIV (minmn)
36
+ else
37
+ init_pivot (Val (true ), minmn)
38
+ end
39
+ end
40
+ init_pivot (:: Val{true} , minmn) = Vector {BlasInt} (undef, minmn)
41
+
42
+ if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
43
+ function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44
+ B:: StridedVecOrMat )
45
+ return B
46
+ end
47
+ end
48
+ if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
49
+ function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
50
+ B:: StridedVecOrMat )
51
+ return B
52
+ end
53
+ end
54
+
25
55
function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
26
56
m, n = size (A)
27
57
minmn = min (m, n)
28
- F = if minmn < 10 # avx introduces small performance degradation
58
+ npivot = normalize_pivot (pivot)
59
+ # we want the type on both branches to match. When pivot = Val(false), we construct
60
+ # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
61
+ F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
29
62
LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
30
63
else
31
- lu! (A, Vector {BlasInt} (undef , minmn), normalize_pivot (pivot) , thread; check = check,
64
+ lu! (A, init_pivot (npivot , minmn), npivot , thread; check = check,
32
65
kwargs... )
33
66
end
34
67
return F
@@ -44,6 +77,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
44
77
recurse (:: StridedArray ) = true
45
78
recurse (_) = false
46
79
80
+ _ptrarray (ipiv) = PtrArray (ipiv)
81
+ _ptrarray (ipiv:: NotIPIV ) = ipiv
47
82
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
48
83
pivot = Val (true ), thread = Val (true );
49
84
check:: Bool = true ,
@@ -54,11 +89,14 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
54
89
info = zero (BlasInt)
55
90
m, n = size (A)
56
91
mnmin = min (m, n)
92
+ if pivot === Val (false ) && ! CUSTOMIZABLE_PIVOT
93
+ copyto! (ipiv, 1 : mnmin)
94
+ end
57
95
if recurse (A) && mnmin > threshold
58
96
if T <: Union{Float32, Float64}
59
97
GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
60
98
m, n, mnmin,
61
- PtrArray (ipiv), info, blocksize,
99
+ _ptrarray (ipiv), info, blocksize,
62
100
thread) end
63
101
else
64
102
info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
90
128
# [AL AR]
91
129
AL = @view A[:, 1 : m]
92
130
AR = @view A[:, (m + 1 ): n]
93
- apply_permutation! (ipiv, AR, Val ( Thread))
94
- ldiv! (_unit_lower_triangular (AL), AR, Val ( Thread))
131
+ Pivot && apply_permutation! (ipiv, AR, Val { Thread} ( ))
132
+ ldiv! (_unit_lower_triangular (AL), AR, Val { Thread} ( ))
95
133
end
96
134
info
97
135
end
@@ -187,8 +225,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187
225
Pivot && apply_permutation! (P2, A21, thread)
188
226
189
227
info != previnfo && (info += n1)
190
- @turbo warn_check_args= false for i in 1 : n2
191
- P2[i] += n1
228
+ if Pivot
229
+ @turbo warn_check_args= false for i in 1 : n2
230
+ P2[i] += n1
231
+ end
192
232
end
193
233
return info
194
234
end # inbounds
@@ -234,8 +274,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234
274
amax = absi
235
275
end
236
276
end
277
+ ipiv[k] = kp
237
278
end
238
- ipiv[k] = kp
239
279
if ! iszero (A[kp, k])
240
280
if k != kp
241
281
# Interchange
0 commit comments