Skip to content
Open
8 changes: 5 additions & 3 deletions benchmarks/lu.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using BenchmarkTools, Random, VectorizationBase
using LinearAlgebra, LinearSolve, MKL_jll
using RecursiveFactorization

nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads())
BLAS.set_num_threads(nc)
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5

function luflop(m, n = m; innerflop = 2)
Expand All @@ -24,10 +25,10 @@ algs = [
RFLUFactorization(),
MKLLUFactorization(),
FastLUFactorization(),
SimpleLUFactorization()
SimpleLUFactorization(),
ButterflyFactorization()
]
res = [Float64[] for i in 1:length(algs)]

ns = 4:8:500
for i in 1:length(ns)
n = ns[i]
Expand Down Expand Up @@ -65,3 +66,4 @@ p

savefig("lubench.png")
savefig("lubench.pdf")

102 changes: 101 additions & 1 deletion ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization
end
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
Expand Down Expand Up @@ -105,4 +104,105 @@ function SciMLBase.solve!(
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
end

# Mixed precision RecursiveFactorization implementation
LinearSolve.default_alias_A(::RF32MixedLUFactorization, ::Any, ::Any) = false
LinearSolve.default_alias_b(::RF32MixedLUFactorization, ::Any, ::Any) = false

const PREALLOCATED_RF32_LU = begin
A = rand(Float32, 0, 0)
luinst = ArrayInterface.lu_instance(A)
(luinst, Vector{LinearAlgebra.BlasInt}(undef, 0))
end

function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::LinearSolve.OperatorAssumptions) where {P, T}
# Pre-allocate appropriate 32-bit arrays based on input type
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
A_32 = similar(A, T32)
b_32 = similar(b, T32)
u_32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(m, n))
# Return tuple with pre-allocated arrays
(luinst, ipiv, A_32, b_32, u_32)
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::RF32MixedLUFactorization{P, T};
kwargs...) where {P, T}
A = cache.A
A = convert(AbstractMatrix, A)

if cache.isfresh
# Get pre-allocated arrays from cacheval
luinst, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
# Compute 32-bit type on demand and copy A
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
A_32 .= T32.(A)

# Ensure ipiv is the right size
if length(ipiv) != min(size(A_32)...)
resize!(ipiv, min(size(A_32)...))
end

fact = RecursiveFactorization.lu!(A_32, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv, A_32, b_32, u_32)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
end

cache.isfresh = false
end

# Get the factorization and pre-allocated arrays from the cache
fact_cached, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)

# Compute types on demand for conversions
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Copy b to pre-allocated 32-bit array
b_32 .= T32.(cache.b)

# Solve in 32-bit precision
ldiv!(u_32, fact_cached, b_32)

# Convert back to original precision
cache.u .= Torig.(u_32)

SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::ButterflyFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
b = cache.b
M, N = size(A)
B, U, V = cache.cacheval[2], cache.cacheval[3], cache.cacheval[4]
if cache.isfresh
@assert M==N "A must be square"
U, V, F = RecursiveFactorization.🦋workspace(A, B, U, V)
cache.cacheval = (A, B, U, V, F)
cache.isfresh = false
if (M % 4 != 0)
b = [b; rand(4 - M % 4)]
end
end
A, B, U, V, F = cache.cacheval
sol = V * (F \ (U * b))
SciMLBase.build_linear_solution(alg, sol[1:M], nothing, cache)
end

function LinearSolve.init_cacheval(alg::ButterflyFactorization, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
A, A, A', A, RecursiveFactorization.lu!(rand(1, 1), Val(false))
end

end

4 changes: 2 additions & 2 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ for kralg in (Krylov.lsmr!, Krylov.craigmr!)
end
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:RFLUFactorization, :ButterflyFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
:MKLLUFactorization, :MetalLUFactorization, :CUSOLVERRFFactorization)
Expand Down Expand Up @@ -464,7 +464,7 @@ cudss_loaded(A) = false
is_cusparse(A) = false

export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization,
NormalCholeskyFactorization, NormalBunchKaufmanFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization, CholeskyFactorization,
Expand Down
22 changes: 22 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,28 @@ function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror =
RFLUFactorization(pivot, thread; throwerror)
end

"""
`ButterflyFactorization()`

A fast pure Julia LU-factorization implementation
using RecursiveFactorization.jl. This approach utilizes a butterly
factorization approach rather than pivoting.
"""
struct ButterflyFactorization{T} <: AbstractDenseFactorization
function ButterflyFactorization(::Val{T}; throwerror = true) where {T}
if !userecursivefactorization(nothing)
throwerror &&
error("ButterflyFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`")
end
new{T}()
end
end

function ButterflyFactorization(; thread = Val(true), throwerror = true)
ButterflyFactorization(thread; throwerror)
end


# There's no options like pivot here.
# But I'm not sure it makes sense as a GenericFactorization
# since it just uses `LAPACK.getrf!`.
Expand Down
35 changes: 35 additions & 0 deletions test/butterfly.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using LinearAlgebra, LinearSolve
using Test
using RecursiveFactorization

@testset "Random Matricies" begin
for i in 490 : 510
A = rand(i, i)
b = rand(i)
prob = LinearProblem(A, b)
x = solve(prob, ButterflyFactorization())
@test norm(A * x .- b) <= 1e-4
end
end

function wilkinson(N)
A = zeros(N, N)
A[1:(N+1):N*N] .= 1
A[:, end] .= 1
for n in 1:(N - 1)
for r in (n + 1):N
@inbounds A[r, n] = -1
end
end
A
end

@testset "Wilkinson" begin
for i in 790 : 810
A = wilkinson(i)
b = rand(i)
prob = LinearProblem(A, b)
x = solve(prob, ButterflyFactorization())
@test norm(A * x .- b) <= 1e-10
end
end