@@ -22,13 +22,37 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
22
22
return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
23
23
end
24
24
25
+ struct NotIPIV <: AbstractVector{BlasInt}
26
+ len:: Int
27
+ end
28
+ Base. size (A:: NotIPIV ) = (A. len,)
29
+ Base. getindex (:: NotIPIV , i:: Int ) = i
30
+ Base. view (:: NotIPIV , r:: AbstractUnitRange ) = NotIPIV (length (r))
31
+ init_pivot (:: Val{false} , minmn) = NotIPIV (minmn)
32
+ init_pivot (:: Val{true} , minmn) = Vector {BlasInt} (undef, minmn)
33
+
34
+ if isdefined (LinearAlgebra, :_ipiv_cols! )
35
+ function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
36
+ B:: StridedVecOrMat )
37
+ return B
38
+ end
39
+ end
40
+ if isdefined (LinearAlgebra, :_ipiv_rows! )
41
+ function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
42
+ B:: StridedVecOrMat )
43
+ return B
44
+ end
45
+ end
46
+
25
47
function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
26
48
m, n = size (A)
27
49
minmn = min (m, n)
28
- F = if minmn < 10 # avx introduces small performance degradation
50
+ # we want the type on both branches to match. When pivot = Val(false), we construct
51
+ # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
52
+ F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
29
53
LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
30
54
else
31
- lu! (A, Vector {BlasInt} (undef , minmn), normalize_pivot (pivot), thread; check = check,
55
+ lu! (A, init_pivot (pivot , minmn), normalize_pivot (pivot), thread; check = check,
32
56
kwargs... )
33
57
end
34
58
return F
@@ -44,6 +68,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
44
68
recurse (:: StridedArray ) = true
45
69
recurse (_) = false
46
70
71
+ _ptrarray (ipiv) = PtrArray (ipiv)
72
+ _ptrarray (ipiv:: NotIPIV ) = ipiv
47
73
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
48
74
pivot = Val (true ), thread = Val (true );
49
75
check:: Bool = true ,
@@ -58,7 +84,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
58
84
if T <: Union{Float32, Float64}
59
85
GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
60
86
m, n, mnmin,
61
- PtrArray (ipiv), info, blocksize,
87
+ _ptrarray (ipiv), info, blocksize,
62
88
thread) end
63
89
else
64
90
info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
90
116
# [AL AR]
91
117
AL = @view A[:, 1 : m]
92
118
AR = @view A[:, (m + 1 ): n]
93
- apply_permutation! (ipiv, AR, Val ( Thread))
94
- ldiv! (_unit_lower_triangular (AL), AR, Val ( Thread))
119
+ apply_permutation! (ipiv, AR, Val { Thread} ( ))
120
+ ldiv! (_unit_lower_triangular (AL), AR, Val { Thread} ( ))
95
121
end
96
122
info
97
123
end
@@ -187,8 +213,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187
213
Pivot && apply_permutation! (P2, A21, thread)
188
214
189
215
info != previnfo && (info += n1)
190
- @turbo warn_check_args= false for i in 1 : n2
191
- P2[i] += n1
216
+ if Pivot
217
+ @turbo warn_check_args= false for i in 1 : n2
218
+ P2[i] += n1
219
+ end
192
220
end
193
221
return info
194
222
end # inbounds
@@ -234,8 +262,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234
262
amax = absi
235
263
end
236
264
end
265
+ ipiv[k] = kp
237
266
end
238
- ipiv[k] = kp
239
267
if ! iszero (A[kp, k])
240
268
if k != kp
241
269
# Interchange
0 commit comments