1- using VectorizedRNG
2- using LinearAlgebra: Diagonal, I
3- using LoopVectorization
4- using RecursiveFactorization
5- using SparseArrays
1+ using LinearAlgebra, . Threads
62
73struct SparseBandedMatrix{T} <: AbstractMatrix{T}
84 size :: Tuple{Int, Int}
@@ -45,9 +41,9 @@ function Base.getindex(M :: SparseBandedMatrix{T}, i :: Int, j :: Int, I :: Int.
4541 zero (T)
4642end
4743
48- function Base. setindex! (M :: SparseBandedMatrix{T} , val, i :: Int , j :: Int , I :: Int... ) where T # TODO IF VAL ISNT OF TYPE T
44+ function Base. setindex! (M :: SparseBandedMatrix{T} , val, i :: Int , j :: Int , I :: Int... ) where T
4945 @boundscheck checkbounds (M, i, j, I... )
50- rows, cols = size (M)
46+ rows = size (M, 1 )
5147 wanted_ind = rows - i + j
5248 ind = searchsortedfirst (M. indices, wanted_ind)
5349 if (ind > length (M. indices) || M. indices[ind] != wanted_ind)
@@ -85,9 +81,6 @@ function Base.setindex!(M :: SparseBandedMatrix{T}, val, i :: Int, j :: Int, I :
8581 diagvals
8682end
8783
88- using LinearAlgebra
89- using . Threads
90-
9184# C = Cb + aAB
9285function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: Matrix{T} , a :: Number , b :: Number ) where T
9386 @assert size (A, 2 ) == size (B, 1 )
@@ -111,7 +104,7 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Matr
111104 index_j = location - cols + i
112105 end
113106 # A[index_i, index_j] * B[index_j, j] = C[index_i, j]
114- @simd for j in 1 : size (B, 2 )
107+ for j in 1 : size (B, 2 )
115108 C[index_i, j] = fma (val, B[index_j, j], C[index_i, j])
116109 end
117110 end
@@ -146,6 +139,50 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: Matrix{T}, B :: SparseBandedMatr
146139 C
147140end
148141
142+ function LinearAlgebra. mul! (C :: SparseBandedMatrix{T} , A:: SparseBandedMatrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
143+ @assert size (A, 2 ) == size (B, 1 )
144+ @assert size (A, 1 ) == size (C, 1 )
145+ @assert size (B, 2 ) == size (C, 2 )
146+
147+ C.*= b
148+
149+ rows_a, cols_a = size (A)
150+ rows_b, cols_b = size (B)
151+ @inbounds for (ind_a, location_a) in enumerate (A. indices)
152+ @threads for i in eachindex (A. diags[ind_a])
153+ val_a = A. diags[ind_a][i] * a
154+ if location_a < rows_a
155+ index_ia = rows_a - location_a + i
156+ index_ja = i
157+ else
158+ index_ia = i
159+ index_ja = location_a - cols_a + i
160+ end
161+ min_loc = rows_b - index_ja + 1
162+ max_loc = 2 * rows_b - index_ja
163+ for (ind_b, location_b) in enumerate (B. indices)
164+ # index_ib = index_ja
165+ # if ind < rows(A), then index = (rows - loc + i, i)
166+ # rows - loc + j = index_ja, j = index_ja - rows + loc
167+ # else index = (i, loc - cols + i)
168+ # if location < rows(B), then
169+ if location_b <= rows_b && location_b >= min_loc
170+ j = index_ja - rows_b + location_b
171+ index_jb = j
172+ val_b = B. diags[ind_b][j]
173+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
174+ elseif location_b > rows_b && location_b <= max_loc
175+ j = index_ja
176+ index_jb = location_b - cols_b + j
177+ val_b = B. diags[ind_b][j]
178+ C[index_ia, index_jb] = muladd (val_a, val_b, C[index_ia, index_jb])
179+ end
180+ end
181+ end
182+ end
183+ C
184+ end
185+
149186function LinearAlgebra. mul! (C :: Matrix{T} , A:: SparseBandedMatrix{T} , B :: SparseBandedMatrix{T} , a :: Number , b :: Number ) where T
150187 @assert size (A, 2 ) == size (B, 1 )
151188 @assert size (A, 1 ) == size (C, 1 )
@@ -190,6 +227,12 @@ function LinearAlgebra.mul!(C :: Matrix{T}, A:: SparseBandedMatrix{T}, B :: Spar
190227 C
191228end
192229
230+ using VectorizedRNG
231+ using LinearAlgebra: Diagonal, I
232+ using LoopVectorization
233+ using RecursiveFactorization
234+ using SparseArrays
235+
193236@inline exphalf (x) = exp (x) * oftype (x, 0.5 )
194237function 🦋! (wv, :: Val{SEED} = Val (888 )) where {SEED}
195238 T = eltype (wv)
@@ -207,13 +250,13 @@ function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED}
207250end
208251
209252function 🦋workspace (A, :: Val{SEED} = Val (888 )) where {SEED}
210- A = pad! (A)
211253 B = similar (A);
212254 ws = 🦋generate_random! (B)
213255 🦋mul! (copyto! (B, A), ws)
214- U, V, B = materializeUV (B, ws)
256+ U, V = materializeUV (B, ws)
215257 F = RecursiveFactorization. lu! (B, Val (false ))
216- A, U, V, F
258+
259+ U, V, F
217260end
218261
219262const butterfly_workspace = 🦋workspace;
@@ -284,30 +327,7 @@ function diagnegbottom(x)
284327 Diagonal (y), Diagonal (z)
285328end
286329
287- # 🦋(A, B) = [A B
288- # A -B]
289-
290- # Bu2 = [🦋(U₁u, U₁l) 0*I
291- # 0*I 🦋(U₂u, U₂l)]
292- # U1u U1l 0 0
293- # U1u -U1l 0 0
294- #=
295- function 🦋!(C, A, B)
296- A1, A2 = size(A)
297- B1, B2 = size(B)
298- @assert A1 == B1
299- for j in 1 : A2, i in 1 : A1
300- C[i, j] = A[i, j]
301- C[i + A1, j] = A[i, j]
302- end
303- for j in A2 + 1 : A2 + B2, i in 1 : A1
304- C[i, j] = B[i, j - A2]
305- C[i + A1, j] = -B[i, j - A2]
306- end
307- C
308- end
309- =#
310- function 🦋! (C, A:: Diagonal , B:: Diagonal )
330+ function 🦋2 !(C, A:: Diagonal , B:: Diagonal )
311331 @assert size (A) == size (B)
312332 A1 = size (A, 1 )
313333
@@ -321,19 +341,32 @@ function 🦋!(C, A::Diagonal, B::Diagonal)
321341 C
322342end
323343
324- function 🦋! (C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
325- @assert size (A) == size (B)
344+ function 🦋! (A:: Matrix , C:: SparseBandedMatrix , X:: Diagonal , Y:: Diagonal )
345+ @assert size (X) == size (Y)
346+ if (size (X, 1 ) + size (Y, 1 ) != size (A, 1 ))
347+ x = size (A, 1 ) - size (X, 1 ) - size (Y, 1 )
348+ setdiagonal! (C, [X. diag; rand (x); - Y. diag], true )
349+ setdiagonal! (C, X. diag, true )
350+ setdiagonal! (C, Y. diag, false )
351+ else
352+ setdiagonal! (C, [X. diag; - Y. diag], true )
353+ setdiagonal! (C, X. diag, true )
354+ setdiagonal! (C, Y. diag, false )
355+ end
326356
357+ C
358+ end
359+
360+ function 🦋2 !(C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
327361 setdiagonal! (C, [A. diag; - B. diag], true )
328362 setdiagonal! (C, A. diag, true )
329363 setdiagonal! (C, B. diag, false )
330364 C
331365end
332366
333-
334367function materializeUV (A, (uv,))
335368 M, N = size (A)
336- Mh = M >>> 1
369+ Mh = M >>> 1
337370 Nh = N >>> 1
338371
339372 U₁u, U₁l = diagnegbottom (@view (uv[1 : Mh]))
@@ -346,30 +379,46 @@ function materializeUV(A, (uv,))
346379 # WRITE OUT MERGINGS EXPLICITLY
347380 # Bu2 = [🦋(U₁u, U₁l) 0*I
348381 # 0*I 🦋(U₂u, U₂l)]
349- # show size(Bu2)[1] #808
350- # @show size(🦋(V₁u, V₁l))[1] #404
351382
352383 # Bu2 = spzeros(M, N)
384+
385+ mrng = VectorizedRNG. MutableXoshift (888 )
386+ T = typeof (uv[1 ])
387+
353388 Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
354389
355- 🦋! (view (Bu2, 1 : M ÷ 2 , 1 : N ÷ 2 ), U₁u, U₁l)
356- 🦋! (view (Bu2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), U₂u, U₂l)
390+ 🦋2 !(view (Bu2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 ), U₁u, U₁l)
391+ 🦋2 !(view (Bu2, M - M ÷ 4 * 2 + 1 : M, N - N ÷ 4 * 2 + 1 : N), U₂u, U₂l)
392+ rand! (mrng, diag (view (Bu2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 )), static (0 ), T (- 0.05 ), T (0.1 ))
393+
357394
358395 # Bu1 = spzeros(M, N)
359396 Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
360- 🦋! (Bu1, Uu, Ul)
397+ 🦋! (A, Bu1, Uu, Ul)
361398
362399 # Bv2 = spzeros(M, N)
363400 Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
364401
365- 🦋! (view (Bv2, 1 : M ÷ 2 , 1 : N ÷ 2 ), V₁u, V₁l)
366- 🦋! (view (Bv2, M ÷ 2 + 1 : M, N ÷ 2 + 1 : N), V₂u, V₂l)
402+ 🦋2 !(view (Bv2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 ), V₁u, V₁l)
403+ 🦋2 !(view (Bv2, M - M ÷ 4 * 2 + 1 : M, N - N ÷ 4 * 2 + 1 : N), V₂u, V₂l)
404+ rand! (mrng, diag (view (Bv2, 1 : (M ÷ 4 ) * 2 , 1 : (N ÷ 4 ) * 2 )), static (0 ), T (- 0.05 ), T (0.1 ))
367405
368406 # Bv1 = spzeros(M, N)
369407 Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N)
370- 🦋! (Bv1, Vu, Vl)
408+ 🦋! (A, Bv1, Vu, Vl)
371409
372- (Bu2 * Bu1)' , Bv2 * Bv1, A
410+ # U = similar(A)
411+ # U = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
412+
413+ # mul!(U, Bu2, Bu1, 1, 0)
414+
415+ # V = similar(A)
416+ # V = SparseBandedMatrix{typeof(uv[1])}(undef, M, N)
417+ # mul!(V, Bv2, Bv1, 1, 0)
418+ # U = sparse(U)
419+ # V = sparse(V)
420+
421+ (Bu2 * Bu1)' , Bv2 * Bv1
373422end
374423
375424function pad! (A)
@@ -389,4 +438,85 @@ function pad!(A)
389438 @inbounds A_new[i,j] = i == j
390439 end
391440 A_new
392- end
441+ end
442+
443+
444+
445+
446+
447+
448+
449+
450+
451+ #=
452+ using SparseArrays, BenchmarkTools, Random
453+
454+ function get_data1()
455+ dim = 5000
456+ x = rand(10:75)
457+ diag_vals = Vector{Vector{Float64}}(undef, x)
458+ diag_locs = randperm(dim * 2 - 1)[1:x]
459+ for j in 1:x
460+ diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
461+ end
462+
463+ x_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
464+ x_dense = copy(x_butterfly)
465+
466+ y = rand(dim, dim)
467+ z = zeros(dim, dim)
468+
469+ @show norm(x_dense*y - x_butterfly * y)
470+
471+ println("Timing dense multiplication.")
472+ println("(left-side mul)")
473+ @btime x_dense*y;
474+ println("(right-side mul)")
475+ @btime y*x_dense;
476+ println("\nTiming butterfly multiplication.")
477+ println("(left-side mul)")
478+ @btime x_butterfly*y;
479+ println("(right-side mul)")
480+ @btime y*x_butterfly;
481+
482+ nothing
483+ end
484+
485+ function get_data2()
486+ dim = 1000
487+ x = rand(10:40)
488+ diag_vals = Vector{Vector{Float64}}(undef, x)
489+ diag_locs = randperm(dim * 2 - 1)[1:x]
490+ for j in 1:x
491+ diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
492+ end
493+
494+ x_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
495+ x_dense = copy(x_butterfly)
496+ x_sparse = sparse(x_dense)
497+
498+ y = rand(10:40)
499+ diag_vals = Vector{Vector{Float64}}(undef, y)
500+ diag_locs = randperm(dim * 2 - 1)[1:y]
501+ for j in 1:y
502+ diag_vals[j] = rand(min(diag_locs[j], 2 * dim - diag_locs[j]))
503+ end
504+
505+ y_butterfly = SparseBandedMatrix{Float64}(diag_locs, diag_vals, dim, dim)
506+ y_dense = copy(y_butterfly)
507+ y_sparse = sparse(y_dense)
508+
509+ a = true
510+ b = false
511+ @assert isapprox(x_butterfly * y_butterfly, x_dense * y_dense)
512+ println("Timing butterfly multiplication.")
513+ @btime mul!(zeros(dim, dim), x_butterfly, y_butterfly, a, b);
514+ println("\nTiming sparse multiplication.")
515+ @btime mul!(zeros(dim, dim), x_sparse, y_sparse, a, b);
516+ println("\nTiming dense multiplication.")
517+ @btime mul!(zeros(dim, dim), x_dense, y_dense, a, b);
518+
519+ nothing
520+ end
521+ =#
522+
0 commit comments