Skip to content

Commit 1143e93

Browse files
authored
TriangularSolve.jl for ldiv! (#28)
1 parent a478c1f commit 1143e93

File tree

5 files changed

+87
-37
lines changed

5 files changed

+87
-37
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ on:
88
- master
99
jobs:
1010
test:
11+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.threads }} - ${{ matrix.arch }} - ${{ github.event_name }}
1112
runs-on: ${{ matrix.os }}
1213
strategy:
1314
matrix:
1415
julia-version: ['1', '^1.7.0-0']
16+
threads:
17+
- '1'
18+
- '3'
1519
os: [ubuntu-latest, windows-latest, macOS-latest]
1620
steps:
1721
- uses: actions/checkout@v2
@@ -30,6 +34,8 @@ jobs:
3034
${{ runner.os }}-
3135
- uses: julia-actions/julia-buildpkg@v1
3236
- uses: julia-actions/julia-runtest@v1
37+
env:
38+
JULIA_NUM_THREADS: ${{ matrix.threads }}
3339
- uses: julia-actions/julia-processcoverage@v1
3440
- uses: codecov/codecov-action@v1
3541
with:

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
name = "RecursiveFactorization"
22
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
33
authors = ["Yingbo Ma <[email protected]>"]
4-
version = "0.1.13"
4+
version = "0.2.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
9+
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
10+
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
11+
TriangularSolve = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf"
912

1013
[compat]
1114
LoopVectorization = "0.10,0.11, 0.12"
15+
Polyester = "0.3.2"
16+
StrideArraysCore = "0.1.13"
17+
TriangularSolve = "0.1.1"
1218
julia = "1.5"
1319

1420
[extras]

perf/lu.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using BenchmarkTools, Random
2-
using LinearAlgebra, RecursiveFactorization
3-
2+
using LinearAlgebra, RecursiveFactorization, VectorizationBase
3+
nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads())
4+
BLAS.set_num_threads(nc)
45
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5
56

67
function luflop(m, n=m; innerflop=2)
@@ -43,7 +44,12 @@ for n in ns
4344
end
4445

4546
using DataFrames, VegaLite
46-
blaslib = BLAS.vendor() === :mkl ? :MKL : :OpenBLAS
47+
blaslib = if VERSION v"1.7.0-beta2"
48+
config = BLAS.get_config().loaded_libs
49+
occursin("libmkl_rt", config[1].libname) ? :MKL : :OpenBLAS
50+
else
51+
BLAS.vendor() === :mkl ? :MKL : :OpenBLAS
52+
end
4753
df = DataFrame(Size = ns,
4854
Reference = ref_mflops)
4955
setproperty!(df, blaslib, bas_mflops)
@@ -60,7 +66,7 @@ plt = df |> @vlplot(
6066
x = {:Size}, y = {:GFLOPS},
6167
width = 1000, height = 600
6268
)
63-
save(joinpath(homedir(), "Pictures", "lu_float64.png"), plt)
69+
save(joinpath(homedir(), "Pictures", "lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
6470

6571
#=
6672
using Plot

src/lu.jl

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using LoopVectorization
2-
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, checknonsingular, BLAS, LinearAlgebra
2+
using TriangularSolve: ldiv!
3+
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS, LinearAlgebra, Adjoint, Transpose
4+
using StrideArraysCore
5+
using Polyester: @batch
36

47
# 1.7 compat
58
normalize_pivot(t::Val{T}) where T = t
@@ -26,43 +29,40 @@ function lu!(A, pivot = Val(true); check=true, kwargs...)
2629
return F
2730
end
2831

32+
for (f, T) in [(:adjoint, :Adjoint), (:transpose, :Transpose)], lu in (:lu, :lu!)
33+
@eval $lu(A::$T, args...; kwargs...) = $f($lu(parent(A), args...; kwargs...))
34+
end
35+
2936
const RECURSION_THRESHOLD = Ref(-1)
3037

3138
# AVX512 needs a smaller recursion limit
3239
function pick_threshold()
3340
RECURSION_THRESHOLD[] >= 0 && return RECURSION_THRESHOLD[]
34-
blasvendor = @static if VERSION >= v"1.7.0-DEV.610"
35-
:openblas64
36-
else
37-
BLAS.vendor()
38-
end
39-
if blasvendor === :openblas || blasvendor === :openblas64
40-
LoopVectorization.register_size() == 64 ? 110 : 72
41-
else
42-
LoopVectorization.register_size() == 64 ? 48 : 72
43-
end
41+
LoopVectorization.register_size() == 64 ? 48 : 40
4442
end
4543

44+
recurse(::StridedArray) = true
45+
recurse(_) = false
46+
4647
function lu!(
4748
A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
4849
pivot = Val(true);
4950
check::Bool=true,
50-
# the performance is not sensitive wrt blocksize, and 16 is a good default
51-
blocksize::Integer=16,
51+
# the performance is not sensitive wrt blocksize, and 8 is a good default
52+
blocksize::Integer=length(A) 40_000 ? 8 : 16,
5253
threshold::Integer=pick_threshold()
5354
) where T
5455
pivot = normalize_pivot(pivot)
5556
info = zero(BlasInt)
5657
m, n = size(A)
5758
mnmin = min(m, n)
58-
if A isa StridedArray && mnmin > threshold
59-
info = reckernel!(A, pivot, m, mnmin, ipiv, info, blocksize)
60-
if m < n # fat matrix
61-
# [AL AR]
62-
AL = @view A[:, 1:m]
63-
AR = @view A[:, m+1:n]
64-
apply_permutation!(ipiv, AR)
65-
ldiv!(UnitLowerTriangular(AL), AR)
59+
if recurse(A) && mnmin > threshold
60+
if T <: Union{Float32,Float64}
61+
GC.@preserve ipiv A begin
62+
info = recurse!(PtrArray(A), pivot, m, n, mnmin, PtrArray(ipiv), info, blocksize)
63+
end
64+
else
65+
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize)
6666
end
6767
else # generic fallback
6868
info = _generic_lufact!(A, pivot, ipiv, info)
@@ -71,13 +71,41 @@ function lu!(
7171
LU{T, typeof(A)}(A, ipiv, info)
7272
end
7373

74-
function nsplit(::Type{T}, n) where T
74+
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize) where {Pivot}
75+
thread = length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1))
76+
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, thread)
77+
@inbounds if m < n # fat matrix
78+
# [AL AR]
79+
AL = @view A[:, 1:m]
80+
AR = @view A[:, m+1:n]
81+
apply_permutation!(ipiv, AR, thread)
82+
ldiv!(UnitLowerTriangular(AL), AR)
83+
end
84+
info
85+
end
86+
87+
@inline function nsplit(::Type{T}, n) where T
7588
k = 512 ÷ (isbitstype(T) ? sizeof(T) : 8)
7689
k_2 = k ÷ 2
7790
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
7891
end
7992

80-
Base.@propagate_inbounds function apply_permutation!(P, A)
93+
function apply_permutation_threaded!(P, A)
94+
batchsize = cld(2000, length(P))
95+
@batch minbatch=batchsize for j in axes(A, 2)
96+
@inbounds @simd ivdep for i in axes(P, 1)
97+
i′ = P[i]
98+
tmp = A[i, j]
99+
A[i, j] = A[i′, j]
100+
A[i′, j] = tmp
101+
end
102+
end
103+
nothing
104+
end
105+
_sizeof(::Type{T}) where {T} = Base.isbitstype(T) ? sizeof(T) : sizeof(Int)
106+
Base.@propagate_inbounds function apply_permutation!(P, A, thread)
107+
thread && return apply_permutation_threaded!(P, A)
108+
# length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1)) && return apply_permutation_threaded!(P, A)
81109
for i in axes(P, 1)
82110
i′ = P[i]
83111
i′ == i && continue
@@ -90,10 +118,10 @@ Base.@propagate_inbounds function apply_permutation!(P, A)
90118
nothing
91119
end
92120

93-
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize)::BlasInt where {T,Pivot}
121+
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize, thread)::BlasInt where {T,Pivot}
94122
@inbounds begin
95123
if n <= max(blocksize, 1)
96-
info = _generic_lufact!(A, pivot, ipiv, info)
124+
info = _generic_lufact!(A, Val(Pivot), ipiv, info)
97125
return info
98126
end
99127
n1 = nsplit(T, n)
@@ -128,11 +156,11 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
128156
# [ A11 ] [ L11 ]
129157
# P [ ] = [ ] U11
130158
# [ A21 ] [ L21 ]
131-
info = reckernel!(AL, pivot, m, n1, P1, info, blocksize)
159+
info = reckernel!(AL, Val(Pivot), m, n1, P1, info, blocksize, thread)
132160
# [ A12 ] [ P1 ] [ A12 ]
133161
# [ ] <- [ ] [ ]
134162
# [ A22 ] [ 0 ] [ A22 ]
135-
Pivot && apply_permutation!(P1, AR)
163+
Pivot && apply_permutation!(P1, AR, thread)
136164
# A12 = L11 U12 => U12 = L11 \ A12
137165
ldiv!(UnitLowerTriangular(A11), A12)
138166
# Schur complement:
@@ -143,9 +171,9 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
143171
# record info
144172
previnfo = info
145173
# P2 A22 = L22 U22
146-
info = reckernel!(A22, pivot, m2, n2, P2, info, blocksize)
174+
info = reckernel!(A22, Val(Pivot), m2, n2, P2, info, blocksize, thread)
147175
# A21 <- P2 A21
148-
Pivot && apply_permutation!(P2, A21)
176+
Pivot && apply_permutation!(P2, A21, thread)
149177

150178
info != previnfo && (info += n1)
151179
@avx for i in 1:n2
@@ -156,7 +184,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
156184
end
157185

158186
function schur_complement!(𝐂, 𝐀, 𝐁)
159-
@avx for m 1:size(𝐀,1), n 1:size(𝐁,2)
187+
@tturbo for m 1:size(𝐀,1), n 1:size(𝐁,2)
160188
𝐂ₘₙ = zero(eltype(𝐂))
161189
for k 1:size(𝐀,2)
162190
𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]

test/runtests.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test
22
import RecursiveFactorization
33
import LinearAlgebra
4-
using LinearAlgebra: norm
4+
using LinearAlgebra: norm, Adjoint
55
using Random
66

77
Random.seed!(12)
@@ -11,9 +11,10 @@ const mylu = RecursiveFactorization.lu
1111

1212
function testlu(A, MF, BF)
1313
@test MF.info == BF.info
14-
@test norm(MF.L*MF.U - A[MF.p, :], Inf) < 100sqrt(eps(real(one(float(first(A))))))
14+
@test norm(MF.L*MF.U - A[MF.p, :], Inf) < length(A)*sqrt(eps(real(one(float(first(A))))))/16
1515
nothing
1616
end
17+
testlu(A::Adjoint, MF::Adjoint, BF) = testlu(parent(A), parent(MF), BF)
1718

1819
@testset "Test LU factorization" begin
1920
for _p in (true, false), T in (Float64, Float32, ComplexF64, ComplexF32, Real)
@@ -32,6 +33,9 @@ end
3233
MF = mylu(A, p)
3334
BF = baselu(A, p)
3435
testlu(A, MF, BF)
36+
A′ = permutedims(A)
37+
MF′ = mylu(A′', p)
38+
testlu(A′', MF′, BF)
3539
i = rand(1:s) # test `MF.info`
3640
A[:, i] .= 0
3741
MF = mylu(A, p, check=false)

0 commit comments

Comments
 (0)