1
- using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
1
+ using LoopVectorization
2
+ using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, checknonsingular, BLAS, LinearAlgebra
2
3
3
- function lu (A:: AbstractMatrix , pivot:: Union{Val{false}, Val{true}} = Val (true );
4
- check:: Bool = true , blocksize:: Integer = 16 )
5
- lu! (copy (A), pivot; check = check, blocksize = blocksize)
4
+ function lu (A:: AbstractMatrix , pivot:: Union{Val{false}, Val{true}} = Val (true ); kwargs... )
5
+ return lu! (copy (A), pivot; kwargs... )
6
6
end
7
7
8
- function lu! (A, pivot:: Union{Val{false}, Val{true}} = Val (true );
9
- check:: Bool = true , blocksize:: Integer = 16 )
10
- lu! (A, Vector {BlasInt} (undef, min (size (A)... )), pivot;
11
- check = check, blocksize = blocksize)
8
+ function lu! (A, pivot:: Union{Val{false}, Val{true}} = Val (true ); check= true , kwargs... )
9
+ m, n = size (A)
10
+ minmn = min (m, n)
11
+ F = if minmn < 10 # avx introduces small performance degradation
12
+ LinearAlgebra. generic_lufact! (A, pivot; check= check)
13
+ else
14
+ lu! (A, Vector {BlasInt} (undef, minmn), pivot; check= check, kwargs... )
15
+ end
16
+ return F
12
17
end
13
18
19
+ # Use a function here to make sure it gets optimized away
20
+ # OpenBLAS' TRSM isn't very good, we use a higher threshold for recursion
21
+ pick_threshold () = BLAS. vendor () === :mkl ? 48 : 192
22
+
14
23
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
15
24
pivot:: Union{Val{false}, Val{true}} = Val (true );
16
- check:: Bool = true , blocksize:: Integer = 16 ) where T
17
- info = Ref (zero (BlasInt))
25
+ check:: Bool = true ,
26
+ # the performance is not sensitive wrt blocksize, and 16 is a good default
27
+ blocksize:: Integer = 16 ,
28
+ threshold:: Integer = pick_threshold ()) where T
29
+ info = zero (BlasInt)
18
30
m, n = size (A)
19
31
mnmin = min (m, n)
20
- if T <: BlasFloat && A isa StridedArray
21
- reckernel! (A, pivot, m, mnmin, ipiv, info, blocksize)
32
+ if A isa StridedArray && mnmin > threshold
33
+ info = reckernel! (A, pivot, m, mnmin, ipiv, info, blocksize)
22
34
if m < n # fat matrix
23
35
# [AL AR]
24
36
AL = @view A[:, 1 : m]
25
37
AR = @view A[:, m+ 1 : n]
26
38
apply_permutation! (ipiv, AR)
27
39
ldiv! (UnitLowerTriangular (AL), AR)
28
40
end
29
- else # generic fallback
30
- _generic_lufact! (A, pivot, ipiv, info)
41
+ else # generic fallback
42
+ info = _generic_lufact! (A, pivot, ipiv, info)
31
43
end
32
- check && checknonsingular (info[] )
33
- LU {T, typeof(A)} (A, ipiv, info[] )
44
+ check && checknonsingular (info)
45
+ LU {T, typeof(A)} (A, ipiv, info)
34
46
end
35
47
36
48
function nsplit (:: Type{T} , n) where T
37
- k = 128 ÷ sizeof (T)
49
+ k = 512 ÷ ( isbitstype (T) ? sizeof (T) : 8 )
38
50
k_2 = k ÷ 2
39
51
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
40
52
end
@@ -44,17 +56,19 @@ Base.@propagate_inbounds function apply_permutation!(P, A)
44
56
i′ = P[i]
45
57
i′ == i && continue
46
58
@simd for j in axes (A, 2 )
47
- A[i, j], A[i′, j] = A[i′, j], A[i, j]
59
+ tmp = A[i, j]
60
+ A[i, j] = A[i′, j]
61
+ A[i′, j] = tmp
48
62
end
49
63
end
50
64
nothing
51
65
end
52
66
53
- function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize):: Nothing where {T,Pivot}
67
+ function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize):: BlasInt where {T,Pivot}
54
68
@inbounds begin
55
69
if n <= max (blocksize, 1 )
56
- _generic_lufact! (A, pivot, ipiv, info)
57
- return nothing
70
+ info = _generic_lufact! (A, pivot, ipiv, info)
71
+ return info
58
72
end
59
73
n1 = nsplit (T, n)
60
74
n2 = n - n1
@@ -88,7 +102,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
88
102
# [ A11 ] [ L11 ]
89
103
# P [ ] = [ ] U11
90
104
# [ A21 ] [ L21 ]
91
- reckernel! (AL, pivot, m, n1, P1, info, blocksize)
105
+ info = reckernel! (AL, pivot, m, n1, P1, info, blocksize)
92
106
# [ A12 ] [ P1 ] [ A12 ]
93
107
# [ ] <- [ ] [ ]
94
108
# [ A22 ] [ 0 ] [ A22 ]
@@ -98,22 +112,33 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
98
112
# Schur complement:
99
113
# We have A22 = L21 U12 + A′22, hence
100
114
# A′22 = A22 - L21 U12
101
- BLAS. gemm! (' N' , ' N' , - one (T), A21, A12, one (T), A22)
115
+ # mul!(A22, A21, A12, -one(T), one(T))
116
+ schur_complement! (A22, A21, A12)
102
117
# record info
103
- previnfo = info[]
118
+ previnfo = info
104
119
# P2 A22 = L22 U22
105
- reckernel! (A22, pivot, m2, n2, P2, info, blocksize)
120
+ info = reckernel! (A22, pivot, m2, n2, P2, info, blocksize)
106
121
# A21 <- P2 A21
107
122
Pivot && apply_permutation! (P2, A21)
108
123
109
- info[] != previnfo && (info[] += n1)
110
- @simd for i in 1 : n2
124
+ info != previnfo && (info += n1)
125
+ @avx for i in 1 : n2
111
126
P2[i] += n1
112
127
end
113
- return nothing
128
+ return info
114
129
end # inbounds
115
130
end
116
131
132
+ function schur_complement! (𝐂, 𝐀, 𝐁)
133
+ @avx for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
134
+ 𝐂ₘₙ = zero (eltype (𝐂))
135
+ for k ∈ 1 : size (𝐀,2 )
136
+ 𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
137
+ end
138
+ 𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
139
+ end
140
+ end
141
+
117
142
#=
118
143
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
119
144
License is MIT: https://julialang.org/license
@@ -147,19 +172,19 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
147
172
end
148
173
# Scale first column
149
174
Akkinv = inv (A[k,k])
150
- @simd for i = k+ 1 : m
175
+ @avx for i = k+ 1 : m
151
176
A[i,k] *= Akkinv
152
177
end
153
- elseif info[] == 0
154
- info[] = k
178
+ elseif info == 0
179
+ info = k
155
180
end
156
181
# Update the rest
157
- for j = k+ 1 : n
158
- @simd for i = k+ 1 : m
182
+ @avx for j = k+ 1 : n
183
+ for i = k+ 1 : m
159
184
A[i,j] -= A[i,k]* A[k,j]
160
185
end
161
186
end
162
187
end
163
188
end
164
- return nothing
189
+ return info
165
190
end
0 commit comments