Skip to content

Commit a790145

Browse files
authored
Merge pull request #6 from YingboMa/avx
Use avx micro
2 parents ead6809 + ec8e94e commit a790145

File tree

5 files changed

+118
-59
lines changed

5 files changed

+118
-59
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ language: julia
33
os:
44
- linux
55
julia:
6-
- 1.0
6+
- 1.1
77
- nightly
88
#matrix:
99
# allow_failures:

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
name = "RecursiveFactorization"
22
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
33
authors = ["Yingbo Ma <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
89

910
[compat]
10-
julia = "1"
11+
LoopVectorization = "0.7"
12+
julia = "1.1"
1113

1214
[extras]
1315
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

perf/lu.jl

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,55 @@
11
using BenchmarkTools
2-
import LinearAlgebra, RecursiveFactorization
2+
using LinearAlgebra, RecursiveFactorization
33

4-
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5
4+
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.08
55

6-
luflop(m, n) = n^3÷3 - n÷3 + m*n^2
7-
luflop(n) = luflop(n, n)
6+
function luflop(m, n=m; innerflop=2)
7+
sum(1:min(m, n)) do k
8+
invflop = 1
9+
scaleflop = isempty(k+1:m) ? 0 : sum(k+1:m)
10+
updateflop = isempty(k+1:n) ? 0 : sum(k+1:n) do j
11+
isempty(k+1:m) ? 0 : sum(k+1:m) do i
12+
innerflop
13+
end
14+
end
15+
invflop + scaleflop + updateflop
16+
end
17+
end
818

919
bas_mflops = Float64[]
1020
rec_mflops = Float64[]
11-
ns = 50:50:800
21+
ref_mflops = Float64[]
22+
ns = 4:8:500
1223
for n in ns
24+
@info "$n × $n"
1325
A = rand(n, n)
1426
bt = @belapsed LinearAlgebra.lu!($(copy(A)))
15-
rt = @belapsed RecursiveFactorization.lu!($(copy(A)))
1627
push!(bas_mflops, luflop(n)/bt/1e9)
28+
29+
rt = @belapsed RecursiveFactorization.lu!($(copy(A)))
1730
push!(rec_mflops, luflop(n)/rt/1e9)
31+
32+
ref = @belapsed LinearAlgebra.generic_lufact!($(copy(A)))
33+
push!(ref_mflops, luflop(n)/ref/1e9)
1834
end
1935

20-
using Plots
21-
plt = plot(ns, bas_mflops, legend=:bottomright, lab="OpenBLAS", title="LU Factorization Benchmark", marker=:auto, dpi=150)
36+
using DataFrames, VegaLite
37+
df = DataFrame(Size = ns, RecursiveFactorization = rec_mflops, OpenBLAS = bas_mflops, Reference = ref_mflops)
38+
df = stack(df, [:RecursiveFactorization, :OpenBLAS, :Reference], variable_name = :Library, value_name = :GFLOPS)
39+
plt = df |> @vlplot(
40+
:line, color = :Library,
41+
x = {:Size}, y = {:GFLOPS},
42+
width = 2400, height = 600
43+
)
44+
save(joinpath(homedir(), "Pictures", "lu_float64.png"), plt)
45+
46+
#=
47+
using Plot
48+
plt = plot(ns, bas_mflops, legend=:bottomright, lab="OpenBLAS", title="LU Factorization Benchmark", marker=:auto, dpi=300)
2249
plot!(plt, ns, rec_mflops, lab="RecursiveFactorization", marker=:auto)
50+
plot!(plt, ns, ref_mflops, lab="Reference", marker=:auto)
2351
xaxis!(plt, "size (N x N)")
2452
yaxis!(plt, "GFLOPS")
2553
savefig("lubench.png")
2654
savefig("lubench.pdf")
55+
=#

src/lu.jl

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,52 @@
1-
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
1+
using LoopVectorization
2+
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, checknonsingular, BLAS, LinearAlgebra
23

3-
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true);
4-
check::Bool = true, blocksize::Integer = 16)
5-
lu!(copy(A), pivot; check = check, blocksize = blocksize)
4+
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); kwargs...)
5+
return lu!(copy(A), pivot; kwargs...)
66
end
77

8-
function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true);
9-
check::Bool = true, blocksize::Integer = 16)
10-
lu!(A, Vector{BlasInt}(undef, min(size(A)...)), pivot;
11-
check = check, blocksize = blocksize)
8+
function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true); check=true, kwargs...)
9+
m, n = size(A)
10+
minmn = min(m, n)
11+
F = if minmn < 10 # avx introduces small performance degradation
12+
LinearAlgebra.generic_lufact!(A, pivot; check=check)
13+
else
14+
lu!(A, Vector{BlasInt}(undef, minmn), pivot; check=check, kwargs...)
15+
end
16+
return F
1217
end
1318

19+
# Use a function here to make sure it gets optimized away
20+
# OpenBLAS' TRSM isn't very good, we use a higher threshold for recursion
21+
pick_threshold() = BLAS.vendor() === :mkl ? 48 : 192
22+
1423
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
1524
pivot::Union{Val{false}, Val{true}} = Val(true);
16-
check::Bool=true, blocksize::Integer=16) where T
17-
info = Ref(zero(BlasInt))
25+
check::Bool=true,
26+
# the performance is not sensitive wrt blocksize, and 16 is a good default
27+
blocksize::Integer=16,
28+
threshold::Integer=pick_threshold()) where T
29+
info = zero(BlasInt)
1830
m, n = size(A)
1931
mnmin = min(m, n)
20-
if T <: BlasFloat && A isa StridedArray
21-
reckernel!(A, pivot, m, mnmin, ipiv, info, blocksize)
32+
if A isa StridedArray && mnmin > threshold
33+
info = reckernel!(A, pivot, m, mnmin, ipiv, info, blocksize)
2234
if m < n # fat matrix
2335
# [AL AR]
2436
AL = @view A[:, 1:m]
2537
AR = @view A[:, m+1:n]
2638
apply_permutation!(ipiv, AR)
2739
ldiv!(UnitLowerTriangular(AL), AR)
2840
end
29-
else # generic fallback
30-
_generic_lufact!(A, pivot, ipiv, info)
41+
else # generic fallback
42+
info = _generic_lufact!(A, pivot, ipiv, info)
3143
end
32-
check && checknonsingular(info[])
33-
LU{T, typeof(A)}(A, ipiv, info[])
44+
check && checknonsingular(info)
45+
LU{T, typeof(A)}(A, ipiv, info)
3446
end
3547

3648
function nsplit(::Type{T}, n) where T
37-
k = 128 ÷ sizeof(T)
49+
k = 512 ÷ (isbitstype(T) ? sizeof(T) : 8)
3850
k_2 = k ÷ 2
3951
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
4052
end
@@ -44,17 +56,19 @@ Base.@propagate_inbounds function apply_permutation!(P, A)
4456
i′ = P[i]
4557
i′ == i && continue
4658
@simd for j in axes(A, 2)
47-
A[i, j], A[i′, j] = A[i′, j], A[i, j]
59+
tmp = A[i, j]
60+
A[i, j] = A[i′, j]
61+
A[i′, j] = tmp
4862
end
4963
end
5064
nothing
5165
end
5266

53-
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize)::Nothing where {T,Pivot}
67+
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize)::BlasInt where {T,Pivot}
5468
@inbounds begin
5569
if n <= max(blocksize, 1)
56-
_generic_lufact!(A, pivot, ipiv, info)
57-
return nothing
70+
info = _generic_lufact!(A, pivot, ipiv, info)
71+
return info
5872
end
5973
n1 = nsplit(T, n)
6074
n2 = n - n1
@@ -88,7 +102,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
88102
# [ A11 ] [ L11 ]
89103
# P [ ] = [ ] U11
90104
# [ A21 ] [ L21 ]
91-
reckernel!(AL, pivot, m, n1, P1, info, blocksize)
105+
info = reckernel!(AL, pivot, m, n1, P1, info, blocksize)
92106
# [ A12 ] [ P1 ] [ A12 ]
93107
# [ ] <- [ ] [ ]
94108
# [ A22 ] [ 0 ] [ A22 ]
@@ -98,22 +112,33 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
98112
# Schur complement:
99113
# We have A22 = L21 U12 + A′22, hence
100114
# A′22 = A22 - L21 U12
101-
BLAS.gemm!('N', 'N', -one(T), A21, A12, one(T), A22)
115+
#mul!(A22, A21, A12, -one(T), one(T))
116+
schur_complement!(A22, A21, A12)
102117
# record info
103-
previnfo = info[]
118+
previnfo = info
104119
# P2 A22 = L22 U22
105-
reckernel!(A22, pivot, m2, n2, P2, info, blocksize)
120+
info = reckernel!(A22, pivot, m2, n2, P2, info, blocksize)
106121
# A21 <- P2 A21
107122
Pivot && apply_permutation!(P2, A21)
108123

109-
info[] != previnfo && (info[] += n1)
110-
@simd for i in 1:n2
124+
info != previnfo && (info += n1)
125+
@avx for i in 1:n2
111126
P2[i] += n1
112127
end
113-
return nothing
128+
return info
114129
end # inbounds
115130
end
116131

132+
function schur_complement!(𝐂, 𝐀, 𝐁)
133+
@avx for m 1:size(𝐀,1), n 1:size(𝐁,2)
134+
𝐂ₘₙ = zero(eltype(𝐂))
135+
for k 1:size(𝐀,2)
136+
𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
137+
end
138+
𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
139+
end
140+
end
141+
117142
#=
118143
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
119144
License is MIT: https://julialang.org/license
@@ -147,19 +172,19 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
147172
end
148173
# Scale first column
149174
Akkinv = inv(A[k,k])
150-
@simd for i = k+1:m
175+
@avx for i = k+1:m
151176
A[i,k] *= Akkinv
152177
end
153-
elseif info[] == 0
154-
info[] = k
178+
elseif info == 0
179+
info = k
155180
end
156181
# Update the rest
157-
for j = k+1:n
158-
@simd for i = k+1:m
182+
@avx for j = k+1:n
183+
for i = k+1:m
159184
A[i,j] -= A[i,k]*A[k,j]
160185
end
161186
end
162187
end
163188
end
164-
return nothing
189+
return info
165190
end

test/runtests.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,27 @@ const mylu = RecursiveFactorization.lu
1111

1212
function testlu(A, MF, BF)
1313
@test MF.info == BF.info
14-
@test norm(MF.L*MF.U - A[MF.p, :], Inf) < sqrt(eps(real(first(A))))
14+
@test norm(MF.L*MF.U - A[MF.p, :], Inf) < 100sqrt(eps(real(one(float(first(A))))))
1515
nothing
1616
end
1717

1818
@testset "Test LU factorization" begin
19-
for p in (Val(true), Val(false)), T in (Float64, Float32, ComplexF64, ComplexF32, Real)
20-
siz = (50, 100)
21-
if isconcretetype(T)
22-
A = rand(T, siz...)
23-
else
24-
_A = rand(50, 100)
25-
A = Matrix{T}(undef, siz...)
26-
copyto!(A, _A)
27-
end
28-
MF = mylu(A, p)
29-
BF = baselu(A, p)
30-
testlu(A, MF, BF)
31-
for i in 50:7:100 # test `MF.info`
19+
for _p in (true, false), T in (Float64, Float32, ComplexF64, ComplexF32, Real)
20+
p = Val(_p)
21+
for s in [1:10; 50:80:200; 300]
22+
siz = (s, s+2)
23+
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
24+
if isconcretetype(T)
25+
A = rand(T, siz...)
26+
else
27+
_A = rand(siz...)
28+
A = Matrix{T}(undef, siz...)
29+
copyto!(A, _A)
30+
end
31+
MF = mylu(A, p)
32+
BF = baselu(A, p)
33+
testlu(A, MF, BF)
34+
i = rand(1:s) # test `MF.info`
3235
A[:, i] .= 0
3336
MF = mylu(A, p, check=false)
3437
BF = baselu(A, p, check=false)

0 commit comments

Comments
 (0)