Skip to content

Commit 0c655a8

Browse files
Merge pull request #103 from Shreyas-Ekanathan/master
Clean up ButterflyFactorization
2 parents 9c716e6 + 28feecc commit 0c655a8

File tree

2 files changed

+88
-100
lines changed

2 files changed

+88
-100
lines changed

src/butterflylu.jl

Lines changed: 83 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,62 @@ using RecursiveFactorization
55
using SparseBandedMatrices
66

77
@inline exphalf(x) = exp(x) * oftype(x, 0.5)
8-
function 🦋!(wv, ::Val{SEED} = Val(888)) where {SEED}
8+
function generate_rand_butterfly_vals!(wv, ::Val{SEED} = Val(888)) where {SEED}
99
T = eltype(wv)
1010
mrng = VectorizedRNG.MutableXoshift(SEED)
1111
GC.@preserve mrng begin rand!(exphalf, VectorizedRNG.Xoshift(mrng), wv, static(0),
1212
T(-0.05), T(0.1)) end
1313
end
1414

1515
function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED}
16-
Usz = 2 * size(A, 1)
17-
Vsz = 2 * size(A, 2)
18-
uv = similar(A, Usz + Vsz)
19-
🦋!(uv, Val(SEED))
20-
(uv,)
16+
uv = similar(A, 4 * size(A, 1))
17+
generate_rand_butterfly_vals!(uv, Val(SEED))
18+
uv
2119
end
22-
23-
function 🦋workspace(A, b, B::Matrix{T}, U::Adjoint{T, Matrix{T}}, V::Matrix{T}, thread, ::Val{SEED} = Val(888)) where {T, SEED}
24-
M = size(A, 1)
25-
if (M % 4 != 0)
26-
A = pad!(A)
20+
struct 🦋workspace{T}
21+
A::Matrix{T}
22+
b::Vector{T}
23+
ws::Vector{T}
24+
U::Matrix{T}
25+
V::Matrix{T}
26+
out::Vector{T}
27+
function 🦋workspace(A, b, ::Val{SEED} = Val(888)) where {SEED}
28+
M = size(A, 1)
29+
out = similar(b, M)
30+
if (M % 4 != 0)
31+
A = pad!(A)
32+
xn = 4 - M % 4
33+
b = [b; rand(xn)]
34+
end
35+
U, V = (similar(A), similar(A))
36+
ws = 🦋generate_random!(A)
37+
materializeUV(U, V, ws)
38+
new{eltype(A)}(A, b, ws, U, V, out)
2739
end
28-
B = similar(A)
29-
ws = 🦋generate_random!(copyto!(B, A))
30-
🦋mul!(copyto!(B, A), ws)
31-
U, V = materializeUV(B, ws)
32-
F = RecursiveFactorization.lu!(B, thread)
33-
out = similar(b, M)
34-
35-
U, V, F, out
40+
end
41+
42+
function 🦋lu!(workspace::🦋workspace, M, thread)
43+
(;A, b, ws, U, V, out) = workspace
44+
🦋mul!(A, ws)
45+
F = RecursiveFactorization.lu!(A, Val(false), thread)
46+
sol = V * (F \ (U' * b))
47+
out .= @view sol[1:M]
48+
out
3649
end
3750

3851
const butterfly_workspace = 🦋workspace;
3952

4053
function 🦋mul_level!(A, u, v)
4154
M, N = size(A)
4255
@assert M == length(u) && N == length(v)
43-
Mh = M >>> 1
44-
Nh = N >>> 1
45-
@turbo for n in 1 : Nh
46-
for m in 1 : Mh
56+
M_half = M >>> 1
57+
N_half = N >>> 1
58+
@turbo for n in 1 : N_half
59+
for m in 1 : M_half
4760
A11 = A[m, n]
48-
A21 = A[m + Mh, n]
49-
A12 = A[m, n + Nh]
50-
A22 = A[m + Mh, n + Nh]
61+
A21 = A[m + M_half, n]
62+
A12 = A[m, n + N_half]
63+
A22 = A[m + M_half, n + N_half]
5164

5265
T1 = A11 + A12
5366
T2 = A21 + A22
@@ -59,32 +72,32 @@ function 🦋mul_level!(A, u, v)
5972
C22 = T3 - T4
6073

6174
u1 = u[m]
62-
u2 = u[m + Mh]
75+
u2 = u[m + M_half]
6376
v1 = v[n]
64-
v2 = v[n + Nh]
77+
v2 = v[n + N_half]
6578

6679
A[m, n] = u1 * C11 * v1
67-
A[m + Mh, n] = u2 * C21 * v1
68-
A[m, n + Nh] = u1 * C12 * v2
69-
A[m + Mh, n + Nh] = u2 * C22 * v2
80+
A[m + M_half, n] = u2 * C21 * v1
81+
A[m, n + N_half] = u1 * C12 * v2
82+
A[m + M_half, n + N_half] = u2 * C22 * v2
7083
end
7184
end
7285
end
7386

74-
function 🦋mul!(A, (uv,))
87+
function 🦋mul!(A, uv)
7588
M, N = size(A)
7689
@assert M == N
77-
Mh = M >>> 1
90+
M_half = M >>> 1
7891

79-
U₁ = @view(uv[1:Mh])
80-
V₁ = @view(uv[(Mh + 1):(M)])
81-
U₂ = @view(uv[(1 + M):(M + Mh)])
82-
V₂ = @view(uv[(1 + M + Mh):(2 * M)])
92+
U₁ = @view(uv[1:M_half])
93+
V₁ = @view(uv[(M_half + 1):(M)])
94+
U₂ = @view(uv[(1 + M):(M + M_half)])
95+
V₂ = @view(uv[(1 + M + M_half):(2 * M)])
8396

84-
🦋mul_level!(@view(A[1:Mh, 1:Mh]), U₁, V₁)
85-
🦋mul_level!(@view(A[Mh + 1:M, 1:Mh]), U₂, V₁)
86-
🦋mul_level!(@view(A[1:Mh, Mh + 1:M]), U₁, V₂)
87-
🦋mul_level!(@view(A[Mh + 1:M, Mh + 1:M]), U₂, V₂)
97+
🦋mul_level!(@view(A[1:M_half, 1:M_half]), U₁, V₁)
98+
🦋mul_level!(@view(A[M_half + 1:M, 1:M_half]), U₂, V₁)
99+
🦋mul_level!(@view(A[1:M_half, M_half + 1:M]), U₁, V₂)
100+
🦋mul_level!(@view(A[M_half + 1:M, M_half + 1:M]), U₂, V₂)
88101

89102
U = @view(uv[(1 + 2 * M):(3 * M)])
90103
V = @view(uv[(1 + 3 * M):(4 * M)])
@@ -106,7 +119,14 @@ function diagnegbottom(x)
106119
Diagonal(y), Diagonal(z)
107120
end
108121

109-
function 🦋2!(C, A::Diagonal, B::Diagonal)
122+
function 🦋!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal)
123+
setdiagonal!(C, [A.diag; -B.diag], true)
124+
setdiagonal!(C, A.diag, true)
125+
setdiagonal!(C, B.diag, false)
126+
C
127+
end
128+
129+
function 🦋!(C, A::Diagonal, B::Diagonal)
110130
@assert size(A) == size(B)
111131
A1 = size(A, 1)
112132

@@ -120,61 +140,35 @@ function 🦋2!(C, A::Diagonal, B::Diagonal)
120140
C
121141
end
122142

123-
function 🦋!(A::Matrix, C::SparseBandedMatrix, X::Diagonal, Y::Diagonal)
124-
@assert size(X) == size(Y)
125-
if (size(X, 1) + size(Y, 1) != size(A, 1))
126-
x = size(A, 1) - size(X, 1) - size(Y, 1)
127-
setdiagonal!(C, [X.diag; rand(x); -Y.diag], true)
128-
setdiagonal!(C, X.diag, true)
129-
setdiagonal!(C, Y.diag, false)
130-
else
131-
setdiagonal!(C, [X.diag; -Y.diag], true)
132-
setdiagonal!(C, X.diag, true)
133-
setdiagonal!(C, Y.diag, false)
134-
end
135-
136-
C
137-
end
138-
139-
function 🦋2!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal)
140-
setdiagonal!(C, [A.diag; -B.diag], true)
141-
setdiagonal!(C, A.diag, true)
142-
setdiagonal!(C, B.diag, false)
143-
C
144-
end
145-
146-
function materializeUV(A, (uv,))
147-
M, N = size(A)
148-
Mh = M >>> 1
149-
Nh = N >>> 1
143+
function materializeUV(U, V, uv)
144+
M = size(U, 1)
145+
M_half = M >>> 1
150146

151-
U₁u, U₁l = diagnegbottom(@view(uv[1:Mh])) #Mh
152-
U₂u, U₂l = diagnegbottom(@view(uv[(1 + Mh + Nh):(M + Nh)])) #M2
153-
V₁u, V₁l = diagnegbottom(@view(uv[(Mh + 1):(Mh + Nh)])) #Nh
154-
V₂u, V₂l = diagnegbottom(@view(uv[(1 + 2 * Mh + Nh):(2 * Mh + N)])) #N2
155-
Uu, Ul = diagnegbottom(@view(uv[(1 + M + N):(2 * M + N)])) #M
156-
Vu, Vl = diagnegbottom(@view(uv[(1 + 2 * M + N):(2 * M + 2 * N)])) #N
147+
U₁u, U₁l = diagnegbottom(@view(uv[1:M_half])) #M_half
148+
U₂u, U₂l = diagnegbottom(@view(uv[(1 + 2 * M_half):(M + M_half)])) #M_half
149+
V₁u, V₁l = diagnegbottom(@view(uv[(M_half + 1):(2 * M_half)])) #M_half
150+
V₂u, V₂l = diagnegbottom(@view(uv[(1 + 3 * M_half):(2 * M_half + M)])) #M_half
151+
Uu, Ul = diagnegbottom(@view(uv[(1 + 2 * M):(3 * M)])) #M
152+
Vu, Vl = diagnegbottom(@view(uv[(1 + 3 * M):(4 * M)])) #M
157153

158-
Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
154+
Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
159155

160-
🦋2!(view(Bu2, 1 : Mh, 1 : Nh), U₁u, U₁l)
161-
🦋2!(view(Bu2, Mh + 1: M, Nh + 1: N), U₂u, U₂l)
156+
🦋!(view(Bu2, 1 : M_half, 1 : M_half), U₁u, U₁l)
157+
🦋!(view(Bu2, M_half + 1: M, M_half + 1: M), U₂u, U₂l)
162158

163-
Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
164-
🦋!(A, Bu1, Uu, Ul)
159+
Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
160+
🦋!(Bu1, Uu, Ul)
165161

166-
Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
162+
Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
167163

168-
🦋2!(view(Bv2, 1 : Mh, 1 : Nh), V₁u, V₁l)
169-
🦋2!(view(Bv2, Mh + 1: M, Nh + 1: N), V₂u, V₂l)
164+
🦋!(view(Bv2, 1 : M_half, 1 : M_half), V₁u, V₁l)
165+
🦋!(view(Bv2, M_half + 1: M, M_half + 1: M), V₂u, V₂l)
170166

171-
Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
172-
🦋!(A, Bv1, Vu, Vl)
167+
Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, M)
168+
🦋!(Bv1, Vu, Vl)
173169

174-
U = (Bu2 * Bu1)'
175-
V = Bv2 * Bv1
176-
177-
U, V
170+
mul!(U, Bu2, Bu1)
171+
mul!(V, Bv2, Bv1)
178172
end
179173

180174
function pad!(A)

test/runtests.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,9 @@ end
8181
for i in 790 : 810
8282
A = wilkinson(i)
8383
b = rand(i)
84-
U, V, F, out = RecursiveFactorization.🦋workspace(A, b, A, A', A, Val(true))
85-
M = size(A, 1)
86-
xn = 4 - M % 4
87-
if (M % 4 != 0)
88-
xn = 4 - M % 4
89-
b = [b; rand(xn)]
90-
end
91-
sol = V * (F \ (U * b))
92-
out .= @view sol[1:M]
93-
@test norm(A * out .- b[1:M]) <= 1e-10
84+
ws = RecursiveFactorization.🦋workspace(copy(A), copy(b))
85+
out = RecursiveFactorization.🦋lu!(ws, i, Val(true))
86+
@test norm(A * out .- b) <= 1e-10
9487
end
95-
end
88+
end
89+

0 commit comments

Comments
 (0)