Skip to content

Commit 3f73b45

Browse files
authored
Merge pull request #39 from chriselrod/optionalthreading
Add optional threading, and disable turbo type warnings
2 parents f8cceb4 + 7238894 commit 3f73b45

File tree

3 files changed

+32
-21
lines changed

3 files changed

+32
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveFactorization"
22
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
33
authors = ["Yingbo Ma <[email protected]>"]
4-
version = "0.2.7"
4+
version = "0.2.8"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/lu.jl

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@ if VERSION >= v"1.7.0-DEV.1188"
1414
to_stdlib_pivot(::Val{false}) = LinearAlgebra.NoPivot()
1515
end
1616

17-
function lu(A::AbstractMatrix, pivot = Val(true); kwargs...)
18-
return lu!(copy(A), normalize_pivot(pivot); kwargs...)
17+
function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
18+
return lu!(copy(A), normalize_pivot(pivot), thread; kwargs...)
1919
end
2020

21-
function lu!(A, pivot = Val(true); check=true, kwargs...)
21+
function lu!(A, pivot = Val(true), thread = Val(true); check=true, kwargs...)
2222
m, n = size(A)
2323
minmn = min(m, n)
2424
F = if minmn < 10 # avx introduces small performance degradation
2525
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check=check)
2626
else
27-
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot); check=check, kwargs...)
27+
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot), thread; check=check, kwargs...)
2828
end
2929
return F
3030
end
@@ -46,7 +46,7 @@ recurse(_) = false
4646

4747
function lu!(
4848
A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
49-
pivot = Val(true);
49+
pivot = Val(true), thread = Val(true);
5050
check::Bool=true,
5151
# the performance is not sensitive wrt blocksize, and 8 is a good default
5252
blocksize::Integer=length(A) 40_000 ? 8 : 16,
@@ -59,10 +59,10 @@ function lu!(
5959
if recurse(A) && mnmin > threshold
6060
if T <: Union{Float32,Float64}
6161
GC.@preserve ipiv A begin
62-
info = recurse!(PtrArray(A), pivot, m, n, mnmin, PtrArray(ipiv), info, blocksize)
62+
info = recurse!(PtrArray(A), pivot, m, n, mnmin, PtrArray(ipiv), info, blocksize, thread)
6363
end
6464
else
65-
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize)
65+
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
6666
end
6767
else # generic fallback
6868
info = _generic_lufact!(A, pivot, ipiv, info)
@@ -71,26 +71,36 @@ function lu!(
7171
LU{T, typeof(A)}(A, ipiv, info)
7272
end
7373

74-
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize) where {Pivot}
75-
thread = length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1))
76-
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, thread)
74+
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, ::Val{true}) where {Pivot}
75+
if length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1))
76+
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true))
77+
else
78+
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
79+
end
80+
end
81+
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, ::Val{false}) where {Pivot}
82+
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
83+
end
84+
@inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, ::Val{Thread}) where {Pivot,Thread}
85+
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))
7786
@inbounds if m < n # fat matrix
7887
# [AL AR]
7988
AL = @view A[:, 1:m]
8089
AR = @view A[:, m+1:n]
81-
apply_permutation!(ipiv, AR, thread)
82-
ldiv!(UnitLowerTriangular(AL), AR)
90+
apply_permutation!(ipiv, AR, Val(Thread))
91+
ldiv!(UnitLowerTriangular(AL), AR, Val(Thread))
8392
end
8493
info
8594
end
8695

96+
8797
@inline function nsplit(::Type{T}, n) where T
8898
k = 512 ÷ (isbitstype(T) ? sizeof(T) : 8)
8999
k_2 = k ÷ 2
90100
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
91101
end
92102

93-
function apply_permutation_threaded!(P, A)
103+
function apply_permutation!(P, A, ::Val{true})
94104
batchsize = cld(2000, length(P))
95105
@batch minbatch=batchsize for j in axes(A, 2)
96106
@inbounds for i in axes(P, 1)
@@ -103,9 +113,7 @@ function apply_permutation_threaded!(P, A)
103113
nothing
104114
end
105115
_sizeof(::Type{T}) where {T} = Base.isbitstype(T) ? sizeof(T) : sizeof(Int)
106-
Base.@propagate_inbounds function apply_permutation!(P, A, thread)
107-
thread && return apply_permutation_threaded!(P, A)
108-
# length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1)) && return apply_permutation_threaded!(P, A)
116+
Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
109117
for i in axes(P, 1)
110118
i′ = P[i]
111119
i′ == i && continue
@@ -162,7 +170,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
162170
# [ A22 ] [ 0 ] [ A22 ]
163171
Pivot && apply_permutation!(P1, AR, thread)
164172
# A12 = L11 U12 => U12 = L11 \ A12
165-
ldiv!(UnitLowerTriangular(A11), A12)
173+
ldiv!(UnitLowerTriangular(A11), A12, thread)
166174
# Schur complement:
167175
# We have A22 = L21 U12 + A′22, hence
168176
# A′22 = A22 - L21 U12
@@ -176,7 +184,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
176184
Pivot && apply_permutation!(P2, A21, thread)
177185

178186
info != previnfo && (info += n1)
179-
@avx for i in 1:n2
187+
@turbo warn_check_args=false for i in 1:n2
180188
P2[i] += n1
181189
end
182190
return info
@@ -226,15 +234,15 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
226234
end
227235
# Scale first column
228236
Akkinv = inv(A[k,k])
229-
@avx check_empty=true for i = k+1:m
237+
@turbo check_empty=true warn_check_args=false for i = k+1:m
230238
A[i,k] *= Akkinv
231239
end
232240
elseif info == 0
233241
info = k
234242
end
235243
k == minmn && break
236244
# Update the rest
237-
@avx for j = k+1:n
245+
@turbo warn_check_args=false for j = k+1:n
238246
for i = k+1:m
239247
A[i,j] -= A[i,k]*A[k,j]
240248
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@ testlu(A::Adjoint, MF::Adjoint, BF) = testlu(parent(A), parent(MF), BF)
3333
MF = mylu(A, p)
3434
BF = baselu(A, p)
3535
testlu(A, MF, BF)
36+
testlu(A, mylu(A, p, Val(false)), BF)
3637
A′ = permutedims(A)
3738
MF′ = mylu(A′', p)
3839
testlu(A′', MF′, BF)
40+
testlu(A′', mylu(A′', p, Val(false)), BF)
3941
i = rand(1:s) # test `MF.info`
4042
A[:, i] .= 0
4143
MF = mylu(A, p, check=false)
4244
BF = baselu(A, p, check=false)
4345
testlu(A, MF, BF)
46+
testlu(A, mylu(A, p, Val(false), check=false), BF)
4447
end
4548
end
4649
end

0 commit comments

Comments
 (0)