1
1
using LoopVectorization
2
- using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, checknonsingular, BLAS, LinearAlgebra
2
+ using TriangularSolve: ldiv!
3
+ using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS, LinearAlgebra, Adjoint, Transpose
4
+ using StrideArraysCore
5
+ using Polyester: @batch
3
6
4
7
# 1.7 compat
5
8
normalize_pivot (t:: Val{T} ) where T = t
@@ -26,43 +29,40 @@ function lu!(A, pivot = Val(true); check=true, kwargs...)
26
29
return F
27
30
end
28
31
32
+ for (f, T) in [(:adjoint , :Adjoint ), (:transpose , :Transpose )], lu in (:lu , :lu! )
33
+ @eval $ lu (A:: $T , args... ; kwargs... ) = $ f ($ lu (parent (A), args... ; kwargs... ))
34
+ end
35
+
29
36
const RECURSION_THRESHOLD = Ref (- 1 )
30
37
31
38
# AVX512 needs a smaller recursion limit
32
39
function pick_threshold ()
33
40
RECURSION_THRESHOLD[] >= 0 && return RECURSION_THRESHOLD[]
34
- blasvendor = @static if VERSION >= v " 1.7.0-DEV.610"
35
- :openblas64
36
- else
37
- BLAS. vendor ()
38
- end
39
- if blasvendor === :openblas || blasvendor === :openblas64
40
- LoopVectorization. register_size () == 64 ? 110 : 72
41
- else
42
- LoopVectorization. register_size () == 64 ? 48 : 72
43
- end
41
+ LoopVectorization. register_size () == 64 ? 48 : 40
44
42
end
45
43
44
+ recurse (:: StridedArray ) = true
45
+ recurse (_) = false
46
+
46
47
function lu! (
47
48
A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
48
49
pivot = Val (true );
49
50
check:: Bool = true ,
50
- # the performance is not sensitive wrt blocksize, and 16 is a good default
51
- blocksize:: Integer = 16 ,
51
+ # the performance is not sensitive wrt blocksize, and 8 is a good default
52
+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
52
53
threshold:: Integer = pick_threshold ()
53
54
) where T
54
55
pivot = normalize_pivot (pivot)
55
56
info = zero (BlasInt)
56
57
m, n = size (A)
57
58
mnmin = min (m, n)
58
- if A isa StridedArray && mnmin > threshold
59
- info = reckernel! (A, pivot, m, mnmin, ipiv, info, blocksize)
60
- if m < n # fat matrix
61
- # [AL AR]
62
- AL = @view A[:, 1 : m]
63
- AR = @view A[:, m+ 1 : n]
64
- apply_permutation! (ipiv, AR)
65
- ldiv! (UnitLowerTriangular (AL), AR)
59
+ if recurse (A) && mnmin > threshold
60
+ if T <: Union{Float32,Float64}
61
+ GC. @preserve ipiv A begin
62
+ info = recurse! (PtrArray (A), pivot, m, n, mnmin, PtrArray (ipiv), info, blocksize)
63
+ end
64
+ else
65
+ info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize)
66
66
end
67
67
else # generic fallback
68
68
info = _generic_lufact! (A, pivot, ipiv, info)
@@ -71,13 +71,41 @@ function lu!(
71
71
LU {T, typeof(A)} (A, ipiv, info)
72
72
end
73
73
74
- function nsplit (:: Type{T} , n) where T
74
+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize) where {Pivot}
75
+ thread = length (A) * _sizeof (eltype (A)) > 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
76
+ info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, thread)
77
+ @inbounds if m < n # fat matrix
78
+ # [AL AR]
79
+ AL = @view A[:, 1 : m]
80
+ AR = @view A[:, m+ 1 : n]
81
+ apply_permutation! (ipiv, AR, thread)
82
+ ldiv! (UnitLowerTriangular (AL), AR)
83
+ end
84
+ info
85
+ end
86
+
87
+ @inline function nsplit (:: Type{T} , n) where T
75
88
k = 512 ÷ (isbitstype (T) ? sizeof (T) : 8 )
76
89
k_2 = k ÷ 2
77
90
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
78
91
end
79
92
80
- Base. @propagate_inbounds function apply_permutation! (P, A)
93
+ function apply_permutation_threaded! (P, A)
94
+ batchsize = cld (2000 , length (P))
95
+ @batch minbatch= batchsize for j in axes (A, 2 )
96
+ @inbounds @simd ivdep for i in axes (P, 1 )
97
+ i′ = P[i]
98
+ tmp = A[i, j]
99
+ A[i, j] = A[i′, j]
100
+ A[i′, j] = tmp
101
+ end
102
+ end
103
+ nothing
104
+ end
105
+ _sizeof (:: Type{T} ) where {T} = Base. isbitstype (T) ? sizeof (T) : sizeof (Int)
106
+ Base. @propagate_inbounds function apply_permutation! (P, A, thread)
107
+ thread && return apply_permutation_threaded! (P, A)
108
+ # length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1)) && return apply_permutation_threaded!(P, A)
81
109
for i in axes (P, 1 )
82
110
i′ = P[i]
83
111
i′ == i && continue
@@ -90,10 +118,10 @@ Base.@propagate_inbounds function apply_permutation!(P, A)
90
118
nothing
91
119
end
92
120
93
- function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize):: BlasInt where {T,Pivot}
121
+ function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize, thread ):: BlasInt where {T,Pivot}
94
122
@inbounds begin
95
123
if n <= max (blocksize, 1 )
96
- info = _generic_lufact! (A, pivot , ipiv, info)
124
+ info = _generic_lufact! (A, Val (Pivot) , ipiv, info)
97
125
return info
98
126
end
99
127
n1 = nsplit (T, n)
@@ -128,11 +156,11 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
128
156
# [ A11 ] [ L11 ]
129
157
# P [ ] = [ ] U11
130
158
# [ A21 ] [ L21 ]
131
- info = reckernel! (AL, pivot , m, n1, P1, info, blocksize)
159
+ info = reckernel! (AL, Val (Pivot) , m, n1, P1, info, blocksize, thread )
132
160
# [ A12 ] [ P1 ] [ A12 ]
133
161
# [ ] <- [ ] [ ]
134
162
# [ A22 ] [ 0 ] [ A22 ]
135
- Pivot && apply_permutation! (P1, AR)
163
+ Pivot && apply_permutation! (P1, AR, thread )
136
164
# A12 = L11 U12 => U12 = L11 \ A12
137
165
ldiv! (UnitLowerTriangular (A11), A12)
138
166
# Schur complement:
@@ -143,9 +171,9 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
143
171
# record info
144
172
previnfo = info
145
173
# P2 A22 = L22 U22
146
- info = reckernel! (A22, pivot , m2, n2, P2, info, blocksize)
174
+ info = reckernel! (A22, Val (Pivot) , m2, n2, P2, info, blocksize, thread )
147
175
# A21 <- P2 A21
148
- Pivot && apply_permutation! (P2, A21)
176
+ Pivot && apply_permutation! (P2, A21, thread )
149
177
150
178
info != previnfo && (info += n1)
151
179
@avx for i in 1 : n2
@@ -156,7 +184,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
156
184
end
157
185
158
186
function schur_complement! (𝐂, 𝐀, 𝐁)
159
- @avx for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
187
+ @tturbo for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
160
188
𝐂ₘₙ = zero (eltype (𝐂))
161
189
for k ∈ 1 : size (𝐀,2 )
162
190
𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
0 commit comments