Skip to content

Commit e378eba

Browse files
authored
Merge pull request #83 from JuliaLinearAlgebra/myb/nopiv
ldiv support for NotIPIV
2 parents afec32d + 74f4f35 commit e378eba

File tree

4 files changed

+122
-93
lines changed

4 files changed

+122
-93
lines changed

perf/lu.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,22 @@ else
5252
BLAS.vendor() === :mkl ? :MKL : :OpenBLAS
5353
end
5454
df = DataFrame(Size = ns,
55-
Reference = ref_mflops)
55+
Reference = ref_mflops)
5656
setproperty!(df, blaslib, bas_mflops)
5757
setproperty!(df, Symbol("RF with default threshold"), rec_mflops)
5858
setproperty!(df, Symbol("RF fully recursive"), rec4_mflops)
5959
setproperty!(df, Symbol("RF fully iterative"), rec800_mflops)
6060
df = stack(df,
61-
[Symbol("RF with default threshold"),
62-
Symbol("RF fully recursive"),
63-
Symbol("RF fully iterative"),
64-
blaslib,
65-
:Reference], variable_name = :Library, value_name = :GFLOPS)
61+
[Symbol("RF with default threshold"),
62+
Symbol("RF fully recursive"),
63+
Symbol("RF fully iterative"),
64+
blaslib,
65+
:Reference], variable_name = :Library, value_name = :GFLOPS)
6666
plt = df |> @vlplot(:line, color={:Library, scale = {scheme = "category10"}},
67-
x={:Size}, y={:GFLOPS},
68-
width=1000, height=600)
67+
x={:Size}, y={:GFLOPS},
68+
width=1000, height=600)
6969
save(joinpath(homedir(), "Pictures",
70-
"lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
70+
"lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
7171

7272
#=
7373
using Plot

src/RecursiveFactorization.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ include("./lu.jl")
77

88
import PrecompileTools
99

10-
PrecompileTools.@compile_workload begin lu!(rand(2, 2)) end
10+
PrecompileTools.@compile_workload begin
11+
lu!(rand(2, 2))
12+
end
1113

1214
end # module

src/lu.jl

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using LoopVectorization
22
using TriangularSolve: ldiv!
33
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4-
LinearAlgebra, Adjoint, Transpose
4+
LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
55
using StrideArraysCore
66
using Polyester: @batch
77

@@ -41,16 +41,23 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
4141

4242
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_cols!)
4343
function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
44-
B::StridedVecOrMat)
44+
B::StridedVecOrMat)
4545
return B
4646
end
4747
end
4848
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_rows!)
49-
function LinearAlgebra._ipiv_rows!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
50-
B::StridedVecOrMat)
49+
function LinearAlgebra._ipiv_rows!(::(LU{T, <:AbstractMatrix{T}, NotIPIV} where {T}),
50+
::OrdinalRange,
51+
B::StridedVecOrMat)
5152
return B
5253
end
5354
end
55+
if CUSTOMIZABLE_PIVOT
56+
function LinearAlgebra.ldiv!(A::LU{T, <:StridedMatrix, <:NotIPIV},
57+
B::StridedVecOrMat{T}) where {T <: BlasFloat}
58+
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B))
59+
end
60+
end
5461

5562
function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
5663
m, n = size(A)
@@ -80,11 +87,11 @@ recurse(_) = false
8087
_ptrarray(ipiv) = PtrArray(ipiv)
8188
_ptrarray(ipiv::NotIPIV) = ipiv
8289
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
83-
pivot = Val(true), thread = Val(true);
84-
check::Bool = true,
85-
# the performance is not sensitive wrt blocksize, and 8 is a good default
86-
blocksize::Integer = length(A) 40_000 ? 8 : 16,
87-
threshold::Integer = pick_threshold()) where {T}
90+
pivot = Val(true), thread = Val(true);
91+
check::Bool = true,
92+
# the performance is not sensitive wrt blocksize, and 8 is a good default
93+
blocksize::Integer = length(A) 40_000 ? 8 : 16,
94+
threshold::Integer = pick_threshold()) where {T}
8895
pivot = normalize_pivot(pivot)
8996
info = zero(BlasInt)
9097
m, n = size(A)
@@ -94,10 +101,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
94101
end
95102
if recurse(A) && mnmin > threshold
96103
if T <: Union{Float32, Float64}
97-
GC.@preserve ipiv A begin info = recurse!(view(PtrArray(A), axes(A)...), pivot,
98-
m, n, mnmin,
99-
_ptrarray(ipiv), info, blocksize,
100-
thread) end
104+
GC.@preserve ipiv A begin
105+
info = recurse!(view(PtrArray(A), axes(A)...), pivot,
106+
m, n, mnmin,
107+
_ptrarray(ipiv), info, blocksize,
108+
thread)
109+
end
101110
else
102111
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
103112
end
@@ -109,7 +118,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
109118
end
110119

111120
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
112-
::Val{true}) where {Pivot}
121+
::Val{true}) where {Pivot}
113122
if length(A) * _sizeof(eltype(A)) >
114123
0.92 * LoopVectorization.VectorizationBase.cache_size(Val(2))
115124
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true))
@@ -118,11 +127,11 @@ end
118127
end
119128
end
120129
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
121-
::Val{false}) where {Pivot}
130+
::Val{false}) where {Pivot}
122131
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
123132
end
124133
@inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
125-
::Val{Thread}) where {Pivot, Thread}
134+
::Val{Thread}) where {Pivot, Thread}
126135
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int
127136
@inbounds if m < n # fat matrix
128137
# [AL AR]
@@ -166,7 +175,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
166175
nothing
167176
end
168177
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize,
169-
thread)::BlasInt where {T, Pivot}
178+
thread)::BlasInt where {T, Pivot}
170179
@inbounds begin
171180
if n <= max(blocksize, 1)
172181
info = _generic_lufact!(A, Val(Pivot), ipiv, info)
@@ -262,44 +271,46 @@ end
262271
function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
263272
m, n = size(A)
264273
minmn = length(ipiv)
265-
@inbounds begin for k in 1:minmn
266-
# find index max
267-
kp = k
268-
if Pivot
269-
amax = abs(zero(eltype(A)))
270-
for i in k:m
271-
absi = abs(A[i, k])
272-
if absi > amax
273-
kp = i
274-
amax = absi
274+
@inbounds begin
275+
for k in 1:minmn
276+
# find index max
277+
kp = k
278+
if Pivot
279+
amax = abs(zero(eltype(A)))
280+
for i in k:m
281+
absi = abs(A[i, k])
282+
if absi > amax
283+
kp = i
284+
amax = absi
285+
end
275286
end
287+
ipiv[k] = kp
276288
end
277-
ipiv[k] = kp
278-
end
279-
if !iszero(A[kp, k])
280-
if k != kp
281-
# Interchange
282-
@simd for i in 1:n
283-
tmp = A[k, i]
284-
A[k, i] = A[kp, i]
285-
A[kp, i] = tmp
289+
if !iszero(A[kp, k])
290+
if k != kp
291+
# Interchange
292+
@simd for i in 1:n
293+
tmp = A[k, i]
294+
A[k, i] = A[kp, i]
295+
A[kp, i] = tmp
296+
end
286297
end
298+
# Scale first column
299+
Akkinv = inv(A[k, k])
300+
@turbo check_empty=true warn_check_args=false for i in (k + 1):m
301+
A[i, k] *= Akkinv
302+
end
303+
elseif info == 0
304+
info = k
287305
end
288-
# Scale first column
289-
Akkinv = inv(A[k, k])
290-
@turbo check_empty=true warn_check_args=false for i in (k + 1):m
291-
A[i, k] *= Akkinv
292-
end
293-
elseif info == 0
294-
info = k
295-
end
296-
k == minmn && break
297-
# Update the rest
298-
@turbo warn_check_args=false for j in (k + 1):n
299-
for i in (k + 1):m
300-
A[i, j] -= A[i, k] * A[k, j]
306+
k == minmn && break
307+
# Update the rest
308+
@turbo warn_check_args=false for j in (k + 1):n
309+
for i in (k + 1):m
310+
A[i, j] -= A[i, k] * A[k, j]
311+
end
301312
end
302313
end
303-
end end
314+
end
304315
return info
305316
end

test/runtests.jl

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,66 @@
11
using Test
22
import RecursiveFactorization
33
import LinearAlgebra
4-
using LinearAlgebra: norm, Adjoint, Transpose
4+
using LinearAlgebra: norm, Adjoint, Transpose, ldiv!
55
using Random
66

77
Random.seed!(12)
88

99
const baselu = LinearAlgebra.lu
1010
const mylu = RecursiveFactorization.lu
1111

12-
function testlu(A, MF, BF)
12+
function testlu(A, MF, BF, p)
1313
@test MF.info == BF.info
14-
@test norm(MF.L * MF.U - A[MF.p, :], Inf) < 200sqrt(eps(real(one(float(first(A))))))
14+
if !iszero(MF.info)
15+
return nothing
16+
end
17+
E = 20size(A, 1) * eps(real(one(float(first(A)))))
18+
@test norm(MF.L * MF.U - A[MF.p, :], Inf) < (p ? E : 10sqrt(E))
19+
if ==(size(A)...)
20+
b = ldiv!(MF, A[:, end])
21+
if all(isfinite, b)
22+
n = size(A, 2)
23+
rhs = [i == n for i in 1:n]
24+
@test brhs atol=p ? 100E : 100sqrt(E)
25+
end
26+
end
1527
nothing
1628
end
17-
testlu(A::Union{Transpose, Adjoint}, MF, BF) = testlu(parent(A), parent(MF), BF)
29+
testlu(A::Union{Transpose, Adjoint}, MF, BF, p) = testlu(parent(A), parent(MF), BF, p)
1830

19-
@testset "Test LU factorization" begin for _p in (true, false),
20-
T in (Float64, Float32, ComplexF64, ComplexF32,
21-
Real)
31+
@testset "Test LU factorization" begin
32+
for _p in (true, false),
33+
T in (Float64, Float32, ComplexF64, ComplexF32,
34+
Real)
2235

23-
p = Val(_p)
24-
for (i, s) in enumerate([1:10; 50:80:200; 300])
25-
iseven(i) && (p = RecursiveFactorization.to_stdlib_pivot(p))
26-
siz = (s, s + 2)
27-
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
28-
if isconcretetype(T)
29-
A = rand(T, siz...)
30-
else
31-
_A = rand(siz...)
32-
A = Matrix{T}(undef, siz...)
33-
copyto!(A, _A)
36+
p = Val(_p)
37+
for (i, s) in enumerate([1:10; 50:80:200; 300])
38+
iseven(i) && (p = RecursiveFactorization.to_stdlib_pivot(p))
39+
for m in (s, s + 2)
40+
siz = (s, m)
41+
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
42+
if isconcretetype(T)
43+
A = rand(T, siz...)
44+
else
45+
_A = rand(siz...)
46+
A = Matrix{T}(undef, siz...)
47+
copyto!(A, _A)
48+
end
49+
MF = mylu(A, p)
50+
BF = baselu(A, p)
51+
testlu(A, MF, BF, _p)
52+
testlu(A, mylu(A, p, Val(false)), BF, false)
53+
A′ = permutedims(A)
54+
MF′ = mylu(A′', p)
55+
testlu(A′', MF′, BF, _p)
56+
testlu(A′', mylu(A′', p, Val(false)), BF, false)
57+
i = rand(1:s) # test `MF.info`
58+
A[:, i] .= 0
59+
MF = mylu(A, p, check = false)
60+
BF = baselu(A, p, check = false)
61+
testlu(A, MF, BF, _p)
62+
testlu(A, mylu(A, p, Val(false), check = false), BF, false)
63+
end
3464
end
35-
MF = mylu(A, p)
36-
BF = baselu(A, p)
37-
testlu(A, MF, BF)
38-
testlu(A, mylu(A, p, Val(false)), BF)
39-
A′ = permutedims(A)
40-
MF′ = mylu(A′', p)
41-
testlu(A′', MF′, BF)
42-
testlu(A′', mylu(A′', p, Val(false)), BF)
43-
i = rand(1:s) # test `MF.info`
44-
A[:, i] .= 0
45-
MF = mylu(A, p, check = false)
46-
BF = baselu(A, p, check = false)
47-
testlu(A, MF, BF)
48-
testlu(A, mylu(A, p, Val(false), check = false), BF)
4965
end
50-
end end
66+
end

0 commit comments

Comments
 (0)