@@ -2,6 +2,193 @@ using VectorizedRNG
22using LinearAlgebra: Diagonal, I
33using LoopVectorization
44using RecursiveFactorization
5+ using SparseArrays
6+
7+ struct SparseBandedMatrix{T} <: AbstractMatrix{T}
8+ size :: Tuple{Int, Int}
9+ indices :: Vector{Int}
10+ diags :: Vector{Vector{T}}
11+ function SparseBandedMatrix {T} (:: UndefInitializer , N, M) where T
12+ size = (N, M)
13+ indices = Int[]
14+ diags = Vector{T}[]
15+ new (size, indices, diags)
16+ end
17+ function SparseBandedMatrix {T} (ind_vals, diag_vals, N, M) where T
18+ size = (N, M)
19+ perm = sortperm (ind_vals)
20+ indices = ind_vals[perm]
21+ for i in 1 : length (indices) - 1
22+ @assert indices[i] != indices[i + 1 ]
23+ end
24+ diags = diag_vals[perm]
25+ new (size, indices, diags)
26+ end
27+ end
28+
29+ function Base. size (M :: SparseBandedMatrix )
30+ M. size
31+ end
32+
33+ function Base. getindex (M :: SparseBandedMatrix{T} , i :: Int , j :: Int , I :: Int... ) where T
34+ @boundscheck checkbounds (M, i, j, I... )
35+ rows, cols = size (M)
36+ wanted_ind = rows - i + j
37+ ind = searchsortedfirst (M. indices, wanted_ind)
38+ if (ind <= length (M. indices) && M. indices[ind] == wanted_ind)
39+ if (i > j)
40+ return M. diags[ind][j]
41+ else
42+ return M. diags[ind][i]
43+ end
44+ end
45+ zero (T)
46+ end
47+
48+ function Base. setindex! (M :: SparseBandedMatrix{T} , val, i :: Int , j :: Int , I :: Int... ) where T # TODO IF VAL ISNT OF TYPE T
49+ @boundscheck checkbounds (M, i, j, I... )
50+ rows, cols = size (M)
51+ wanted_ind = rows - i + j
52+ ind = searchsortedfirst (M. indices, wanted_ind)
53+ if (ind > length (M. indices) || M. indices[ind] != wanted_ind)
54+ insert! (M. indices, ind, wanted_ind)
55+ insert! (M. diags, ind, zeros (T, rows - abs (wanted_ind - rows)))
56+ end
57+ if (i > j)
58+ M. diags[ind][j] = val isa T ? val : convert (T, val):: T
59+ else
60+ M. diags[ind][i] = val isa T ? val : convert (T, val):: T
61+ end
62+ val
63+ end
64+
65+ function setdiagonal! (M :: SparseBandedMatrix{T} , diagvals, lower :: Bool ) where T
66+ rows, cols = size (M)
67+ if length (diagvals) > rows
68+ error (" size of diagonal is too big for the matrix" )
69+ end
70+ if lower
71+ wanted_ind = length (diagvals)
72+ else
73+ wanted_ind = 2 * rows - length (diagvals)
74+ end
75+
76+ ind = searchsortedfirst (M. indices, wanted_ind)
77+ if (ind > length (M. indices) || M. indices[ind] != wanted_ind)
78+ insert! (M. indices, ind, wanted_ind)
79+ insert! (M. diags, ind, diagvals isa Vector{T} ? diagvals : convert (Vector{T}, diagvals):: Vector{T} )
80+ else
81+ for i in 1 : eachindex (diagvals)
82+ M. diags[ind][i] = diagvals[i] isa T ? diagvals[i] : convert (T, diagvals[i]):: T
83+ end
84+ end
85+ diagvals
86+ end
87+
88+ using LinearAlgebra
89+ using . Threads
90+
91+ # C = Cb + aAB
92+ function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: Matrix{T} , a :: Number , b :: Number ) where T
93+ @assert size (A, 2 ) == size (B, 1 )
94+ @assert size (A, 1 ) == size (C, 1 )
95+ @assert size (B, 2 ) == size (C, 2 )
96+ C.*= b
97+
98+ rows, cols = size (A)
99+ @inbounds for (ind, location) in enumerate (A. indices)
100+ @threads for i in 1 : length (A. diags[ind])
101+ # value: diag[i]
102+ # index in array:
103+ # if ind < rows(A), then index = (rows - loc + i, i)
104+ # else index = (i, loc - cols + i)
105+ val = A. diags[ind][i] * a
106+ if location < rows
107+ index_i = rows - location + i
108+ index_j = i
109+ else
110+ index_i = i
111+ index_j = location - cols + i
112+ end
113+ # A[index_i, index_j] * B[index_j, j] = C[index_i, j]
114+ @simd for j in 1 : size (B, 2 )
115+ C[index_i, j] = fma (val, B[index_j, j], C[index_i, j])
116+ end
117+ end
118+ end
119+ C
120+ end
121+
122+ # C = Cb + aBA
123+ function LinearAlgebra. mul! (C :: Matrix{T} , A:: Matrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
124+ @assert size (A, 2 ) == size (B, 1 )
125+ @assert size (A, 1 ) == size (C, 1 )
126+ @assert size (B, 2 ) == size (C, 2 )
127+
128+ C.*= b
129+
130+ rows, cols = size (B)
131+ @inbounds for (ind, location) in enumerate (B. indices)
132+ @threads for i in eachindex (B. diags[ind])
133+ val = B. diags[ind][i] * a
134+ if location < rows
135+ index_i = rows - location + i
136+ index_j = i
137+ else
138+ index_i = i
139+ index_j = location - cols + i
140+ end
141+ @simd for j in 1 : size (A, 1 )
142+ C[j, index_j] = fma (val, A[j, index_i], C[j, index_j])
143+ end
144+ end
145+ end
146+ C
147+ end
148+
149+ function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
150+ @assert size (A, 2 ) == size (B, 1 )
151+ @assert size (A, 1 ) == size (C, 1 )
152+ @assert size (B, 2 ) == size (C, 2 )
153+
154+ C.*= b
155+
156+ rows_a, cols_a = size (A)
157+ rows_b, cols_b = size (B)
158+ @inbounds for (ind_a, location_a) in enumerate (A. indices)
159+ @threads for i in eachindex (A. diags[ind_a])
160+ val_a = A. diags[ind_a][i] * a
161+ if location_a < rows_a
162+ index_ia = rows_a - location_a + i
163+ index_ja = i
164+ else
165+ index_ia = i
166+ index_ja = location_a - cols_a + i
167+ end
168+ min_loc = rows_b - index_ja + 1
169+ max_loc = 2 * rows_b - index_ja
170+ for (ind_b, location_b) in enumerate (B. indices)
171+ # index_ib = index_ja
172+ # if ind < rows(A), then index = (rows - loc + i, i)
173+ # rows - loc + j = index_ja, j = index_ja - rows + loc
174+ # else index = (i, loc - cols + i)
175+ # if location < rows(B), then
176+ if location_b <= rows_b && location_b >= min_loc
177+ j = index_ja - rows_b + location_b
178+ index_jb = j
179+ val_b = B. diags[ind_b][j]
180+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
181+ elseif location_b > rows_b && location_b <= max_loc
182+ j = index_ja
183+ index_jb = location_b - cols_b + j
184+ val_b = B. diags[ind_b][j]
185+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
186+ end
187+ end
188+ end
189+ end
190+ C
191+ end
5192
6193@inline exphalf (x) = exp (x) * oftype (x, 0.5 )
7194function 🦋! (wv, :: Val{SEED} = Val (888 )) where {SEED}
32219const butterfly_workspace = 🦋workspace;
33220
34221function 🦋mul_level! (A, u, v)
35- # for now, assume...
36222 M, N = size (A)
37223 Ml = M >>> 1
38224 Nl = N >>> 1
@@ -97,8 +283,53 @@ function diagnegbottom(x)
97283 end
98284 Diagonal (y), Diagonal (z)
99285end
100- 🦋(A, B) = [A B
101- A - B]
286+
287+ # 🦋(A, B) = [A B
288+ # A -B]
289+
290+ # Bu2 = [🦋(U₁u, U₁l) 0*I
291+ # 0*I 🦋(U₂u, U₂l)]
292+ # U1u U1l 0 0
293+ # U1u -U1l 0 0
294+ #=
295+ function 🦋!(C, A, B)
296+ A1, A2 = size(A)
297+ B1, B2 = size(B)
298+ @assert A1 == B1
299+ for j in 1 : A2, i in 1 : A1
300+ C[i, j] = A[i, j]
301+ C[i + A1, j] = A[i, j]
302+ end
303+ for j in A2 + 1 : A2 + B2, i in 1 : A1
304+ C[i, j] = B[i, j - A2]
305+ C[i + A1, j] = -B[i, j - A2]
306+ end
307+ C
308+ end
309+ =#
310+ function 🦋! (C, A:: Diagonal , B:: Diagonal )
311+ @assert size (A) == size (B)
312+ A1 = size (A, 1 )
313+
314+ for i in 1 : A1
315+ C[i, i] = A[i, i]
316+ C[i + A1, i] = A[i, i]
317+ C[i, i + A1] = B[i, i]
318+ C[i + A1, i + A1] = - B[i, i]
319+ end
320+
321+ C
322+ end
323+
324+ function 🦋! (C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
325+ @assert size (A) == size (B)
326+
327+ setdiagonal! (C, [A. diag; - B. diag], true )
328+ setdiagonal! (C, A. diag, true )
329+ setdiagonal! (C, B. diag, false )
330+ C
331+ end
332+
102333
103334function materializeUV (A, (uv,))
104335 M, N = size (A)
@@ -112,13 +343,31 @@ function materializeUV(A, (uv,))
112343 Uu, Ul = diagnegbottom (@view (uv[(1 + 2 * Mh + 2 * Nh): (2 * Mh + 2 * Nh + M)]))
113344 Vu, Vl = diagnegbottom (@view (uv[(1 + 2 * Mh + 2 * Nh + M): (2 * Mh + 2 * Nh + M + N)]))
114345
115- Bu2 = [🦋(U₁u, U₁l) 0 * I
116- 0 * I 🦋(U₂u, U₂l)]
117- Bu1 = 🦋(Uu, Ul)
346+ # WRITE OUT MERGINGS EXPLICITLY
347+ # Bu2 = [🦋(U₁u, U₁l) 0*I
348+ # 0*I 🦋(U₂u, U₂l)]
349+ # show size(Bu2)[1] #808
350+ # @show size(🦋(V₁u, V₁l))[1] #404
351+
352+ # Bu2 = spzeros(M, N)
353+ Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
354+
355+ 🦋! (view (Bu2, 1 : M ÷ 2 , 1 : N ÷ 2 ), U₁u, U₁l)
356+ 🦋! (view (Bu2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), U₂u, U₂l)
357+
358+ # Bu1 = spzeros(M, N)
359+ Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
360+ 🦋! (Bu1, Uu, Ul)
361+
362+ # Bv2 = spzeros(M, N)
363+ Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
364+
365+ 🦋! (view (Bv2, 1 : M ÷ 2 , 1 : N ÷ 2 ), V₁u, V₁l)
366+ 🦋! (view (Bv2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), V₂u, V₂l)
118367
119- Bv2 = [🦋(V₁u, V₁l) 0 * I
120- 0 * I 🦋(V₂u, V₂l)]
121- Bv1 = 🦋( Vu, Vl)
368+ # Bv1 = spzeros(M, N)
369+ Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
370+ 🦋 ! (Bv1, Vu, Vl)
122371
123372 (Bu2 * Bu1)' , Bv2 * Bv1, A
124373end
@@ -136,7 +385,7 @@ function pad!(A)
136385 @inbounds A_new[j, i] = 0
137386 end
138387
139- for i in M + 1 : M + xn, j in N + 1 : N + xn
388+ for j in N + 1 : N + xn, i in M + 1 : M + xn
140389 @inbounds A_new[i,j] = i == j
141390 end
142391 A_new
0 commit comments