diff --git a/src/butterflylu.jl b/src/butterflylu.jl index 6dc1483..dd5f76d 100644 --- a/src/butterflylu.jl +++ b/src/butterflylu.jl @@ -5,7 +5,7 @@ using RecursiveFactorization using SparseBandedMatrices @inline exphalf(x) = exp(x) * oftype(x, 0.5) -function 🦋!(wv, ::Val{SEED} = Val(888)) where {SEED} +function generate_rand_butterfly_vals!(wv, ::Val{SEED} = Val(888)) where {SEED} T = eltype(wv) mrng = VectorizedRNG.MutableXoshift(SEED) GC.@preserve mrng begin rand!(exphalf, VectorizedRNG.Xoshift(mrng), wv, static(0), @@ -13,26 +13,39 @@ function 🦋!(wv, ::Val{SEED} = Val(888)) where {SEED} end function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED} - Usz = 2 * size(A, 1) - Vsz = 2 * size(A, 2) - uv = similar(A, Usz + Vsz) - 🦋!(uv, Val(SEED)) - (uv,) + uv = similar(A, 4 * size(A, 1)) + generate_rand_butterfly_vals!(uv, Val(SEED)) + uv end - -function 🦋workspace(A, b, B::Matrix{T}, U::Adjoint{T, Matrix{T}}, V::Matrix{T}, thread, ::Val{SEED} = Val(888)) where {T, SEED} - M = size(A, 1) - if (M % 4 != 0) - A = pad!(A) +struct 🦋workspace{T} + A::Matrix{T} + b::Vector{T} + ws::Vector{T} + U::Matrix{T} + V::Matrix{T} + out::Vector{T} + function 🦋workspace(A, b, ::Val{SEED} = Val(888)) where {SEED} + M = size(A, 1) + out = similar(b, M) + if (M % 4 != 0) + A = pad!(A) + xn = 4 - M % 4 + b = [b; rand(xn)] + end + U, V = (similar(A), similar(A)) + ws = 🦋generate_random!(A) + materializeUV(U, V, ws) + new{eltype(A)}(A, b, ws, U, V, out) end - B = similar(A) - ws = 🦋generate_random!(copyto!(B, A)) - 🦋mul!(copyto!(B, A), ws) - U, V = materializeUV(B, ws) - F = RecursiveFactorization.lu!(B, thread) - out = similar(b, M) - - U, V, F, out +end + +function 🦋lu!(workspace::🦋workspace, M, thread) + (;A, b, ws, U, V, out) = workspace + 🦋mul!(A, ws) + F = RecursiveFactorization.lu!(A, Val(false), thread) + sol = V * (F \ (U' * b)) + out .= @view sol[1:M] + out end const butterfly_workspace = 🦋workspace; @@ -40,14 +53,14 @@ const butterfly_workspace = 🦋workspace; function 🦋mul_level!(A, u, v) M, N = size(A) @assert M == length(u) && N == length(v) - Mh = M >>> 1 - Nh = N >>> 1 - @turbo for n in 1 : Nh - for m in 1 : Mh + M_half = M >>> 1 + N_half = N >>> 1 + @turbo for n in 1 : N_half + for m in 1 : M_half A11 = A[m, n] - A21 = A[m + Mh, n] - A12 = A[m, n + Nh] - A22 = A[m + Mh, n + Nh] + A21 = A[m + M_half, n] + A12 = A[m, n + N_half] + A22 = A[m + M_half, n + N_half] T1 = A11 + A12 T2 = A21 + A22 @@ -59,32 +72,32 @@ function 🦋mul_level!(A, u, v) C22 = T3 - T4 u1 = u[m] - u2 = u[m + Mh] + u2 = u[m + M_half] v1 = v[n] - v2 = v[n + Nh] + v2 = v[n + N_half] A[m, n] = u1 * C11 * v1 - A[m + Mh, n] = u2 * C21 * v1 - A[m, n + Nh] = u1 * C12 * v2 - A[m + Mh, n + Nh] = u2 * C22 * v2 + A[m + M_half, n] = u2 * C21 * v1 + A[m, n + N_half] = u1 * C12 * v2 + A[m + M_half, n + N_half] = u2 * C22 * v2 end end end -function 🦋mul!(A, (uv,)) +function 🦋mul!(A, uv) M, N = size(A) @assert M == N - Mh = M >>> 1 + M_half = M >>> 1 - U₁ = @view(uv[1:Mh]) - V₁ = @view(uv[(Mh + 1):(M)]) - U₂ = @view(uv[(1 + M):(M + Mh)]) - V₂ = @view(uv[(1 + M + Mh):(2 * M)]) + U₁ = @view(uv[1:M_half]) + V₁ = @view(uv[(M_half + 1):(M)]) + U₂ = @view(uv[(1 + M):(M + M_half)]) + V₂ = @view(uv[(1 + M + M_half):(2 * M)]) - 🦋mul_level!(@view(A[1:Mh, 1:Mh]), U₁, V₁) - 🦋mul_level!(@view(A[Mh + 1:M, 1:Mh]), U₂, V₁) - 🦋mul_level!(@view(A[1:Mh, Mh + 1:M]), U₁, V₂) - 🦋mul_level!(@view(A[Mh + 1:M, Mh + 1:M]), U₂, V₂) + 🦋mul_level!(@view(A[1:M_half, 1:M_half]), U₁, V₁) + 🦋mul_level!(@view(A[M_half + 1:M, 1:M_half]), U₂, V₁) + 🦋mul_level!(@view(A[1:M_half, M_half + 1:M]), U₁, V₂) + 🦋mul_level!(@view(A[M_half + 1:M, M_half + 1:M]), U₂, V₂) U = @view(uv[(1 + 2 * M):(3 * M)]) V = @view(uv[(1 + 3 * M):(4 * M)]) @@ -106,7 +119,14 @@ function diagnegbottom(x) Diagonal(y), Diagonal(z) end -function 🦋2!(C, A::Diagonal, B::Diagonal) +function 🦋!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal) + setdiagonal!(C, [A.diag; -B.diag], true) + setdiagonal!(C, A.diag, true) + setdiagonal!(C, B.diag, false) + C +end + +function 🦋!(C, A::Diagonal, B::Diagonal) @assert size(A) == size(B) A1 = size(A, 1) @@ -120,61 +140,35 @@ function 🦋2!(C, A::Diagonal, B::Diagonal) C end -function 🦋!(A::Matrix, C::SparseBandedMatrix, X::Diagonal, Y::Diagonal) - @assert size(X) == size(Y) - if (size(X, 1) + size(Y, 1) != size(A, 1)) - x = size(A, 1) - size(X, 1) - size(Y, 1) - setdiagonal!(C, [X.diag; rand(x); -Y.diag], true) - setdiagonal!(C, X.diag, true) - setdiagonal!(C, Y.diag, false) - else - setdiagonal!(C, [X.diag; -Y.diag], true) - setdiagonal!(C, X.diag, true) - setdiagonal!(C, Y.diag, false) - end - - C -end - -function 🦋2!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal) - setdiagonal!(C, [A.diag; -B.diag], true) - setdiagonal!(C, A.diag, true) - setdiagonal!(C, B.diag, false) - C -end - -function materializeUV(A, (uv,)) - M, N = size(A) - Mh = M >>> 1 - Nh = N >>> 1 +function materializeUV(U, V, uv) + M = size(U, 1) + M_half = M >>> 1 - U₁u, U₁l = diagnegbottom(@view(uv[1:Mh])) #Mh - U₂u, U₂l = diagnegbottom(@view(uv[(1 + Mh + Nh):(M + Nh)])) #M2 - V₁u, V₁l = diagnegbottom(@view(uv[(Mh + 1):(Mh + Nh)])) #Nh - V₂u, V₂l = diagnegbottom(@view(uv[(1 + 2 * Mh + Nh):(2 * Mh + N)])) #N2 - Uu, Ul = diagnegbottom(@view(uv[(1 + M + N):(2 * M + N)])) #M - Vu, Vl = diagnegbottom(@view(uv[(1 + 2 * M + N):(2 * M + 2 * N)])) #N + U₁u, U₁l = diagnegbottom(@view(uv[1:M_half])) #M_half + U₂u, U₂l = diagnegbottom(@view(uv[(1 + 2 * M_half):(M + M_half)])) #M_half + V₁u, V₁l = diagnegbottom(@view(uv[(M_half + 1):(2 * M_half)])) #M_half + V₂u, V₂l = diagnegbottom(@view(uv[(1 + 3 * M_half):(2 * M_half + M)])) #M_half + Uu, Ul = diagnegbottom(@view(uv[(1 + 2 * M):(3 * M)])) #M + Vu, Vl = diagnegbottom(@view(uv[(1 + 3 * M):(4 * M)])) #M - Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N) + Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M) - 🦋2!(view(Bu2, 1 : Mh, 1 : Nh), U₁u, U₁l) - 🦋2!(view(Bu2, Mh + 1: M, Nh + 1: N), U₂u, U₂l) + 🦋!(view(Bu2, 1 : M_half, 1 : M_half), U₁u, U₁l) + 🦋!(view(Bu2, M_half + 1: M, M_half + 1: M), U₂u, U₂l) - Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N) - 🦋!(A, Bu1, Uu, Ul) + Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M) + 🦋!(Bu1, Uu, Ul) - Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N) + Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M) - 🦋2!(view(Bv2, 1 : Mh, 1 : Nh), V₁u, V₁l) - 🦋2!(view(Bv2, Mh + 1: M, Nh + 1: N), V₂u, V₂l) + 🦋!(view(Bv2, 1 : M_half, 1 : M_half), V₁u, V₁l) + 🦋!(view(Bv2, M_half + 1: M, M_half + 1: M), V₂u, V₂l) - Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N) - 🦋!(A, Bv1, Vu, Vl) + Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M) + 🦋!(Bv1, Vu, Vl) - U = (Bu2 * Bu1)' - V = Bv2 * Bv1 - - U, V + mul!(U, Bu2, Bu1) + mul!(V, Bv2, Bv1) end function pad!(A) diff --git a/test/runtests.jl b/test/runtests.jl index e89285e..9c3789e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,15 +81,9 @@ end for i in 790 : 810 A = wilkinson(i) b = rand(i) - U, V, F, out = RecursiveFactorization.🦋workspace(A, b, A, A', A, Val(true)) - M = size(A, 1) - xn = 4 - M % 4 - if (M % 4 != 0) - xn = 4 - M % 4 - b = [b; rand(xn)] - end - sol = V * (F \ (U * b)) - out .= @view sol[1:M] - @test norm(A * out .- b[1:M]) <= 1e-10 + ws = RecursiveFactorization.🦋workspace(copy(A), copy(b)) + out = RecursiveFactorization.🦋lu!(ws, i, Val(true)) + @test norm(A * out .- b) <= 1e-10 end -end \ No newline at end of file +end +