@@ -5,49 +5,62 @@ using RecursiveFactorization
5
5
using SparseBandedMatrices
6
6
7
7
@inline exphalf (x) = exp (x) * oftype (x, 0.5 )
8
- function 🦋 ! (wv, :: Val{SEED} = Val (888 )) where {SEED}
8
+ function generate_rand_butterfly_vals ! (wv, :: Val{SEED} = Val (888 )) where {SEED}
9
9
T = eltype (wv)
10
10
mrng = VectorizedRNG. MutableXoshift (SEED)
11
11
GC. @preserve mrng begin rand! (exphalf, VectorizedRNG. Xoshift (mrng), wv, static (0 ),
12
12
T (- 0.05 ), T (0.1 )) end
13
13
end
14
14
15
15
function 🦋generate_random! (A, :: Val{SEED} = Val (888 )) where {SEED}
16
- Usz = 2 * size (A, 1 )
17
- Vsz = 2 * size (A, 2 )
18
- uv = similar (A, Usz + Vsz)
19
- 🦋! (uv, Val (SEED))
20
- (uv,)
16
+ uv = similar (A, 4 * size (A, 1 ))
17
+ generate_rand_butterfly_vals! (uv, Val (SEED))
18
+ uv
21
19
end
22
-
23
- function 🦋workspace (A, b, B:: Matrix{T} , U:: Adjoint{T, Matrix{T}} , V:: Matrix{T} , thread, :: Val{SEED} = Val (888 )) where {T, SEED}
24
- M = size (A, 1 )
25
- if (M % 4 != 0 )
26
- A = pad! (A)
20
+ struct 🦋workspace{T}
21
+ A:: Matrix{T}
22
+ b:: Vector{T}
23
+ ws:: Vector{T}
24
+ U:: Matrix{T}
25
+ V:: Matrix{T}
26
+ out:: Vector{T}
27
+ function 🦋workspace (A, b, :: Val{SEED} = Val (888 )) where {SEED}
28
+ M = size (A, 1 )
29
+ out = similar (b, M)
30
+ if (M % 4 != 0 )
31
+ A = pad! (A)
32
+ xn = 4 - M % 4
33
+ b = [b; rand (xn)]
34
+ end
35
+ U, V = (similar (A), similar (A))
36
+ ws = 🦋generate_random! (A)
37
+ materializeUV (U, V, ws)
38
+ new {eltype(A)} (A, b, ws, U, V, out)
27
39
end
28
- B = similar (A)
29
- ws = 🦋generate_random! (copyto! (B, A))
30
- 🦋mul! (copyto! (B, A), ws)
31
- U, V = materializeUV (B, ws)
32
- F = RecursiveFactorization. lu! (B, thread)
33
- out = similar (b, M)
34
-
35
- U, V, F, out
40
+ end
41
+
42
+ function 🦋lu! (workspace: :🦋workspace, M, thread)
43
+ (;A, b, ws, U, V, out) = workspace
44
+ 🦋mul! (A, ws)
45
+ F = RecursiveFactorization. lu! (A, Val (false ), thread)
46
+ sol = V * (F \ (U' * b))
47
+ out .= @view sol[1 : M]
48
+ out
36
49
end
37
50
38
51
const butterfly_workspace = 🦋workspace;
39
52
40
53
function 🦋mul_level! (A, u, v)
41
54
M, N = size (A)
42
55
@assert M == length (u) && N == length (v)
43
- Mh = M >>> 1
44
- Nh = N >>> 1
45
- @turbo for n in 1 : Nh
46
- for m in 1 : Mh
56
+ M_half = M >>> 1
57
+ N_half = N >>> 1
58
+ @turbo for n in 1 : N_half
59
+ for m in 1 : M_half
47
60
A11 = A[m, n]
48
- A21 = A[m + Mh , n]
49
- A12 = A[m, n + Nh ]
50
- A22 = A[m + Mh , n + Nh ]
61
+ A21 = A[m + M_half , n]
62
+ A12 = A[m, n + N_half ]
63
+ A22 = A[m + M_half , n + N_half ]
51
64
52
65
T1 = A11 + A12
53
66
T2 = A21 + A22
@@ -59,32 +72,32 @@ function 🦋mul_level!(A, u, v)
59
72
C22 = T3 - T4
60
73
61
74
u1 = u[m]
62
- u2 = u[m + Mh ]
75
+ u2 = u[m + M_half ]
63
76
v1 = v[n]
64
- v2 = v[n + Nh ]
77
+ v2 = v[n + N_half ]
65
78
66
79
A[m, n] = u1 * C11 * v1
67
- A[m + Mh , n] = u2 * C21 * v1
68
- A[m, n + Nh ] = u1 * C12 * v2
69
- A[m + Mh , n + Nh ] = u2 * C22 * v2
80
+ A[m + M_half , n] = u2 * C21 * v1
81
+ A[m, n + N_half ] = u1 * C12 * v2
82
+ A[m + M_half , n + N_half ] = u2 * C22 * v2
70
83
end
71
84
end
72
85
end
73
86
74
- function 🦋mul! (A, (uv,) )
87
+ function 🦋mul! (A, uv )
75
88
M, N = size (A)
76
89
@assert M == N
77
- Mh = M >>> 1
90
+ M_half = M >>> 1
78
91
79
- U₁ = @view (uv[1 : Mh ])
80
- V₁ = @view (uv[(Mh + 1 ): (M)])
81
- U₂ = @view (uv[(1 + M): (M + Mh )])
82
- V₂ = @view (uv[(1 + M + Mh ): (2 * M)])
92
+ U₁ = @view (uv[1 : M_half ])
93
+ V₁ = @view (uv[(M_half + 1 ): (M)])
94
+ U₂ = @view (uv[(1 + M): (M + M_half )])
95
+ V₂ = @view (uv[(1 + M + M_half ): (2 * M)])
83
96
84
- 🦋mul_level! (@view (A[1 : Mh , 1 : Mh ]), U₁, V₁)
85
- 🦋mul_level! (@view (A[Mh + 1 : M, 1 : Mh ]), U₂, V₁)
86
- 🦋mul_level! (@view (A[1 : Mh, Mh + 1 : M]), U₁, V₂)
87
- 🦋mul_level! (@view (A[Mh + 1 : M, Mh + 1 : M]), U₂, V₂)
97
+ 🦋mul_level! (@view (A[1 : M_half , 1 : M_half ]), U₁, V₁)
98
+ 🦋mul_level! (@view (A[M_half + 1 : M, 1 : M_half ]), U₂, V₁)
99
+ 🦋mul_level! (@view (A[1 : M_half, M_half + 1 : M]), U₁, V₂)
100
+ 🦋mul_level! (@view (A[M_half + 1 : M, M_half + 1 : M]), U₂, V₂)
88
101
89
102
U = @view (uv[(1 + 2 * M): (3 * M)])
90
103
V = @view (uv[(1 + 3 * M): (4 * M)])
@@ -106,7 +119,14 @@ function diagnegbottom(x)
106
119
Diagonal (y), Diagonal (z)
107
120
end
108
121
109
- function 🦋2 !(C, A:: Diagonal , B:: Diagonal )
122
+ function 🦋! (C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
123
+ setdiagonal! (C, [A. diag; - B. diag], true )
124
+ setdiagonal! (C, A. diag, true )
125
+ setdiagonal! (C, B. diag, false )
126
+ C
127
+ end
128
+
129
+ function 🦋! (C, A:: Diagonal , B:: Diagonal )
110
130
@assert size (A) == size (B)
111
131
A1 = size (A, 1 )
112
132
@@ -120,61 +140,35 @@ function 🦋2!(C, A::Diagonal, B::Diagonal)
120
140
C
121
141
end
122
142
123
- function 🦋! (A:: Matrix , C:: SparseBandedMatrix , X:: Diagonal , Y:: Diagonal )
124
- @assert size (X) == size (Y)
125
- if (size (X, 1 ) + size (Y, 1 ) != size (A, 1 ))
126
- x = size (A, 1 ) - size (X, 1 ) - size (Y, 1 )
127
- setdiagonal! (C, [X. diag; rand (x); - Y. diag], true )
128
- setdiagonal! (C, X. diag, true )
129
- setdiagonal! (C, Y. diag, false )
130
- else
131
- setdiagonal! (C, [X. diag; - Y. diag], true )
132
- setdiagonal! (C, X. diag, true )
133
- setdiagonal! (C, Y. diag, false )
134
- end
135
-
136
- C
137
- end
138
-
139
- function 🦋2 !(C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
140
- setdiagonal! (C, [A. diag; - B. diag], true )
141
- setdiagonal! (C, A. diag, true )
142
- setdiagonal! (C, B. diag, false )
143
- C
144
- end
145
-
146
- function materializeUV (A, (uv,))
147
- M, N = size (A)
148
- Mh = M >>> 1
149
- Nh = N >>> 1
143
+ function materializeUV (U, V, uv)
144
+ M = size (U, 1 )
145
+ M_half = M >>> 1
150
146
151
- U₁u, U₁l = diagnegbottom (@view (uv[1 : Mh ])) # Mh
152
- U₂u, U₂l = diagnegbottom (@view (uv[(1 + Mh + Nh ): (M + Nh )])) # M2
153
- V₁u, V₁l = diagnegbottom (@view (uv[(Mh + 1 ): (Mh + Nh )])) # Nh
154
- V₂u, V₂l = diagnegbottom (@view (uv[(1 + 2 * Mh + Nh ): (2 * Mh + N )])) # N2
155
- Uu, Ul = diagnegbottom (@view (uv[(1 + M + N ): (2 * M + N )])) # M
156
- Vu, Vl = diagnegbottom (@view (uv[(1 + 2 * M + N ): (2 * M + 2 * N )])) # N
147
+ U₁u, U₁l = diagnegbottom (@view (uv[1 : M_half ])) # M_half
148
+ U₂u, U₂l = diagnegbottom (@view (uv[(1 + 2 * M_half ): (M + M_half )])) # M_half
149
+ V₁u, V₁l = diagnegbottom (@view (uv[(M_half + 1 ): (2 * M_half )])) # M_half
150
+ V₂u, V₂l = diagnegbottom (@view (uv[(1 + 3 * M_half ): (2 * M_half + M )])) # M_half
151
+ Uu, Ul = diagnegbottom (@view (uv[(1 + 2 * M ): (3 * M)])) # M
152
+ Vu, Vl = diagnegbottom (@view (uv[(1 + 3 * M): (4 * M)])) # M
157
153
158
- Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
154
+ Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
159
155
160
- 🦋2 !(view (Bu2, 1 : Mh , 1 : Nh ), U₁u, U₁l)
161
- 🦋2 !(view (Bu2, Mh + 1 : M, Nh + 1 : N ), U₂u, U₂l)
156
+ 🦋! (view (Bu2, 1 : M_half , 1 : M_half ), U₁u, U₁l)
157
+ 🦋! (view (Bu2, M_half + 1 : M, M_half + 1 : M ), U₂u, U₂l)
162
158
163
- Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
164
- 🦋! (A, Bu1, Uu, Ul)
159
+ Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
160
+ 🦋! (Bu1, Uu, Ul)
165
161
166
- Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
162
+ Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
167
163
168
- 🦋2 !(view (Bv2, 1 : Mh , 1 : Nh ), V₁u, V₁l)
169
- 🦋2 !(view (Bv2, Mh + 1 : M, Nh + 1 : N ), V₂u, V₂l)
164
+ 🦋! (view (Bv2, 1 : M_half , 1 : M_half ), V₁u, V₁l)
165
+ 🦋! (view (Bv2, M_half + 1 : M, M_half + 1 : M ), V₂u, V₂l)
170
166
171
- Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
172
- 🦋! (A, Bv1, Vu, Vl)
167
+ Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
168
+ 🦋! (Bv1, Vu, Vl)
173
169
174
- U = (Bu2 * Bu1)'
175
- V = Bv2 * Bv1
176
-
177
- U, V
170
+ mul! (U, Bu2, Bu1)
171
+ mul! (V, Bv2, Bv1)
178
172
end
179
173
180
174
function pad! (A)
0 commit comments