Skip to content

Commit 0e635a2

Browse files
refactor butterfly into new struct
1 parent 11f3224 commit 0e635a2

File tree

2 files changed

+64
-75
lines changed

2 files changed

+64
-75
lines changed

src/butterflylu.jl

Lines changed: 60 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,49 @@ using RecursiveFactorization
55
using SparseBandedMatrices
66

77
@inline exphalf(x) = exp(x) * oftype(x, 0.5)
8-
function 🦋!(wv, ::Val{SEED} = Val(888)) where {SEED}
8+
function generate_rand_butterfly_vals!(wv, ::Val{SEED} = Val(888)) where {SEED}
99
T = eltype(wv)
1010
mrng = VectorizedRNG.MutableXoshift(SEED)
1111
GC.@preserve mrng begin rand!(exphalf, VectorizedRNG.Xoshift(mrng), wv, static(0),
1212
T(-0.05), T(0.1)) end
1313
end
1414

1515
function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED}
16-
Usz = 2 * size(A, 1)
17-
Vsz = 2 * size(A, 2)
18-
uv = similar(A, Usz + Vsz)
19-
🦋!(uv, Val(SEED))
20-
(uv,)
16+
uv = similar(A, 4 * size(A, 1))
17+
generate_rand_butterfly_vals!(uv, Val(SEED))
18+
uv
2119
end
22-
23-
function 🦋workspace(A, b, B::Matrix{T}, U::Adjoint{T, Matrix{T}}, V::Matrix{T}, thread, ::Val{SEED} = Val(888)) where {T, SEED}
24-
M = size(A, 1)
25-
if (M % 4 != 0)
26-
A = pad!(A)
20+
struct 🦋workspace{T}
21+
A::Matrix{T}
22+
b::Vector{T}
23+
B::Matrix{T}
24+
ws::Vector{T}
25+
U::Matrix{T}
26+
V::Matrix{T}
27+
out::Vector{T}
28+
function 🦋workspace(A, b, ::Val{SEED} = Val(888)) where {SEED}
29+
M = size(A, 1)
30+
out = similar(b, M)
31+
if (M % 4 != 0)
32+
A = pad!(A)
33+
xn = 4 - M % 4
34+
b = [b; rand(xn)]
35+
end
36+
B = similar(A)
37+
U, V = (similar(A), similar(A))
38+
ws = 🦋generate_random!(A)
39+
new{eltype(A)}(A, b, B, ws, U, V, out)
2740
end
28-
B = similar(A)
29-
ws = 🦋generate_random!(copyto!(B, A))
41+
end
42+
43+
function 🦋lu!(workspace::🦋workspace, M, thread)
44+
(;A, b, B, ws, U, V, out) = workspace
3045
🦋mul!(copyto!(B, A), ws)
31-
U, V = materializeUV(B, ws)
46+
materializeUV(U, V, ws)
3247
F = RecursiveFactorization.lu!(B, thread)
33-
out = similar(b, M)
34-
35-
U, V, F, out
48+
sol = V * (F \ (U' * b))
49+
out .= @view sol[1:M]
50+
out
3651
end
3752

3853
const butterfly_workspace = 🦋workspace;
@@ -71,7 +86,7 @@ function 🦋mul_level!(A, u, v)
7186
end
7287
end
7388

74-
function 🦋mul!(A, (uv,))
89+
function 🦋mul!(A, uv)
7590
M, N = size(A)
7691
@assert M == N
7792
Mh = M >>> 1
@@ -106,6 +121,13 @@ function diagnegbottom(x)
106121
Diagonal(y), Diagonal(z)
107122
end
108123

124+
function 🦋!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal)
125+
setdiagonal!(C, [A.diag; -B.diag], true)
126+
setdiagonal!(C, A.diag, true)
127+
setdiagonal!(C, B.diag, false)
128+
C
129+
end
130+
109131
function 🦋2!(C, A::Diagonal, B::Diagonal)
110132
@assert size(A) == size(B)
111133
A1 = size(A, 1)
@@ -120,61 +142,35 @@ function 🦋2!(C, A::Diagonal, B::Diagonal)
120142
C
121143
end
122144

123-
function 🦋!(A::Matrix, C::SparseBandedMatrix, X::Diagonal, Y::Diagonal)
124-
@assert size(X) == size(Y)
125-
if (size(X, 1) + size(Y, 1) != size(A, 1))
126-
x = size(A, 1) - size(X, 1) - size(Y, 1)
127-
setdiagonal!(C, [X.diag; rand(x); -Y.diag], true)
128-
setdiagonal!(C, X.diag, true)
129-
setdiagonal!(C, Y.diag, false)
130-
else
131-
setdiagonal!(C, [X.diag; -Y.diag], true)
132-
setdiagonal!(C, X.diag, true)
133-
setdiagonal!(C, Y.diag, false)
134-
end
135-
136-
C
137-
end
138-
139-
function 🦋2!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal)
140-
setdiagonal!(C, [A.diag; -B.diag], true)
141-
setdiagonal!(C, A.diag, true)
142-
setdiagonal!(C, B.diag, false)
143-
C
144-
end
145-
146-
function materializeUV(A, (uv,))
147-
M, N = size(A)
148-
Mh = M >>> 1
149-
Nh = N >>> 1
145+
function materializeUV(U, V, uv)
146+
M = size(U, 1)
147+
Mh = M >>> 1
150148

151149
U₁u, U₁l = diagnegbottom(@view(uv[1:Mh])) #Mh
152-
U₂u, U₂l = diagnegbottom(@view(uv[(1 + Mh + Nh):(M + Nh)])) #M2
153-
V₁u, V₁l = diagnegbottom(@view(uv[(Mh + 1):(Mh + Nh)])) #Nh
154-
V₂u, V₂l = diagnegbottom(@view(uv[(1 + 2 * Mh + Nh):(2 * Mh + N)])) #N2
155-
Uu, Ul = diagnegbottom(@view(uv[(1 + M + N):(2 * M + N)])) #M
156-
Vu, Vl = diagnegbottom(@view(uv[(1 + 2 * M + N):(2 * M + 2 * N)])) #N
150+
U₂u, U₂l = diagnegbottom(@view(uv[(1 + 2 * Mh):(M + Mh)])) #Mh
151+
V₁u, V₁l = diagnegbottom(@view(uv[(Mh + 1):(2 * Mh)])) #Mh
152+
V₂u, V₂l = diagnegbottom(@view(uv[(1 + 3 * Mh):(2 * Mh + M)])) #Mh
153+
Uu, Ul = diagnegbottom(@view(uv[(1 + 2 * M):(3 * M)])) #M
154+
Vu, Vl = diagnegbottom(@view(uv[(1 + 3 * M):(4 * M)])) #M
157155

158-
Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
156+
Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
159157

160-
🦋2!(view(Bu2, 1 : Mh, 1 : Nh), U₁u, U₁l)
161-
🦋2!(view(Bu2, Mh + 1: M, Nh + 1: N), U₂u, U₂l)
158+
🦋2!(view(Bu2, 1 : Mh, 1 : Mh), U₁u, U₁l)
159+
🦋2!(view(Bu2, Mh + 1: M, Mh + 1: M), U₂u, U₂l)
162160

163-
Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
164-
🦋!(A, Bu1, Uu, Ul)
161+
Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
162+
🦋!(Bu1, Uu, Ul)
165163

166-
Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
164+
Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
167165

168-
🦋2!(view(Bv2, 1 : Mh, 1 : Nh), V₁u, V₁l)
169-
🦋2!(view(Bv2, Mh + 1: M, Nh + 1: N), V₂u, V₂l)
166+
🦋2!(view(Bv2, 1 : Mh, 1 : Mh), V₁u, V₁l)
167+
🦋2!(view(Bv2, Mh + 1: M, Mh + 1: M), V₂u, V₂l)
170168

171-
Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
172-
🦋!(A, Bv1, Vu, Vl)
169+
Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
170+
🦋!(Bv1, Vu, Vl)
173171

174-
U = (Bu2 * Bu1)'
175-
V = Bv2 * Bv1
176-
177-
U, V
172+
mul!(U, Bu2, Bu1)
173+
mul!(V, Bv2, Bv1)
178174
end
179175

180176
function pad!(A)

test/runtests.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,8 @@ end
8181
for i in 790 : 810
8282
A = wilkinson(i)
8383
b = rand(i)
84-
U, V, F, out = RecursiveFactorization.🦋workspace(A, b, A, A', A, Val(true))
85-
M = size(A, 1)
86-
xn = 4 - M % 4
87-
if (M % 4 != 0)
88-
xn = 4 - M % 4
89-
b = [b; rand(xn)]
90-
end
91-
sol = V * (F \ (U * b))
92-
out .= @view sol[1:M]
93-
@test norm(A * out .- b[1:M]) <= 1e-10
84+
ws = RecursiveFactorization.🦋workspace(A, b)
85+
out = RecursiveFactorization.🦋lu!(ws, i, Val(true))
86+
@test norm(A * out .- b) <= 1e-10
9487
end
95-
end
88+
end

0 commit comments

Comments
 (0)