Skip to content

Commit 1b66984

Browse files
committed
Bump version
2 parents c26991b + 2120b7a commit 1b66984

File tree

10 files changed

+160
-51
lines changed

10 files changed

+160
-51
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.3"
4+
version = "0.6.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -39,7 +39,7 @@ GPUArraysCore = "0.1.0, 0.2"
3939
LinearAlgebra = "1.10"
4040
MacroTools = "0.5.13"
4141
MapBroadcast = "0.1.5"
42-
MatrixAlgebraKit = "0.1.2"
42+
MatrixAlgebraKit = "0.2"
4343
SparseArraysBase = "0.5"
4444
SplitApplyCombine = "1.2.3"
4545
TensorAlgebra = "0.3.2"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
66

77
[compat]
88
BlockArrays = "1"
9-
BlockSparseArrays = "0.5"
9+
BlockSparseArrays = "0.6"
1010
Documenter = "1"
1111
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55

66
[compat]
77
BlockArrays = "1"
8-
BlockSparseArrays = "0.5"
8+
BlockSparseArrays = "0.6"
99
Test = "1"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,6 @@ function blockrange(axis::AbstractUnitRange, r::Int)
290290
return error("Slicing with integer values isn't supported.")
291291
end
292292

293-
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
294-
for b in r
295-
@assert b blockaxes(axis, 1)
296-
end
297-
return r
298-
end
299-
300293
# This handles changing the blocking, for example:
301294
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
302295
# I = blockedrange([4, 4])
@@ -315,13 +308,20 @@ end
315308
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
316309
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
317310
# a[I, I]
318-
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
311+
function blockrange(axis::AbstractUnitRange, r::AbstractBlockVector{<:Block{1}})
319312
for b in r
320313
@assert b blockaxes(axis, 1)
321314
end
322315
return only(blockaxes(r))
323316
end
324317

318+
function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
319+
for b in r
320+
@assert b blockaxes(axis, 1)
321+
end
322+
return r
323+
end
324+
325325
using BlockArrays: BlockSlice
326326
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
327327
return blockrange(axis, r.block)

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,32 @@ using BlockArrays:
99
BlockSlice,
1010
BlockVector,
1111
block,
12+
blockedrange,
1213
blockindex,
14+
blocklengths,
1315
findblock,
1416
findblockindex,
1517
mortar
1618

19+
# Get the axes of each block of a block array.
20+
function eachblockaxes(a::AbstractArray)
21+
return map(axes, blocks(a))
22+
end
23+
24+
axis(a::AbstractVector) = axes(a, 1)
25+
26+
# Get the axis of each block of a blocked unit
27+
# range.
28+
function eachblockaxis(a::AbstractVector)
29+
return map(axis, blocks(a))
30+
end
31+
32+
# Take a collection of axes and mortar them
33+
# into a single blocked axis.
34+
function mortar_axis(axs)
35+
return blockedrange(length.(axs))
36+
end
37+
1738
# Custom `BlockedUnitRange` constructor that takes a unit range
1839
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
1940
function blockedunitrange(a::AbstractUnitRange, blocklengths)

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ end
445445

446446
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
447447
to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks)
448-
to_blocks_indices(I::Base.Slice{<:BlockedOneTo}) = Base.OneTo(blocklength(I.indices))
448+
to_blocks_indices(I::Base.Slice) = Base.OneTo(blocklength(I.indices))
449449

450450
@interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(
451451
a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}}

src/blocksparsearrayinterface/map.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,52 @@
1+
using BlockArrays: BlockRange, blockisequal
12
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
23
using GPUArraysCore: @allowscalar
34

5+
# Check if the block structures are the same.
6+
function same_block_structure(as::AbstractArray...)
7+
isempty(as) && return true
8+
return all(
9+
ntuple(ndims(first(as))) do dim
10+
ax = map(Base.Fix2(axes, dim), as)
11+
return blockisequal(ax...)
12+
end,
13+
)
14+
end
15+
16+
# Find the common stored blocks, assuming the block structures are the same.
17+
function union_eachblockstoredindex(as::AbstractArray...)
18+
return (map(eachblockstoredindex, as)...)
19+
end
20+
21+
function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
22+
# TODO: This assumes element types are numbers, generalize this logic.
23+
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
24+
Is = if f_preserves_zeros
25+
union_eachblockstoredindex(a_dest, a_srcs...)
26+
else
27+
BlockRange(a_dest)
28+
end
29+
for I in Is
30+
# TODO: Use:
31+
# block_dest = @view a_dest[I]
32+
# or:
33+
# block_dest = @view! a_dest[I]
34+
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...]
35+
# TODO: Use:
36+
# block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
37+
block_srcs = map(a_srcs) do a_src
38+
return blocks_maybe_single(a_src)[Int.(Tuple(I))...]
39+
end
40+
# TODO: Use `map!!` to handle immutable blocks.
41+
map!(f, block_dest, block_srcs...)
42+
# Replace the entire block, handles initializing new blocks
43+
# or if blocks are immutable.
44+
# TODO: Use `a_dest[I] = block_dest`.
45+
blocks(a_dest)[Int.(Tuple(I))...] = block_dest
46+
end
47+
return a_dest
48+
end
49+
450
# TODO: Rewrite this so that it takes the blocking structure
551
# made by combining the blocking of the axes (i.e. the blocking that
652
# is used to determine `union_stored_blocked_cartesianindices(...)`).
@@ -16,6 +62,10 @@ using GPUArraysCore: @allowscalar
1662
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
1763
return a_dest
1864
end
65+
if same_block_structure(a_dest, a_srcs...)
66+
map_blockwise!(f, a_dest, a_srcs...)
67+
return a_dest
68+
end
1969
# TODO: This assumes element types are numbers, generalize this logic.
2070
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
2171
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)

src/factorizations/svd.jl

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,37 @@ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
1212
alg::A
1313
end
1414

15-
# TODO: this is a hardcoded for now to get around this function not being defined in the
16-
# type domain
17-
function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
15+
function default_blocksparse_svd_algorithm(f, A; kwargs...)
1816
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
1917
error("unsupported type: $(blocktype(A))")
18+
# TODO: this is a hardcoded for now to get around this function not being defined in the
19+
# type domain
20+
# alg = MatrixAlgebraKit.default_algorithm(f, blocktype(A); kwargs...)
2021
alg = MatrixAlgebraKit.LAPACK_DivideAndConquer(; kwargs...)
2122
return BlockPermutedDiagonalAlgorithm(alg)
2223
end
2324

24-
# TODO: this should be replaced with a more general similar function that can handle setting
25-
# the blocktype and element type - something like S = similar(A, BlockType(...))
26-
function _similar_S(A::AbstractBlockSparseMatrix, s_axis)
25+
function MatrixAlgebraKit.default_algorithm(
26+
f::typeof(svd_compact!), A::AbstractBlockSparseMatrix; kwargs...
27+
)
28+
return default_blocksparse_svd_algorithm(f, A; kwargs...)
29+
end
30+
function MatrixAlgebraKit.default_algorithm(
31+
f::typeof(svd_full!), A::AbstractBlockSparseMatrix; kwargs...
32+
)
33+
return default_blocksparse_svd_algorithm(f, A; kwargs...)
34+
end
35+
36+
function similar_output(
37+
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
38+
)
39+
U = similar(A, axes(A, 1), S_axes[1])
2740
T = real(eltype(A))
28-
return BlockSparseArray{T,2,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
41+
# TODO: this should be replaced with a more general similar function that can handle setting
42+
# the blocktype and element type - something like S = similar(A, BlockType(...))
43+
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, S_axes)
44+
Vt = similar(A, S_axes[2], axes(A, 2))
45+
return U, S, Vt
2946
end
3047

3148
function MatrixAlgebraKit.initialize_output(
@@ -34,33 +51,36 @@ function MatrixAlgebraKit.initialize_output(
3451
bm, bn = blocksize(A)
3552
bmn = min(bm, bn)
3653

37-
brows = blocklengths(axes(A, 1))
38-
bcols = blocklengths(axes(A, 2))
39-
slengths = Vector{Int}(undef, bmn)
54+
brows = eachblockaxis(axes(A, 1))
55+
bcols = eachblockaxis(axes(A, 2))
56+
u_axes = similar(brows, bmn)
57+
v_axes = similar(brows, bmn)
4058

4159
# fill in values for blocks that are present
4260
bIs = collect(eachblockstoredindex(A))
4361
browIs = Int.(first.(Tuple.(bIs)))
4462
bcolIs = Int.(last.(Tuple.(bIs)))
4563
for bI in eachblockstoredindex(A)
4664
row, col = Int.(Tuple(bI))
47-
nrows = brows[row]
48-
ncols = bcols[col]
49-
slengths[col] = min(nrows, ncols)
65+
len = minimum(length, (brows[row], bcols[col]))
66+
u_axes[col] = brows[row][Base.OneTo(len)]
67+
v_axes[col] = bcols[col][Base.OneTo(len)]
5068
end
5169

5270
# fill in values for blocks that aren't present, pairing them in order of occurence
5371
# this is a convention, which at least gives the expected results for blockdiagonal
5472
emptyrows = setdiff(1:bm, browIs)
5573
emptycols = setdiff(1:bn, bcolIs)
5674
for (row, col) in zip(emptyrows, emptycols)
57-
slengths[col] = min(brows[row], bcols[col])
75+
len = minimum(length, (brows[row], bcols[col]))
76+
u_axes[col] = brows[row][Base.OneTo(len)]
77+
v_axes[col] = bcols[col][Base.OneTo(len)]
5878
end
5979

60-
s_axis = blockedrange(slengths)
61-
U = similar(A, axes(A, 1), s_axis)
62-
S = _similar_S(A, s_axis)
63-
Vt = similar(A, s_axis, axes(A, 2))
80+
u_axis = mortar_axis(u_axes)
81+
v_axis = mortar_axis(v_axes)
82+
S_axes = (u_axis, v_axis)
83+
U, S, Vt = similar_output(svd_compact!, A, S_axes, alg)
6484

6585
# allocate output
6686
for bI in eachblockstoredindex(A)
@@ -79,40 +99,47 @@ function MatrixAlgebraKit.initialize_output(
7999
return U, S, Vt
80100
end
81101

102+
function similar_output(
103+
::typeof(svd_full!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
104+
)
105+
U = similar(A, axes(A, 1), S_axes[1])
106+
T = real(eltype(A))
107+
S = similar(A, T, S_axes)
108+
Vt = similar(A, S_axes[2], axes(A, 2))
109+
return U, S, Vt
110+
end
111+
82112
function MatrixAlgebraKit.initialize_output(
83113
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
84114
)
85115
bm, bn = blocksize(A)
86116

87-
brows = blocklengths(axes(A, 1))
88-
slengths = copy(brows)
117+
brows = eachblockaxis(axes(A, 1))
118+
u_axes = similar(brows)
89119

90120
# fill in values for blocks that are present
91121
bIs = collect(eachblockstoredindex(A))
92122
browIs = Int.(first.(Tuple.(bIs)))
93123
bcolIs = Int.(last.(Tuple.(bIs)))
94124
for bI in eachblockstoredindex(A)
95125
row, col = Int.(Tuple(bI))
96-
nrows = brows[row]
97-
slengths[col] = nrows
126+
u_axes[col] = brows[row]
98127
end
99128

100129
# fill in values for blocks that aren't present, pairing them in order of occurence
101130
# this is a convention, which at least gives the expected results for blockdiagonal
102131
emptyrows = setdiff(1:bm, browIs)
103132
emptycols = setdiff(1:bn, bcolIs)
104133
for (row, col) in zip(emptyrows, emptycols)
105-
slengths[col] = brows[row]
134+
u_axes[col] = brows[row]
106135
end
107136
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
108-
slengths[bn + i] = brows[emptyrows[k]]
137+
u_axes[bn + i] = brows[emptyrows[k]]
109138
end
110139

111-
s_axis = blockedrange(slengths)
112-
U = similar(A, axes(A, 1), s_axis)
113-
Tr = real(eltype(A))
114-
S = similar(A, Tr, (s_axis, axes(A, 2)))
115-
Vt = similar(A, axes(A, 2), axes(A, 2))
140+
u_axis = mortar_axis(u_axes)
141+
S_axes = (u_axis, axes(A, 2))
142+
U, S, Vt = similar_output(svd_full!, A, S_axes, alg)
116143

117144
# allocate output
118145
for bI in eachblockstoredindex(A)

src/factorizations/truncation.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ function MatrixAlgebraKit.findtruncated(
4545
return indexmask
4646
end
4747

48+
function similar_truncate(
49+
::typeof(svd_trunc!),
50+
(U, S, Vᴴ)::TBlockUSVᴴ,
51+
strategy::BlockPermutedDiagonalTruncationStrategy,
52+
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
53+
)
54+
ax = axes(S, 1)
55+
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
56+
s_lengths = filter!(>(0), map(counter, blocks(ax)))
57+
s_axis = blockedrange(s_lengths)
58+
= similar(U, axes(U, 1), s_axis)
59+
= similar(S, s_axis, s_axis)
60+
Ṽᴴ = similar(Vᴴ, s_axis, axes(Vᴴ, 2))
61+
return Ũ, S̃, Ṽᴴ
62+
end
63+
4864
function MatrixAlgebraKit.truncate!(
4965
::typeof(svd_trunc!),
5066
(U, S, Vᴴ)::TBlockUSVᴴ,
@@ -54,13 +70,7 @@ function MatrixAlgebraKit.truncate!(
5470

5571
# first determine the block structure of the output to avoid having assumptions on the
5672
# data structures
57-
ax = axes(S, 1)
58-
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
59-
Slengths = filter!(>(0), map(counter, blocks(ax)))
60-
Sax = blockedrange(Slengths)
61-
= similar(U, axes(U, 1), Sax)
62-
= similar(S, Sax, Sax)
63-
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))
73+
Ũ, S̃, Ṽᴴ = similar_truncate(svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
6474

6575
# then loop over the blocks and assign the data
6676
# TODO: figure out if we can presort and loop over the blocks -
@@ -70,6 +80,7 @@ function MatrixAlgebraKit.truncate!(
7080
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))
7181

7282
I′ = 0 # number of skipped blocks that got fully truncated
83+
ax = axes(S, 1)
7384
for I in 1:blocksize(ax, 1)
7485
b = ax[Block(I)]
7586
mask = indexmask[b]

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ Adapt = "4"
2323
Aqua = "0.8"
2424
ArrayLayouts = "1"
2525
BlockArrays = "1"
26-
BlockSparseArrays = "0.5"
26+
BlockSparseArrays = "0.6"
2727
DiagonalArrays = "0.3"
2828
GPUArraysCore = "0.2"
2929
JLArrays = "0.2"
3030
LinearAlgebra = "1"
31-
MatrixAlgebraKit = "0.1"
31+
MatrixAlgebraKit = "0.2"
3232
Random = "1"
3333
SafeTestsets = "0.1"
3434
SparseArraysBase = "0.5"

0 commit comments

Comments
 (0)