@@ -14,17 +14,17 @@ if VERSION >= v"1.7.0-DEV.1188"
14
14
to_stdlib_pivot (:: Val{false} ) = LinearAlgebra. NoPivot ()
15
15
end
16
16
17
- function lu (A:: AbstractMatrix , pivot = Val (true ); kwargs... )
18
- return lu! (copy (A), normalize_pivot (pivot); kwargs... )
17
+ function lu (A:: AbstractMatrix , pivot = Val (true ), thread = Val ( true ) ; kwargs... )
18
+ return lu! (copy (A), normalize_pivot (pivot), thread ; kwargs... )
19
19
end
20
20
21
- function lu! (A, pivot = Val (true ); check= true , kwargs... )
21
+ function lu! (A, pivot = Val (true ), thread = Val ( true ) ; check= true , kwargs... )
22
22
m, n = size (A)
23
23
minmn = min (m, n)
24
24
F = if minmn < 10 # avx introduces small performance degradation
25
25
LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check= check)
26
26
else
27
- lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot); check= check, kwargs... )
27
+ lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot), thread ; check= check, kwargs... )
28
28
end
29
29
return F
30
30
end
@@ -46,7 +46,7 @@ recurse(_) = false
46
46
47
47
function lu! (
48
48
A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
49
- pivot = Val (true );
49
+ pivot = Val (true ), thread = Val ( true ) ;
50
50
check:: Bool = true ,
51
51
# the performance is not sensitive wrt blocksize, and 8 is a good default
52
52
blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
@@ -59,10 +59,10 @@ function lu!(
59
59
if recurse (A) && mnmin > threshold
60
60
if T <: Union{Float32,Float64}
61
61
GC. @preserve ipiv A begin
62
- info = recurse! (PtrArray (A), pivot, m, n, mnmin, PtrArray (ipiv), info, blocksize)
62
+ info = recurse! (PtrArray (A), pivot, m, n, mnmin, PtrArray (ipiv), info, blocksize, thread )
63
63
end
64
64
else
65
- info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize)
65
+ info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread )
66
66
end
67
67
else # generic fallback
68
68
info = _generic_lufact! (A, pivot, ipiv, info)
@@ -71,26 +71,36 @@ function lu!(
71
71
LU {T, typeof(A)} (A, ipiv, info)
72
72
end
73
73
74
- @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize) where {Pivot}
75
- thread = length (A) * _sizeof (eltype (A)) > 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
76
- info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, thread)
74
+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{true} ) where {Pivot}
75
+ if length (A) * _sizeof (eltype (A)) > 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
76
+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
77
+ else
78
+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
79
+ end
80
+ end
81
+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{false} ) where {Pivot}
82
+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
83
+ end
84
+ @inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{Thread} ) where {Pivot,Thread}
85
+ info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread))
77
86
@inbounds if m < n # fat matrix
78
87
# [AL AR]
79
88
AL = @view A[:, 1 : m]
80
89
AR = @view A[:, m+ 1 : n]
81
- apply_permutation! (ipiv, AR, thread )
82
- ldiv! (UnitLowerTriangular (AL), AR)
90
+ apply_permutation! (ipiv, AR, Val (Thread) )
91
+ ldiv! (UnitLowerTriangular (AL), AR, Val (Thread) )
83
92
end
84
93
info
85
94
end
86
95
96
+
87
97
@inline function nsplit (:: Type{T} , n) where T
88
98
k = 512 ÷ (isbitstype (T) ? sizeof (T) : 8 )
89
99
k_2 = k ÷ 2
90
100
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
91
101
end
92
102
93
- function apply_permutation_threaded ! (P, A)
103
+ function apply_permutation ! (P, A, :: Val{true} )
94
104
batchsize = cld (2000 , length (P))
95
105
@batch minbatch= batchsize for j in axes (A, 2 )
96
106
@inbounds for i in axes (P, 1 )
@@ -103,9 +113,7 @@ function apply_permutation_threaded!(P, A)
103
113
nothing
104
114
end
105
115
_sizeof (:: Type{T} ) where {T} = Base. isbitstype (T) ? sizeof (T) : sizeof (Int)
106
- Base. @propagate_inbounds function apply_permutation! (P, A, thread)
107
- thread && return apply_permutation_threaded! (P, A)
108
- # length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1)) && return apply_permutation_threaded!(P, A)
116
+ Base. @propagate_inbounds function apply_permutation! (P, A, :: Val{false} )
109
117
for i in axes (P, 1 )
110
118
i′ = P[i]
111
119
i′ == i && continue
@@ -162,7 +170,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
162
170
# [ A22 ] [ 0 ] [ A22 ]
163
171
Pivot && apply_permutation! (P1, AR, thread)
164
172
# A12 = L11 U12 => U12 = L11 \ A12
165
- ldiv! (UnitLowerTriangular (A11), A12)
173
+ ldiv! (UnitLowerTriangular (A11), A12, thread )
166
174
# Schur complement:
167
175
# We have A22 = L21 U12 + A′22, hence
168
176
# A′22 = A22 - L21 U12
@@ -176,7 +184,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
176
184
Pivot && apply_permutation! (P2, A21, thread)
177
185
178
186
info != previnfo && (info += n1)
179
- @avx for i in 1 : n2
187
+ @turbo warn_check_args = false for i in 1 : n2
180
188
P2[i] += n1
181
189
end
182
190
return info
@@ -226,15 +234,15 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
226
234
end
227
235
# Scale first column
228
236
Akkinv = inv (A[k,k])
229
- @avx check_empty= true for i = k+ 1 : m
237
+ @turbo check_empty= true warn_check_args = false for i = k+ 1 : m
230
238
A[i,k] *= Akkinv
231
239
end
232
240
elseif info == 0
233
241
info = k
234
242
end
235
243
k == minmn && break
236
244
# Update the rest
237
- @avx for j = k+ 1 : n
245
+ @turbo warn_check_args = false for j = k+ 1 : n
238
246
for i = k+ 1 : m
239
247
A[i,j] -= A[i,k]* A[k,j]
240
248
end
0 commit comments