diff --git a/src/butterflylu.jl b/src/butterflylu.jl index dd5f76d..6e5a691 100644 --- a/src/butterflylu.jl +++ b/src/butterflylu.jl @@ -24,27 +24,32 @@ struct 🦋workspace{T} U::Matrix{T} V::Matrix{T} out::Vector{T} + n::Int function 🦋workspace(A, b, ::Val{SEED} = Val(888)) where {SEED} - M = size(A, 1) - out = similar(b, M) - if (M % 4 != 0) + n = size(A, 1) + out = similar(b, n) + if (n % 4 != 0) A = pad!(A) - xn = 4 - M % 4 + xn = 4 - n % 4 b = [b; rand(xn)] end U, V = (similar(A), similar(A)) ws = 🦋generate_random!(A) materializeUV(U, V, ws) - new{eltype(A)}(A, b, ws, U, V, out) + new{eltype(A)}(A, b, ws, U, V, out, n) end end -function 🦋lu!(workspace::🦋workspace, M, thread) - (;A, b, ws, U, V, out) = workspace +function 🦋solve!(workspace::🦋workspace, thread) + (;A, b, ws, U, V, out, n) = workspace 🦋mul!(A, ws) F = RecursiveFactorization.lu!(A, Val(false), thread) - sol = V * (F \ (U' * b)) - out .= @view sol[1:M] + + mul!(b, U', b) + ldiv!(b, UnitLowerTriangular(F.factors), b, thread) + ldiv!(b, UpperTriangular(F.factors), b, thread) + mul!(b, V, b) + out .= @view b[1:n] out end diff --git a/test/runtests.jl b/test/runtests.jl index 9c3789e..8f6c2d4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,7 +82,7 @@ end A = wilkinson(i) b = rand(i) ws = RecursiveFactorization.🦋workspace(copy(A), copy(b)) - out = RecursiveFactorization.🦋lu!(ws, i, Val(true)) + out = RecursiveFactorization.🦋solve!(ws, Val(true)) @test norm(A * out .- b) <= 1e-10 end end