Skip to content

Commit b1b2e6b

Browse files
fix threading and other small things
1 parent c83fb06 commit b1b2e6b

File tree

2 files changed

+18
-38
lines changed

2 files changed

+18
-38
lines changed

src/butterflylu.jl

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED}
2020
(uv,)
2121
end
2222

23-
function 🦋workspace(A, B::Matrix{T}, U::Adjoint{T, Matrix{T}}, V::Matrix{T}, ::Val{SEED} = Val(888)) where {T, SEED}
23+
function 🦋workspace(A, b, B::Matrix{T}, U::Adjoint{T, Matrix{T}}, V::Matrix{T}, thread, ::Val{SEED} = Val(888)) where {T, SEED}
2424
M = size(A, 1)
2525
if (M % 4 != 0)
2626
A = pad!(A)
@@ -29,9 +29,10 @@ function 🦋workspace(A, B::Matrix{T}, U::Adjoint{T, Matrix{T}}, V::Matrix{T},
2929
ws = 🦋generate_random!(copyto!(B, A))
3030
🦋mul!(copyto!(B, A), ws)
3131
U, V = materializeUV(B, ws)
32-
F = RecursiveFactorization.lu!(B, Val(false))
32+
F = RecursiveFactorization.lu!(B, thread)
33+
out = similar(b, M)
3334

34-
U, V, F
35+
U, V, F, out
3536
end
3637

3738
const butterfly_workspace = 🦋workspace;
@@ -41,14 +42,12 @@ function 🦋mul_level!(A, u, v)
4142
@assert M == length(u) && N == length(v)
4243
Mh = M >>> 1
4344
Nh = N >>> 1
44-
M2 = M - Mh
45-
N2 = N - Nh
4645
@turbo for n in 1 : Nh
4746
for m in 1 : Mh
4847
A11 = A[m, n]
49-
A21 = A[m + M2, n]
50-
A12 = A[m, n + N2]
51-
A22 = A[m + M2, n + N2]
48+
A21 = A[m + Mh, n]
49+
A12 = A[m, n + Nh]
50+
A22 = A[m + Mh, n + Nh]
5251

5352
T1 = A11 + A12
5453
T2 = A21 + A22
@@ -60,36 +59,16 @@ function 🦋mul_level!(A, u, v)
6059
C22 = T3 - T4
6160

6261
u1 = u[m]
63-
u2 = u[m + M2]
62+
u2 = u[m + Mh]
6463
v1 = v[n]
65-
v2 = v[n + N2]
64+
v2 = v[n + Nh]
6665

6766
A[m, n] = u1 * C11 * v1
68-
A[m + M2, n] = u2 * C21 * v1
69-
A[m, n + N2] = u1 * C12 * v2
70-
A[m + M2, n + N2] = u2 * C22 * v2
67+
A[m + Mh, n] = u2 * C21 * v1
68+
A[m, n + Nh] = u1 * C12 * v2
69+
A[m + Mh, n + Nh] = u2 * C22 * v2
7170
end
7271
end
73-
#=
74-
if (N % 2 == 1) # N odd
75-
n = N2
76-
for m in 1:M
77-
A[m, n] = u[m] * A[m, n] * v[n]
78-
end
79-
end
80-
81-
if (M % 2 == 1) # M odd
82-
m = M2
83-
for n in 1:N
84-
A[m, n] = u[m] * A[m, n] * v[n]
85-
end
86-
end
87-
88-
if (M % 2 == 1) && (N % 2 == 1)
89-
m = M2
90-
n = N2
91-
A[m, n] /= (u[m] * v[n])
92-
end =#
9372
end
9473

9574
function 🦋mul!(A, (uv,))
@@ -98,8 +77,8 @@ function 🦋mul!(A, (uv,))
9877
Mh = M >>> 1
9978

10079
U₁ = @view(uv[1:Mh])
101-
V₁ = @view(uv[(Mh + 1):(2 * Mh)])
102-
U₂ = @view(uv[(1 + 2 * Mh):(M + Mh)])
80+
V₁ = @view(uv[(Mh + 1):(M)])
81+
U₂ = @view(uv[(1 + M):(M + Mh)])
10382
V₂ = @view(uv[(1 + M + Mh):(2 * M)])
10483

10584
🦋mul_level!(@view(A[1:Mh, 1:Mh]), U₁, V₁)

test/runtests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,15 @@ end
8181
for i in 790 : 810
8282
A = wilkinson(i)
8383
b = rand(i)
84-
U, V, F = RecursiveFactorization.🦋workspace(A, A, A', A)
84+
U, V, F, out = RecursiveFactorization.🦋workspace(A, b, A, A', A, Val(true))
8585
M = size(A, 1)
8686
xn = 4 - M % 4
8787
if (M % 4 != 0)
8888
xn = 4 - M % 4
8989
b = [b; rand(xn)]
9090
end
91-
x = V * (F \ (U * b))
92-
@test norm(A * x[1:M] .- b[1:M]) <= 1e-10
91+
sol = V * (F \ (U * b))
92+
out .= @view sol[1:M]
93+
@test norm(A * out .- b[1:M]) <= 1e-10
9394
end
9495
end

0 commit comments

Comments
 (0)