@@ -22,13 +22,23 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
22
22
return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
23
23
end
24
24
25
+ struct NotIPIV <: AbstractVector{BlasInt} len:: Int end
26
+ Base. size (A:: NotIPIV ) = (A. len,)
27
+ Base. getindex (:: NotIPIV , i:: Int ) = i
28
+ Base. view (:: NotIPIV , r:: AbstractUnitRange ) = NotIPIV (length (r))
29
+ init_pivot (:: Val{false} , minmn) = NotIPIV (minmn)
30
+ init_pivot (:: Val{true} , minmn) = Vector {BlasInt} (undef, minmn)
31
+
32
+
25
33
function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
26
34
m, n = size (A)
27
35
minmn = min (m, n)
28
- F = if minmn < 10 # avx introduces small performance degradation
36
+ # we want the type on both branches to match. When pivot = Val(false), we construct
37
+ # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
38
+ F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
29
39
LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
30
40
else
31
- lu! (A, Vector {BlasInt} (undef , minmn), normalize_pivot (pivot), thread; check = check,
41
+ lu! (A, init_pivot (pivot , minmn), normalize_pivot (pivot), thread; check = check,
32
42
kwargs... )
33
43
end
34
44
return F
@@ -44,6 +54,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
44
54
recurse (:: StridedArray ) = true
45
55
recurse (_) = false
46
56
57
+ _ptrarray (ipiv) = PtrArray (ipiv)
58
+ _ptrarray (ipiv:: NotIPIV ) = ipiv
47
59
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
48
60
pivot = Val (true ), thread = Val (true );
49
61
check:: Bool = true ,
@@ -58,7 +70,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
58
70
if T <: Union{Float32, Float64}
59
71
GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
60
72
m, n, mnmin,
61
- PtrArray (ipiv), info, blocksize,
73
+ _ptrarray (ipiv), info, blocksize,
62
74
thread) end
63
75
else
64
76
info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
@@ -187,8 +199,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187
199
Pivot && apply_permutation! (P2, A21, thread)
188
200
189
201
info != previnfo && (info += n1)
190
- @turbo warn_check_args= false for i in 1 : n2
191
- P2[i] += n1
202
+ if Pivot
203
+ @turbo warn_check_args= false for i in 1 : n2
204
+ P2[i] += n1
205
+ end
192
206
end
193
207
return info
194
208
end # inbounds
@@ -234,8 +248,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234
248
amax = absi
235
249
end
236
250
end
251
+ ipiv[k] = kp
237
252
end
238
- ipiv[k] = kp
239
253
if ! iszero (A[kp, k])
240
254
if k != kp
241
255
# Interchange
0 commit comments