Skip to content

Commit 82cd318

Browse files
add SparseBandedMatrices and associated implementation
1 parent 5f558fc commit 82cd318

File tree

2 files changed

+261
-10
lines changed

2 files changed

+261
-10
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
99
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
1010
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
1213
TriangularSolve = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf"
1314
VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e"
@@ -17,6 +18,7 @@ LinearAlgebra = "1.5"
1718
LoopVectorization = "0.10,0.11, 0.12"
1819
Polyester = "0.3.2,0.4.1, 0.5, 0.6, 0.7"
1920
PrecompileTools = "1"
21+
SparseArrays = "1.11.0"
2022
StrideArraysCore = "0.5.5"
2123
TriangularSolve = "0.2"
2224
VectorizedRNG = "0.2.25"

src/butterflylu.jl

Lines changed: 259 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,193 @@ using VectorizedRNG
22
using LinearAlgebra: Diagonal, I
33
using LoopVectorization
44
using RecursiveFactorization
5+
using SparseArrays
6+
7+
struct SparseBandedMatrix{T} <: AbstractMatrix{T}
8+
size :: Tuple{Int, Int}
9+
indices :: Vector{Int}
10+
diags :: Vector{Vector{T}}
11+
function SparseBandedMatrix{T}(::UndefInitializer, N, M) where T
12+
size = (N, M)
13+
indices = Int[]
14+
diags = Vector{T}[]
15+
new(size, indices, diags)
16+
end
17+
function SparseBandedMatrix{T}(ind_vals, diag_vals, N, M) where T
18+
size = (N, M)
19+
perm = sortperm(ind_vals)
20+
indices = ind_vals[perm]
21+
for i in 1 : length(indices) - 1
22+
@assert indices[i] != indices[i + 1]
23+
end
24+
diags = diag_vals[perm]
25+
new(size, indices, diags)
26+
end
27+
end
28+
29+
function Base.size(M :: SparseBandedMatrix)
30+
M.size
31+
end
32+
33+
function Base.getindex(M :: SparseBandedMatrix{T}, i :: Int, j :: Int, I :: Int...) where T
34+
@boundscheck checkbounds(M, i, j, I...)
35+
rows, cols = size(M)
36+
wanted_ind = rows - i + j
37+
ind = searchsortedfirst(M.indices, wanted_ind)
38+
if (ind <= length(M.indices) && M.indices[ind] == wanted_ind)
39+
if (i > j)
40+
return M.diags[ind][j]
41+
else
42+
return M.diags[ind][i]
43+
end
44+
end
45+
zero(T)
46+
end
47+
48+
function Base.setindex!(M :: SparseBandedMatrix{T}, val, i :: Int, j :: Int, I :: Int...) where T #TODO IF VAL ISNT OF TYPE T
49+
@boundscheck checkbounds(M, i, j, I...)
50+
rows, cols = size(M)
51+
wanted_ind = rows - i + j
52+
ind = searchsortedfirst(M.indices, wanted_ind)
53+
if (ind > length(M.indices) || M.indices[ind] != wanted_ind)
54+
insert!(M.indices, ind, wanted_ind)
55+
insert!(M.diags, ind, zeros(T, rows - abs(wanted_ind - rows)))
56+
end
57+
if (i > j)
58+
M.diags[ind][j] = val isa T ? val : convert(T, val)::T
59+
else
60+
M.diags[ind][i] = val isa T ? val : convert(T, val)::T
61+
end
62+
val
63+
end
64+
65+
function setdiagonal!(M :: SparseBandedMatrix{T}, diagvals, lower :: Bool) where T
66+
rows, cols = size(M)
67+
if length(diagvals) > rows
68+
error("size of diagonal is too big for the matrix")
69+
end
70+
if lower
71+
wanted_ind = length(diagvals)
72+
else
73+
wanted_ind = 2 * rows - length(diagvals)
74+
end
75+
76+
ind = searchsortedfirst(M.indices, wanted_ind)
77+
if (ind > length(M.indices) || M.indices[ind] != wanted_ind)
78+
insert!(M.indices, ind, wanted_ind)
79+
insert!(M.diags, ind, diagvals isa Vector{T} ? diagvals : convert(Vector{T}, diagvals)::Vector{T})
80+
else
81+
for i in 1 : eachindex(diagvals)
82+
M.diags[ind][i] = diagvals[i] isa T ? diagvals[i] : convert(T, diagvals[i])::T
83+
end
84+
end
85+
diagvals
86+
end
87+
88+
using LinearAlgebra
89+
using .Threads
90+
91+
# C = Cb + aAB
92+
function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Matrix{T}, a :: Number, b :: Number) where T
93+
@assert size(A, 2) == size(B, 1)
94+
@assert size(A, 1) == size(C, 1)
95+
@assert size(B, 2) == size(C, 2)
96+
C.*=b
97+
98+
rows, cols = size(A)
99+
@inbounds for (ind, location) in enumerate(A.indices)
100+
@threads for i in 1:length(A.diags[ind])
101+
# value: diag[i]
102+
# index in array:
103+
# if ind < rows(A), then index = (rows - loc + i, i)
104+
# else index = (i, loc - cols + i)
105+
val = A.diags[ind][i] * a
106+
if location < rows
107+
index_i = rows - location + i
108+
index_j = i
109+
else
110+
index_i = i
111+
index_j = location - cols + i
112+
end
113+
#A[index_i, index_j] * B[index_j, j] = C[index_i, j]
114+
@simd for j in 1 : size(B, 2)
115+
C[index_i, j] = fma(val, B[index_j, j], C[index_i, j])
116+
end
117+
end
118+
end
119+
C
120+
end
121+
122+
# C = Cb + aBA
123+
function LinearAlgebra.mul!(C :: Matrix{T}, A:: Matrix{T}, B :: SparseBandedMatrix{T}, a :: Number, b :: Number) where T
124+
@assert size(A, 2) == size(B, 1)
125+
@assert size(A, 1) == size(C, 1)
126+
@assert size(B, 2) == size(C, 2)
127+
128+
C.*=b
129+
130+
rows, cols = size(B)
131+
@inbounds for (ind, location) in enumerate(B.indices)
132+
@threads for i in eachindex(B.diags[ind])
133+
val = B.diags[ind][i] * a
134+
if location < rows
135+
index_i = rows - location + i
136+
index_j = i
137+
else
138+
index_i = i
139+
index_j = location - cols + i
140+
end
141+
@simd for j in 1 : size(A, 1)
142+
C[j, index_j] = fma(val, A[j, index_i], C[j, index_j])
143+
end
144+
end
145+
end
146+
C
147+
end
148+
149+
function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: SparseBandedMatrix{T}, a :: Number, b :: Number) where T
150+
@assert size(A, 2) == size(B, 1)
151+
@assert size(A, 1) == size(C, 1)
152+
@assert size(B, 2) == size(C, 2)
153+
154+
C.*=b
155+
156+
rows_a, cols_a = size(A)
157+
rows_b, cols_b = size(B)
158+
@inbounds for (ind_a, location_a) in enumerate(A.indices)
159+
@threads for i in eachindex(A.diags[ind_a])
160+
val_a = A.diags[ind_a][i] * a
161+
if location_a < rows_a
162+
index_ia = rows_a - location_a + i
163+
index_ja = i
164+
else
165+
index_ia = i
166+
index_ja = location_a - cols_a + i
167+
end
168+
min_loc = rows_b - index_ja + 1
169+
max_loc = 2 * rows_b - index_ja
170+
for (ind_b, location_b) in enumerate(B.indices)
171+
#index_ib = index_ja
172+
# if ind < rows(A), then index = (rows - loc + i, i)
173+
#rows - loc + j = index_ja, j = index_ja - rows + loc
174+
# else index = (i, loc - cols + i)
175+
# if location < rows(B), then
176+
if location_b <= rows_b && location_b >= min_loc
177+
j = index_ja - rows_b + location_b
178+
index_jb = j
179+
val_b = B.diags[ind_b][j]
180+
C[index_ia, index_jb] = muladd(val_a, val_b, C[index_ia, index_jb])
181+
elseif location_b > rows_b && location_b <= max_loc
182+
j = index_ja
183+
index_jb = location_b - cols_b + j
184+
val_b = B.diags[ind_b][j]
185+
C[index_ia, index_jb] = muladd(val_a, val_b, C[index_ia, index_jb])
186+
end
187+
end
188+
end
189+
end
190+
C
191+
end
5192

6193
@inline exphalf(x) = exp(x) * oftype(x, 0.5)
7194
function 🦋!(wv, ::Val{SEED} = Val(888)) where {SEED}
@@ -32,7 +219,6 @@ end
32219
const butterfly_workspace = 🦋workspace;
33220

34221
function 🦋mul_level!(A, u, v)
35-
# for now, assume...
36222
M, N = size(A)
37223
Ml = M >>> 1
38224
Nl = N >>> 1
@@ -97,8 +283,53 @@ function diagnegbottom(x)
97283
end
98284
Diagonal(y), Diagonal(z)
99285
end
100-
🦋(A, B) = [A B
101-
A -B]
286+
287+
#🦋(A, B) = [A B
288+
# A -B]
289+
290+
#Bu2 = [🦋(U₁u, U₁l) 0*I
291+
# 0*I 🦋(U₂u, U₂l)]
292+
# U1u U1l 0 0
293+
# U1u -U1l 0 0
294+
#=
295+
function 🦋!(C, A, B)
296+
A1, A2 = size(A)
297+
B1, B2 = size(B)
298+
@assert A1 == B1
299+
for j in 1 : A2, i in 1 : A1
300+
C[i, j] = A[i, j]
301+
C[i + A1, j] = A[i, j]
302+
end
303+
for j in A2 + 1 : A2 + B2, i in 1 : A1
304+
C[i, j] = B[i, j - A2]
305+
C[i + A1, j] = -B[i, j - A2]
306+
end
307+
C
308+
end
309+
=#
310+
function 🦋!(C, A::Diagonal, B::Diagonal)
311+
@assert size(A) == size(B)
312+
A1 = size(A, 1)
313+
314+
for i in 1:A1
315+
C[i, i] = A[i, i]
316+
C[i + A1, i] = A[i, i]
317+
C[i, i + A1] = B[i, i]
318+
C[i + A1, i + A1] = -B[i, i]
319+
end
320+
321+
C
322+
end
323+
324+
function 🦋!(C::SparseBandedMatrix, A::Diagonal, B::Diagonal)
325+
@assert size(A) == size(B)
326+
327+
setdiagonal!(C, [A.diag; -B.diag], true)
328+
setdiagonal!(C, A.diag, true)
329+
setdiagonal!(C, B.diag, false)
330+
C
331+
end
332+
102333

103334
function materializeUV(A, (uv,))
104335
M, N = size(A)
@@ -112,13 +343,31 @@ function materializeUV(A, (uv,))
112343
Uu, Ul = diagnegbottom(@view(uv[(1 + 2 * Mh + 2 * Nh):(2 * Mh + 2 * Nh + M)]))
113344
Vu, Vl = diagnegbottom(@view(uv[(1 + 2 * Mh + 2 * Nh + M):(2 * Mh + 2 * Nh + M + N)]))
114345

115-
Bu2 = [🦋(U₁u, U₁l) 0*I
116-
0*I 🦋(U₂u, U₂l)]
117-
Bu1 = 🦋(Uu, Ul)
346+
#WRITE OUT MERGINGS EXPLICITLY
347+
#Bu2 = [🦋(U₁u, U₁l) 0*I
348+
# 0*I 🦋(U₂u, U₂l)]
349+
#show size(Bu2)[1] #808
350+
#@show size(🦋(V₁u, V₁l))[1] #404
351+
352+
#Bu2 = spzeros(M, N)
353+
Bu2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
354+
355+
🦋!(view(Bu2, 1 : M ÷ 2, 1 : N ÷ 2), U₁u, U₁l)
356+
🦋!(view(Bu2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), U₂u, U₂l)
357+
358+
#Bu1 = spzeros(M, N)
359+
Bu1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
360+
🦋!(Bu1, Uu, Ul)
361+
362+
#Bv2 = spzeros(M, N)
363+
Bv2 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
364+
365+
🦋!(view(Bv2, 1 : M ÷ 2, 1 : N ÷ 2), V₁u, V₁l)
366+
🦋!(view(Bv2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), V₂u, V₂l)
118367

119-
Bv2 = [🦋(V₁u, V₁l) 0*I
120-
0*I 🦋(V₂u, V₂l)]
121-
Bv1 = 🦋(Vu, Vl)
368+
#Bv1 = spzeros(M, N)
369+
Bv1 = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
370+
🦋!(Bv1, Vu, Vl)
122371

123372
(Bu2 * Bu1)', Bv2 * Bv1, A
124373
end
@@ -136,7 +385,7 @@ function pad!(A)
136385
@inbounds A_new[j, i] = 0
137386
end
138387

139-
for i in M + 1 : M + xn, j in N + 1 : N + xn
388+
for j in N + 1 : N + xn, i in M + 1 : M + xn
140389
@inbounds A_new[i,j] = i == j
141390
end
142391
A_new

0 commit comments

Comments
 (0)