@@ -2,6 +2,193 @@ using VectorizedRNG
2
2
using LinearAlgebra: Diagonal, I
3
3
using LoopVectorization
4
4
using RecursiveFactorization
5
+ using SparseArrays
6
+
7
+ struct SparseBandedMatrix{T} <: AbstractMatrix{T}
8
+ size :: Tuple{Int, Int}
9
+ indices :: Vector{Int}
10
+ diags :: Vector{Vector{T}}
11
+ function SparseBandedMatrix {T} (:: UndefInitializer , N, M) where T
12
+ size = (N, M)
13
+ indices = Int[]
14
+ diags = Vector{T}[]
15
+ new (size, indices, diags)
16
+ end
17
+ function SparseBandedMatrix {T} (ind_vals, diag_vals, N, M) where T
18
+ size = (N, M)
19
+ perm = sortperm (ind_vals)
20
+ indices = ind_vals[perm]
21
+ for i in 1 : length (indices) - 1
22
+ @assert indices[i] != indices[i + 1 ]
23
+ end
24
+ diags = diag_vals[perm]
25
+ new (size, indices, diags)
26
+ end
27
+ end
28
+
29
+ function Base. size (M :: SparseBandedMatrix )
30
+ M. size
31
+ end
32
+
33
+ function Base. getindex (M :: SparseBandedMatrix{T} , i :: Int , j :: Int , I :: Int... ) where T
34
+ @boundscheck checkbounds (M, i, j, I... )
35
+ rows, cols = size (M)
36
+ wanted_ind = rows - i + j
37
+ ind = searchsortedfirst (M. indices, wanted_ind)
38
+ if (ind <= length (M. indices) && M. indices[ind] == wanted_ind)
39
+ if (i > j)
40
+ return M. diags[ind][j]
41
+ else
42
+ return M. diags[ind][i]
43
+ end
44
+ end
45
+ zero (T)
46
+ end
47
+
48
+ function Base. setindex! (M :: SparseBandedMatrix{T} , val, i :: Int , j :: Int , I :: Int... ) where T # TODO IF VAL ISNT OF TYPE T
49
+ @boundscheck checkbounds (M, i, j, I... )
50
+ rows, cols = size (M)
51
+ wanted_ind = rows - i + j
52
+ ind = searchsortedfirst (M. indices, wanted_ind)
53
+ if (ind > length (M. indices) || M. indices[ind] != wanted_ind)
54
+ insert! (M. indices, ind, wanted_ind)
55
+ insert! (M. diags, ind, zeros (T, rows - abs (wanted_ind - rows)))
56
+ end
57
+ if (i > j)
58
+ M. diags[ind][j] = val isa T ? val : convert (T, val):: T
59
+ else
60
+ M. diags[ind][i] = val isa T ? val : convert (T, val):: T
61
+ end
62
+ val
63
+ end
64
+
65
+ function setdiagonal! (M :: SparseBandedMatrix{T} , diagvals, lower :: Bool ) where T
66
+ rows, cols = size (M)
67
+ if length (diagvals) > rows
68
+ error (" size of diagonal is too big for the matrix" )
69
+ end
70
+ if lower
71
+ wanted_ind = length (diagvals)
72
+ else
73
+ wanted_ind = 2 * rows - length (diagvals)
74
+ end
75
+
76
+ ind = searchsortedfirst (M. indices, wanted_ind)
77
+ if (ind > length (M. indices) || M. indices[ind] != wanted_ind)
78
+ insert! (M. indices, ind, wanted_ind)
79
+ insert! (M. diags, ind, diagvals isa Vector{T} ? diagvals : convert (Vector{T}, diagvals):: Vector{T} )
80
+ else
81
+ for i in 1 : eachindex (diagvals)
82
+ M. diags[ind][i] = diagvals[i] isa T ? diagvals[i] : convert (T, diagvals[i]):: T
83
+ end
84
+ end
85
+ diagvals
86
+ end
87
+
88
+ using LinearAlgebra
89
+ using . Threads
90
+
91
+ # C = Cb + aAB
92
+ function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: Matrix{T} , a :: Number , b :: Number ) where T
93
+ @assert size (A, 2 ) == size (B, 1 )
94
+ @assert size (A, 1 ) == size (C, 1 )
95
+ @assert size (B, 2 ) == size (C, 2 )
96
+ C.*= b
97
+
98
+ rows, cols = size (A)
99
+ @inbounds for (ind, location) in enumerate (A. indices)
100
+ @threads for i in 1 : length (A. diags[ind])
101
+ # value: diag[i]
102
+ # index in array:
103
+ # if ind < rows(A), then index = (rows - loc + i, i)
104
+ # else index = (i, loc - cols + i)
105
+ val = A. diags[ind][i] * a
106
+ if location < rows
107
+ index_i = rows - location + i
108
+ index_j = i
109
+ else
110
+ index_i = i
111
+ index_j = location - cols + i
112
+ end
113
+ # A[index_i, index_j] * B[index_j, j] = C[index_i, j]
114
+ @simd for j in 1 : size (B, 2 )
115
+ C[index_i, j] = fma (val, B[index_j, j], C[index_i, j])
116
+ end
117
+ end
118
+ end
119
+ C
120
+ end
121
+
122
+ # C = Cb + aBA
123
+ function LinearAlgebra. mul! (C :: Matrix{T} , A:: Matrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
124
+ @assert size (A, 2 ) == size (B, 1 )
125
+ @assert size (A, 1 ) == size (C, 1 )
126
+ @assert size (B, 2 ) == size (C, 2 )
127
+
128
+ C.*= b
129
+
130
+ rows, cols = size (B)
131
+ @inbounds for (ind, location) in enumerate (B. indices)
132
+ @threads for i in eachindex (B. diags[ind])
133
+ val = B. diags[ind][i] * a
134
+ if location < rows
135
+ index_i = rows - location + i
136
+ index_j = i
137
+ else
138
+ index_i = i
139
+ index_j = location - cols + i
140
+ end
141
+ @simd for j in 1 : size (A, 1 )
142
+ C[j, index_j] = fma (val, A[j, index_i], C[j, index_j])
143
+ end
144
+ end
145
+ end
146
+ C
147
+ end
148
+
149
+ function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
150
+ @assert size (A, 2 ) == size (B, 1 )
151
+ @assert size (A, 1 ) == size (C, 1 )
152
+ @assert size (B, 2 ) == size (C, 2 )
153
+
154
+ C.*= b
155
+
156
+ rows_a, cols_a = size (A)
157
+ rows_b, cols_b = size (B)
158
+ @inbounds for (ind_a, location_a) in enumerate (A. indices)
159
+ @threads for i in eachindex (A. diags[ind_a])
160
+ val_a = A. diags[ind_a][i] * a
161
+ if location_a < rows_a
162
+ index_ia = rows_a - location_a + i
163
+ index_ja = i
164
+ else
165
+ index_ia = i
166
+ index_ja = location_a - cols_a + i
167
+ end
168
+ min_loc = rows_b - index_ja + 1
169
+ max_loc = 2 * rows_b - index_ja
170
+ for (ind_b, location_b) in enumerate (B. indices)
171
+ # index_ib = index_ja
172
+ # if ind < rows(A), then index = (rows - loc + i, i)
173
+ # rows - loc + j = index_ja, j = index_ja - rows + loc
174
+ # else index = (i, loc - cols + i)
175
+ # if location < rows(B), then
176
+ if location_b <= rows_b && location_b >= min_loc
177
+ j = index_ja - rows_b + location_b
178
+ index_jb = j
179
+ val_b = B. diags[ind_b][j]
180
+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
181
+ elseif location_b > rows_b && location_b <= max_loc
182
+ j = index_ja
183
+ index_jb = location_b - cols_b + j
184
+ val_b = B. diags[ind_b][j]
185
+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
186
+ end
187
+ end
188
+ end
189
+ end
190
+ C
191
+ end
5
192
6
193
@inline exphalf (x) = exp (x) * oftype (x, 0.5 )
7
194
function 🦋! (wv, :: Val{SEED} = Val (888 )) where {SEED}
32
219
const butterfly_workspace = 🦋workspace;
33
220
34
221
function 🦋mul_level! (A, u, v)
35
- # for now, assume...
36
222
M, N = size (A)
37
223
Ml = M >>> 1
38
224
Nl = N >>> 1
@@ -97,8 +283,53 @@ function diagnegbottom(x)
97
283
end
98
284
Diagonal (y), Diagonal (z)
99
285
end
100
- 🦋(A, B) = [A B
101
- A - B]
286
+
287
+ # 🦋(A, B) = [A B
288
+ # A -B]
289
+
290
+ # Bu2 = [🦋(U₁u, U₁l) 0*I
291
+ # 0*I 🦋(U₂u, U₂l)]
292
+ # U1u U1l 0 0
293
+ # U1u -U1l 0 0
294
+ #=
295
+ function 🦋!(C, A, B)
296
+ A1, A2 = size(A)
297
+ B1, B2 = size(B)
298
+ @assert A1 == B1
299
+ for j in 1 : A2, i in 1 : A1
300
+ C[i, j] = A[i, j]
301
+ C[i + A1, j] = A[i, j]
302
+ end
303
+ for j in A2 + 1 : A2 + B2, i in 1 : A1
304
+ C[i, j] = B[i, j - A2]
305
+ C[i + A1, j] = -B[i, j - A2]
306
+ end
307
+ C
308
+ end
309
+ =#
310
+ function 🦋! (C, A:: Diagonal , B:: Diagonal )
311
+ @assert size (A) == size (B)
312
+ A1 = size (A, 1 )
313
+
314
+ for i in 1 : A1
315
+ C[i, i] = A[i, i]
316
+ C[i + A1, i] = A[i, i]
317
+ C[i, i + A1] = B[i, i]
318
+ C[i + A1, i + A1] = - B[i, i]
319
+ end
320
+
321
+ C
322
+ end
323
+
324
+ function 🦋! (C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
325
+ @assert size (A) == size (B)
326
+
327
+ setdiagonal! (C, [A. diag; - B. diag], true )
328
+ setdiagonal! (C, A. diag, true )
329
+ setdiagonal! (C, B. diag, false )
330
+ C
331
+ end
332
+
102
333
103
334
function materializeUV (A, (uv,))
104
335
M, N = size (A)
@@ -112,13 +343,31 @@ function materializeUV(A, (uv,))
112
343
Uu, Ul = diagnegbottom (@view (uv[(1 + 2 * Mh + 2 * Nh): (2 * Mh + 2 * Nh + M)]))
113
344
Vu, Vl = diagnegbottom (@view (uv[(1 + 2 * Mh + 2 * Nh + M): (2 * Mh + 2 * Nh + M + N)]))
114
345
115
- Bu2 = [🦋(U₁u, U₁l) 0 * I
116
- 0 * I 🦋(U₂u, U₂l)]
117
- Bu1 = 🦋(Uu, Ul)
346
+ # WRITE OUT MERGINGS EXPLICITLY
347
+ # Bu2 = [🦋(U₁u, U₁l) 0*I
348
+ # 0*I 🦋(U₂u, U₂l)]
349
+ # show size(Bu2)[1] #808
350
+ # @show size(🦋(V₁u, V₁l))[1] #404
351
+
352
+ # Bu2 = spzeros(M, N)
353
+ Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
354
+
355
+ 🦋! (view (Bu2, 1 : M ÷ 2 , 1 : N ÷ 2 ), U₁u, U₁l)
356
+ 🦋! (view (Bu2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), U₂u, U₂l)
357
+
358
+ # Bu1 = spzeros(M, N)
359
+ Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
360
+ 🦋! (Bu1, Uu, Ul)
361
+
362
+ # Bv2 = spzeros(M, N)
363
+ Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
364
+
365
+ 🦋! (view (Bv2, 1 : M ÷ 2 , 1 : N ÷ 2 ), V₁u, V₁l)
366
+ 🦋! (view (Bv2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), V₂u, V₂l)
118
367
119
- Bv2 = [🦋(V₁u, V₁l) 0 * I
120
- 0 * I 🦋(V₂u, V₂l)]
121
- Bv1 = 🦋( Vu, Vl)
368
+ # Bv1 = spzeros(M, N)
369
+ Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
370
+ 🦋 ! (Bv1, Vu, Vl)
122
371
123
372
(Bu2 * Bu1)' , Bv2 * Bv1, A
124
373
end
@@ -136,7 +385,7 @@ function pad!(A)
136
385
@inbounds A_new[j, i] = 0
137
386
end
138
387
139
- for i in M + 1 : M + xn, j in N + 1 : N + xn
388
+ for j in N + 1 : N + xn, i in M + 1 : M + xn
140
389
@inbounds A_new[i,j] = i == j
141
390
end
142
391
A_new
0 commit comments