Skip to content

Commit 45d4af5

Browse files
committed
Format and add ldiv support for NotIPIV
1 parent afec32d commit 45d4af5

File tree

4 files changed

+102
-88
lines changed

4 files changed

+102
-88
lines changed

perf/lu.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,22 @@ else
5252
BLAS.vendor() === :mkl ? :MKL : :OpenBLAS
5353
end
5454
df = DataFrame(Size = ns,
55-
Reference = ref_mflops)
55+
Reference = ref_mflops)
5656
setproperty!(df, blaslib, bas_mflops)
5757
setproperty!(df, Symbol("RF with default threshold"), rec_mflops)
5858
setproperty!(df, Symbol("RF fully recursive"), rec4_mflops)
5959
setproperty!(df, Symbol("RF fully iterative"), rec800_mflops)
6060
df = stack(df,
61-
[Symbol("RF with default threshold"),
62-
Symbol("RF fully recursive"),
63-
Symbol("RF fully iterative"),
64-
blaslib,
65-
:Reference], variable_name = :Library, value_name = :GFLOPS)
61+
[Symbol("RF with default threshold"),
62+
Symbol("RF fully recursive"),
63+
Symbol("RF fully iterative"),
64+
blaslib,
65+
:Reference], variable_name = :Library, value_name = :GFLOPS)
6666
plt = df |> @vlplot(:line, color={:Library, scale = {scheme = "category10"}},
67-
x={:Size}, y={:GFLOPS},
68-
width=1000, height=600)
67+
x={:Size}, y={:GFLOPS},
68+
width=1000, height=600)
6969
save(joinpath(homedir(), "Pictures",
70-
"lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
70+
"lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
7171

7272
#=
7373
using Plot

src/RecursiveFactorization.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ include("./lu.jl")
77

88
import PrecompileTools
99

10-
PrecompileTools.@compile_workload begin lu!(rand(2, 2)) end
10+
PrecompileTools.@compile_workload begin
11+
lu!(rand(2, 2))
12+
end
1113

1214
end # module

src/lu.jl

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using LoopVectorization
22
using TriangularSolve: ldiv!
33
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4-
LinearAlgebra, Adjoint, Transpose
4+
LinearAlgebra, Adjoint, Transpose, UpperTriangular
55
using StrideArraysCore
66
using Polyester: @batch
77

@@ -41,16 +41,22 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
4141

4242
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_cols!)
4343
function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
44-
B::StridedVecOrMat)
44+
B::StridedVecOrMat)
4545
return B
4646
end
4747
end
4848
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_rows!)
4949
function LinearAlgebra._ipiv_rows!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
50-
B::StridedVecOrMat)
50+
B::StridedVecOrMat)
5151
return B
5252
end
5353
end
54+
if CUSTOMIZABLE_PIVOT
55+
function LinearAlgebra.ldiv!(A::LU{T, <:StridedMatrix, <:NotIPIV},
56+
B::StridedVecOrMat{T}) where {T <: BlasFloat}
57+
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B))
58+
end
59+
end
5460

5561
function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
5662
m, n = size(A)
@@ -80,11 +86,11 @@ recurse(_) = false
8086
_ptrarray(ipiv) = PtrArray(ipiv)
8187
_ptrarray(ipiv::NotIPIV) = ipiv
8288
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
83-
pivot = Val(true), thread = Val(true);
84-
check::Bool = true,
85-
# the performance is not sensitive wrt blocksize, and 8 is a good default
86-
blocksize::Integer = length(A) 40_000 ? 8 : 16,
87-
threshold::Integer = pick_threshold()) where {T}
89+
pivot = Val(true), thread = Val(true);
90+
check::Bool = true,
91+
# the performance is not sensitive wrt blocksize, and 8 is a good default
92+
blocksize::Integer = length(A) 40_000 ? 8 : 16,
93+
threshold::Integer = pick_threshold()) where {T}
8894
pivot = normalize_pivot(pivot)
8995
info = zero(BlasInt)
9096
m, n = size(A)
@@ -94,10 +100,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
94100
end
95101
if recurse(A) && mnmin > threshold
96102
if T <: Union{Float32, Float64}
97-
GC.@preserve ipiv A begin info = recurse!(view(PtrArray(A), axes(A)...), pivot,
98-
m, n, mnmin,
99-
_ptrarray(ipiv), info, blocksize,
100-
thread) end
103+
GC.@preserve ipiv A begin
104+
info = recurse!(view(PtrArray(A), axes(A)...), pivot,
105+
m, n, mnmin,
106+
_ptrarray(ipiv), info, blocksize,
107+
thread)
108+
end
101109
else
102110
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
103111
end
@@ -109,7 +117,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
109117
end
110118

111119
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
112-
::Val{true}) where {Pivot}
120+
::Val{true}) where {Pivot}
113121
if length(A) * _sizeof(eltype(A)) >
114122
0.92 * LoopVectorization.VectorizationBase.cache_size(Val(2))
115123
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true))
@@ -118,11 +126,11 @@ end
118126
end
119127
end
120128
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
121-
::Val{false}) where {Pivot}
129+
::Val{false}) where {Pivot}
122130
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
123131
end
124132
@inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
125-
::Val{Thread}) where {Pivot, Thread}
133+
::Val{Thread}) where {Pivot, Thread}
126134
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int
127135
@inbounds if m < n # fat matrix
128136
# [AL AR]
@@ -166,7 +174,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
166174
nothing
167175
end
168176
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize,
169-
thread)::BlasInt where {T, Pivot}
177+
thread)::BlasInt where {T, Pivot}
170178
@inbounds begin
171179
if n <= max(blocksize, 1)
172180
info = _generic_lufact!(A, Val(Pivot), ipiv, info)
@@ -262,44 +270,46 @@ end
262270
function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
263271
m, n = size(A)
264272
minmn = length(ipiv)
265-
@inbounds begin for k in 1:minmn
266-
# find index max
267-
kp = k
268-
if Pivot
269-
amax = abs(zero(eltype(A)))
270-
for i in k:m
271-
absi = abs(A[i, k])
272-
if absi > amax
273-
kp = i
274-
amax = absi
273+
@inbounds begin
274+
for k in 1:minmn
275+
# find index max
276+
kp = k
277+
if Pivot
278+
amax = abs(zero(eltype(A)))
279+
for i in k:m
280+
absi = abs(A[i, k])
281+
if absi > amax
282+
kp = i
283+
amax = absi
284+
end
275285
end
286+
ipiv[k] = kp
276287
end
277-
ipiv[k] = kp
278-
end
279-
if !iszero(A[kp, k])
280-
if k != kp
281-
# Interchange
282-
@simd for i in 1:n
283-
tmp = A[k, i]
284-
A[k, i] = A[kp, i]
285-
A[kp, i] = tmp
288+
if !iszero(A[kp, k])
289+
if k != kp
290+
# Interchange
291+
@simd for i in 1:n
292+
tmp = A[k, i]
293+
A[k, i] = A[kp, i]
294+
A[kp, i] = tmp
295+
end
286296
end
297+
# Scale first column
298+
Akkinv = inv(A[k, k])
299+
@turbo check_empty=true warn_check_args=false for i in (k + 1):m
300+
A[i, k] *= Akkinv
301+
end
302+
elseif info == 0
303+
info = k
287304
end
288-
# Scale first column
289-
Akkinv = inv(A[k, k])
290-
@turbo check_empty=true warn_check_args=false for i in (k + 1):m
291-
A[i, k] *= Akkinv
292-
end
293-
elseif info == 0
294-
info = k
295-
end
296-
k == minmn && break
297-
# Update the rest
298-
@turbo warn_check_args=false for j in (k + 1):n
299-
for i in (k + 1):m
300-
A[i, j] -= A[i, k] * A[k, j]
305+
k == minmn && break
306+
# Update the rest
307+
@turbo warn_check_args=false for j in (k + 1):n
308+
for i in (k + 1):m
309+
A[i, j] -= A[i, k] * A[k, j]
310+
end
301311
end
302312
end
303-
end end
313+
end
304314
return info
305315
end

test/runtests.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,37 @@ function testlu(A, MF, BF)
1616
end
1717
testlu(A::Union{Transpose, Adjoint}, MF, BF) = testlu(parent(A), parent(MF), BF)
1818

19-
@testset "Test LU factorization" begin for _p in (true, false),
20-
T in (Float64, Float32, ComplexF64, ComplexF32,
21-
Real)
19+
@testset "Test LU factorization" begin
20+
for _p in (true, false),
21+
T in (Float64, Float32, ComplexF64, ComplexF32,
22+
Real)
2223

23-
p = Val(_p)
24-
for (i, s) in enumerate([1:10; 50:80:200; 300])
25-
iseven(i) && (p = RecursiveFactorization.to_stdlib_pivot(p))
26-
siz = (s, s + 2)
27-
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
28-
if isconcretetype(T)
29-
A = rand(T, siz...)
30-
else
31-
_A = rand(siz...)
32-
A = Matrix{T}(undef, siz...)
33-
copyto!(A, _A)
24+
p = Val(_p)
25+
for (i, s) in enumerate([1:10; 50:80:200; 300])
26+
iseven(i) && (p = RecursiveFactorization.to_stdlib_pivot(p))
27+
siz = (s, s + 2)
28+
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
29+
if isconcretetype(T)
30+
A = rand(T, siz...)
31+
else
32+
_A = rand(siz...)
33+
A = Matrix{T}(undef, siz...)
34+
copyto!(A, _A)
35+
end
36+
MF = mylu(A, p)
37+
BF = baselu(A, p)
38+
testlu(A, MF, BF)
39+
testlu(A, mylu(A, p, Val(false)), BF)
40+
A′ = permutedims(A)
41+
MF′ = mylu(A′', p)
42+
testlu(A′', MF′, BF)
43+
testlu(A′', mylu(A′', p, Val(false)), BF)
44+
i = rand(1:s) # test `MF.info`
45+
A[:, i] .= 0
46+
MF = mylu(A, p, check = false)
47+
BF = baselu(A, p, check = false)
48+
testlu(A, MF, BF)
49+
testlu(A, mylu(A, p, Val(false), check = false), BF)
3450
end
35-
MF = mylu(A, p)
36-
BF = baselu(A, p)
37-
testlu(A, MF, BF)
38-
testlu(A, mylu(A, p, Val(false)), BF)
39-
A′ = permutedims(A)
40-
MF′ = mylu(A′', p)
41-
testlu(A′', MF′, BF)
42-
testlu(A′', mylu(A′', p, Val(false)), BF)
43-
i = rand(1:s) # test `MF.info`
44-
A[:, i] .= 0
45-
MF = mylu(A, p, check = false)
46-
BF = baselu(A, p, check = false)
47-
testlu(A, MF, BF)
48-
testlu(A, mylu(A, p, Val(false), check = false), BF)
4951
end
50-
end end
52+
end

0 commit comments

Comments
 (0)