Skip to content

Commit 08a3dfa

Browse files
committed
Use avx micro
1 parent ead6809 commit 08a3dfa

File tree

4 files changed

+70
-37
lines changed

4 files changed

+70
-37
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ version = "0.1.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
89

910
[compat]
11+
LoopVectorization = "0.7"
1012
julia = "1"
1113

1214
[extras]

perf/lu.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,42 @@
11
using BenchmarkTools
22
import LinearAlgebra, RecursiveFactorization
33

4-
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5
4+
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.08
55

66
luflop(m, n) = n^3÷3 - n÷3 + m*n^2
77
luflop(n) = luflop(n, n)
88

99
bas_mflops = Float64[]
10-
rec_mflops = Float64[]
11-
ns = 50:50:800
10+
rec8_mflops = Float64[]
11+
rec16_mflops = Float64[]
12+
rec32_mflops = Float64[]
13+
ref_mflops = Float64[]
14+
ns = 4:32:500
1215
for n in ns
16+
@info "$n × $n"
1317
A = rand(n, n)
1418
bt = @belapsed LinearAlgebra.lu!($(copy(A)))
15-
rt = @belapsed RecursiveFactorization.lu!($(copy(A)))
1619
push!(bas_mflops, luflop(n)/bt/1e9)
17-
push!(rec_mflops, luflop(n)/rt/1e9)
20+
21+
rt8 = @belapsed RecursiveFactorization.lu!($(copy(A)); blocksize=8)
22+
push!(rec8_mflops, luflop(n)/rt8/1e9)
23+
24+
rt16 = @belapsed RecursiveFactorization.lu!($(copy(A)); blocksize=16)
25+
push!(rec16_mflops, luflop(n)/rt16/1e9)
26+
27+
rt32 = @belapsed RecursiveFactorization.lu!($(copy(A)); blocksize=32)
28+
push!(rec32_mflops, luflop(n)/rt32/1e9)
29+
30+
ref = @belapsed LinearAlgebra.generic_lufact!($(copy(A)))
31+
push!(ref_mflops, luflop(n)/ref/1e9)
1832
end
1933

2034
using Plots
2135
plt = plot(ns, bas_mflops, legend=:bottomright, lab="OpenBLAS", title="LU Factorization Benchmark", marker=:auto, dpi=150)
22-
plot!(plt, ns, rec_mflops, lab="RecursiveFactorization", marker=:auto)
36+
plot!(plt, ns, rec8_mflops, lab="RF8", marker=:auto)
37+
plot!(plt, ns, rec16_mflops, lab="RF16", marker=:auto)
38+
plot!(plt, ns, rec32_mflops, lab="RF32", marker=:auto)
39+
plot!(plt, ns, ref_mflops, lab="Reference", marker=:auto)
2340
xaxis!(plt, "size (N x N)")
2441
yaxis!(plt, "GFLOPS")
2542
savefig("lubench.png")

src/lu.jl

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
1-
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
1+
using LoopVectorization: @avx
2+
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, mul!, checknonsingular
23

3-
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true);
4-
check::Bool = true, blocksize::Integer = 16)
5-
lu!(copy(A), pivot; check = check, blocksize = blocksize)
4+
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); kwargs...)
5+
lu!(copy(A), pivot; kwargs...)
66
end
77

8-
function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true);
9-
check::Bool = true, blocksize::Integer = 16)
10-
lu!(A, Vector{BlasInt}(undef, min(size(A)...)), pivot;
11-
check = check, blocksize = blocksize)
8+
function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true); kwargs...)
9+
lu!(A, Vector{BlasInt}(undef, min(size(A)...)), pivot; kwargs...)
1210
end
1311

1412
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
1513
pivot::Union{Val{false}, Val{true}} = Val(true);
16-
check::Bool=true, blocksize::Integer=16) where T
14+
check::Bool=true, blocksize::Integer=16, threshold::Integer=192) where T
1715
info = Ref(zero(BlasInt))
1816
m, n = size(A)
1917
mnmin = min(m, n)
20-
if T <: BlasFloat && A isa StridedArray
18+
if A isa StridedArray && mnmin > threshold
2119
reckernel!(A, pivot, m, mnmin, ipiv, info, blocksize)
2220
if m < n # fat matrix
2321
# [AL AR]
@@ -34,7 +32,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
3432
end
3533

3634
function nsplit(::Type{T}, n) where T
37-
k = 128 ÷ sizeof(T)
35+
k = 512 ÷ (isbitstype(T) ? sizeof(T) : 8)
3836
k_2 = k ÷ 2
3937
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
4038
end
@@ -44,7 +42,9 @@ Base.@propagate_inbounds function apply_permutation!(P, A)
4442
i′ = P[i]
4543
i′ == i && continue
4644
@simd for j in axes(A, 2)
47-
A[i, j], A[i′, j] = A[i′, j], A[i, j]
45+
tmp = A[i, j]
46+
A[i, j] = A[i′, j]
47+
A[i′, j] = tmp
4848
end
4949
end
5050
nothing
@@ -98,7 +98,8 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
9898
# Schur complement:
9999
# We have A22 = L21 U12 + A′22, hence
100100
# A′22 = A22 - L21 U12
101-
BLAS.gemm!('N', 'N', -one(T), A21, A12, one(T), A22)
101+
#mul!(A22, A21, A12, -one(T), one(T))
102+
schur_complement!(A22, A21, A12)
102103
# record info
103104
previnfo = info[]
104105
# P2 A22 = L22 U22
@@ -107,13 +108,23 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
107108
Pivot && apply_permutation!(P2, A21)
108109

109110
info[] != previnfo && (info[] += n1)
110-
@simd for i in 1:n2
111+
@avx for i in 1:n2
111112
P2[i] += n1
112113
end
113114
return nothing
114115
end # inbounds
115116
end
116117

118+
function schur_complement!(𝐂, 𝐀, 𝐁)
119+
@avx for m 1:size(𝐀,1), n 1:size(𝐁,2)
120+
𝐂ₘₙ = zero(eltype(𝐂))
121+
for k 1:size(𝐀,2)
122+
𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
123+
end
124+
𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
125+
end
126+
end
127+
117128
#=
118129
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
119130
License is MIT: https://julialang.org/license
@@ -147,15 +158,15 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
147158
end
148159
# Scale first column
149160
Akkinv = inv(A[k,k])
150-
@simd for i = k+1:m
161+
@avx for i = k+1:m
151162
A[i,k] *= Akkinv
152163
end
153164
elseif info[] == 0
154165
info[] = k
155166
end
156167
# Update the rest
157-
for j = k+1:n
158-
@simd for i = k+1:m
168+
@avx for j = k+1:n
169+
for i = k+1:m
159170
A[i,j] -= A[i,k]*A[k,j]
160171
end
161172
end

test/runtests.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,27 @@ const mylu = RecursiveFactorization.lu
1111

1212
function testlu(A, MF, BF)
1313
@test MF.info == BF.info
14-
@test norm(MF.L*MF.U - A[MF.p, :], Inf) < sqrt(eps(real(first(A))))
14+
@test norm(MF.L*MF.U - A[MF.p, :], Inf) < 100sqrt(eps(real(one(float(first(A))))))
1515
nothing
1616
end
1717

1818
@testset "Test LU factorization" begin
19-
for p in (Val(true), Val(false)), T in (Float64, Float32, ComplexF64, ComplexF32, Real)
20-
siz = (50, 100)
21-
if isconcretetype(T)
22-
A = rand(T, siz...)
23-
else
24-
_A = rand(50, 100)
25-
A = Matrix{T}(undef, siz...)
26-
copyto!(A, _A)
27-
end
28-
MF = mylu(A, p)
29-
BF = baselu(A, p)
30-
testlu(A, MF, BF)
31-
for i in 50:7:100 # test `MF.info`
19+
for _p in (true, false), T in (Float64, Float32, ComplexF64, ComplexF32, Real)
20+
p = Val(_p)
21+
for s in [1:10; 50:80:200; 300]
22+
siz = (s, s+2)
23+
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
24+
if isconcretetype(T)
25+
A = rand(T, siz...)
26+
else
27+
_A = rand(siz...)
28+
A = Matrix{T}(undef, siz...)
29+
copyto!(A, _A)
30+
end
31+
MF = mylu(A, p)
32+
BF = baselu(A, p)
33+
testlu(A, MF, BF)
34+
i = rand(1:s) # test `MF.info`
3235
A[:, i] .= 0
3336
MF = mylu(A, p, check=false)
3437
BF = baselu(A, p, check=false)

0 commit comments

Comments
 (0)