1
- using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
1
+ using LoopVectorization: @avx
2
+ using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, mul!, checknonsingular
2
3
3
- function lu (A:: AbstractMatrix , pivot:: Union{Val{false}, Val{true}} = Val (true );
4
- check:: Bool = true , blocksize:: Integer = 16 )
5
- lu! (copy (A), pivot; check = check, blocksize = blocksize)
4
+ function lu (A:: AbstractMatrix , pivot:: Union{Val{false}, Val{true}} = Val (true ); kwargs... )
5
+ lu! (copy (A), pivot; kwargs... )
6
6
end
7
7
8
- function lu! (A, pivot:: Union{Val{false}, Val{true}} = Val (true );
9
- check:: Bool = true , blocksize:: Integer = 16 )
10
- lu! (A, Vector {BlasInt} (undef, min (size (A)... )), pivot;
11
- check = check, blocksize = blocksize)
8
+ function lu! (A, pivot:: Union{Val{false}, Val{true}} = Val (true ); kwargs... )
9
+ lu! (A, Vector {BlasInt} (undef, min (size (A)... )), pivot; kwargs... )
12
10
end
13
11
14
12
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
15
13
pivot:: Union{Val{false}, Val{true}} = Val (true );
16
- check:: Bool = true , blocksize:: Integer = 16 ) where T
14
+ check:: Bool = true , blocksize:: Integer = 16 , threshold :: Integer = 192 ) where T
17
15
info = Ref (zero (BlasInt))
18
16
m, n = size (A)
19
17
mnmin = min (m, n)
20
- if T <: BlasFloat && A isa StridedArray
18
+ if A isa StridedArray && mnmin > threshold
21
19
reckernel! (A, pivot, m, mnmin, ipiv, info, blocksize)
22
20
if m < n # fat matrix
23
21
# [AL AR]
@@ -34,7 +32,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
34
32
end
35
33
36
34
function nsplit (:: Type{T} , n) where T
37
- k = 128 ÷ sizeof (T)
35
+ k = 512 ÷ ( isbitstype (T) ? sizeof (T) : 8 )
38
36
k_2 = k ÷ 2
39
37
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
40
38
end
@@ -44,7 +42,9 @@ Base.@propagate_inbounds function apply_permutation!(P, A)
44
42
i′ = P[i]
45
43
i′ == i && continue
46
44
@simd for j in axes (A, 2 )
47
- A[i, j], A[i′, j] = A[i′, j], A[i, j]
45
+ tmp = A[i, j]
46
+ A[i, j] = A[i′, j]
47
+ A[i′, j] = tmp
48
48
end
49
49
end
50
50
nothing
@@ -98,7 +98,8 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
98
98
# Schur complement:
99
99
# We have A22 = L21 U12 + A′22, hence
100
100
# A′22 = A22 - L21 U12
101
- BLAS. gemm! (' N' , ' N' , - one (T), A21, A12, one (T), A22)
101
+ # mul!(A22, A21, A12, -one(T), one(T))
102
+ schur_complement! (A22, A21, A12)
102
103
# record info
103
104
previnfo = info[]
104
105
# P2 A22 = L22 U22
@@ -107,13 +108,23 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
107
108
Pivot && apply_permutation! (P2, A21)
108
109
109
110
info[] != previnfo && (info[] += n1)
110
- @simd for i in 1 : n2
111
+ @avx for i in 1 : n2
111
112
P2[i] += n1
112
113
end
113
114
return nothing
114
115
end # inbounds
115
116
end
116
117
118
+ function schur_complement! (𝐂, 𝐀, 𝐁)
119
+ @avx for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
120
+ 𝐂ₘₙ = zero (eltype (𝐂))
121
+ for k ∈ 1 : size (𝐀,2 )
122
+ 𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
123
+ end
124
+ 𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
125
+ end
126
+ end
127
+
117
128
#=
118
129
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
119
130
License is MIT: https://julialang.org/license
@@ -147,15 +158,15 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
147
158
end
148
159
# Scale first column
149
160
Akkinv = inv (A[k,k])
150
- @simd for i = k+ 1 : m
161
+ @avx for i = k+ 1 : m
151
162
A[i,k] *= Akkinv
152
163
end
153
164
elseif info[] == 0
154
165
info[] = k
155
166
end
156
167
# Update the rest
157
- for j = k+ 1 : n
158
- @simd for i = k+ 1 : m
168
+ @avx for j = k+ 1 : n
169
+ for i = k+ 1 : m
159
170
A[i,j] -= A[i,k]* A[k,j]
160
171
end
161
172
end
0 commit comments