Skip to content
Open
8 changes: 5 additions & 3 deletions benchmarks/lu.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using BenchmarkTools, Random, VectorizationBase
using LinearAlgebra, LinearSolve, MKL_jll
using RecursiveFactorization

nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads())
BLAS.set_num_threads(nc)
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5

function luflop(m, n = m; innerflop = 2)
Expand All @@ -24,10 +25,10 @@ algs = [
RFLUFactorization(),
MKLLUFactorization(),
FastLUFactorization(),
SimpleLUFactorization()
SimpleLUFactorization(),
ButterflyFactorization(Val(true))
]
res = [Float64[] for i in 1:length(algs)]

ns = 4:8:500
for i in 1:length(ns)
n = ns[i]
Expand Down Expand Up @@ -65,3 +66,4 @@ p

savefig("lubench.png")
savefig("lubench.pdf")

31 changes: 29 additions & 2 deletions ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearSolveRecursiveFactorizationExt

using LinearSolve: LinearSolve, userecursivefactorization, LinearCache, @get_cacheval,
RFLUFactorization, RF32MixedLUFactorization, default_alias_A,
RFLUFactorization, ButterflyFactorization, RF32MixedLUFactorization, default_alias_A,
default_alias_b
using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization
using SciMLBase: SciMLBase, ReturnCode
Expand All @@ -19,7 +19,6 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization
end
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
Expand Down Expand Up @@ -105,4 +104,32 @@ function SciMLBase.solve!(
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::ButterflyFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
b = cache.b
M, N = size(A)
B, U, V = cache.cacheval[2], cache.cacheval[3], cache.cacheval[4]
if cache.isfresh
@assert M==N "A must be square"
U, V, F, out = RecursiveFactorization.🦋workspace(A, b, B, U, V, alg.thread)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not just a struct and ! operation? It would be much easier to read. I assume this is all just in-place and non-allocating.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in make U, V, F, out a struct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

cache.cacheval = (A, B, U, V, F)
cache.isfresh = false
if (M % 4 != 0)
b = [b; rand(4 - M % 4)]
end
end
A, B, U, V, F = cache.cacheval
sol = V * (F \ (U * b))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is allocating and not using TriangularSolve.jl?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we found that TriangularSolve.jl was slower than this method, so we left it as is

out .= @view sol[1:M]
SciMLBase.build_linear_solution(alg, out, nothing, cache)
end

function LinearSolve.init_cacheval(alg::ButterflyFactorization, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::LinearSolve.OperatorAssumptions)
A, A, A', A, RecursiveFactorization.lu!(rand(1, 1), alg.thread)
end

end

4 changes: 2 additions & 2 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ for kralg in (Krylov.lsmr!, Krylov.craigmr!)
end
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:RFLUFactorization, :ButterflyFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
:MKLLUFactorization, :MetalLUFactorization, :CUSOLVERRFFactorization)
Expand Down Expand Up @@ -464,7 +464,7 @@ cudss_loaded(A) = false
is_cusparse(A) = false

export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization,
NormalCholeskyFactorization, NormalBunchKaufmanFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization, CholeskyFactorization,
Expand Down
23 changes: 23 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,29 @@ function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror =
RFLUFactorization(pivot, thread; throwerror)
end

"""
`ButterflyFactorization()`

A fast pure Julia LU-factorization implementation
using RecursiveFactorization.jl. This method utilizes a butterly
factorization approach rather than pivoting.
"""
struct ButterflyFactorization{T} <: AbstractDenseFactorization
thread::Val{T}
function ButterflyFactorization(::Val{T}; throwerror = true) where {T}
if !userecursivefactorization(nothing)
throwerror &&
error("ButterflyFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`")
end
new{T}()
end
end

function ButterflyFactorization(; thread = Val(true), throwerror = true)
ButterflyFactorization(thread; throwerror)
end


# There's no options like pivot here.
# But I'm not sure it makes sense as a GenericFactorization
# since it just uses `LAPACK.getrf!`.
Expand Down
35 changes: 35 additions & 0 deletions test/butterfly.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using LinearAlgebra, LinearSolve
using Test
using RecursiveFactorization

@testset "Random Matricies" begin
for i in 490 : 510
A = rand(i, i)
b = rand(i)
prob = LinearProblem(A, b)
x = solve(prob, ButterflyFactorization())
@test norm(A * x .- b) <= 1e-4
end
end

function wilkinson(N)
A = zeros(N, N)
A[1:(N+1):N*N] .= 1
A[:, end] .= 1
for n in 1:(N - 1)
for r in (n + 1):N
@inbounds A[r, n] = -1
end
end
A
end

@testset "Wilkinson" begin
for i in 790 : 810
A = wilkinson(i)
b = rand(i)
prob = LinearProblem(A, b)
x = solve(prob, ButterflyFactorization())
@test norm(A * x .- b) <= 1e-10
end
end
Loading