Skip to content

Commit d4365f0

Browse files
committed
Generalize matricize
1 parent 83ed436 commit d4365f0

File tree

7 files changed

+52
-19
lines changed

7 files changed

+52
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.8.1"
4+
version = "0.8.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,41 @@ function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray})
1515
return BlockReshapeFusion()
1616
end
1717

18+
using BlockArrays: Block, blocklength, blocks
19+
using BlockSparseArrays: blocksparse
20+
using SparseArraysBase: eachstoredindex
21+
using TensorAlgebra: TensorAlgebra, matricize, unmatricize
1822
function TensorAlgebra.matricize(
1923
::BlockReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
2024
)
21-
new_axes = fuseaxes(axes(a), biperm)
22-
return blockreshape(a, new_axes)
25+
ax = fuseaxes(axes(a), biperm)
26+
reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax))
27+
bs = Dict(
28+
Block(Tuple(I)) => matricize(reshaped_blocks_a[I], biperm) for
29+
I in eachstoredindex(reshaped_blocks_a)
30+
)
31+
return blocksparse(bs, ax)
2332
end
2433

34+
using BlockArrays: blocklengths
2535
function TensorAlgebra.unmatricize(
2636
::BlockReshapeFusion,
2737
m::AbstractMatrix,
28-
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
38+
blocked_ax::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
2939
)
30-
return blockreshape(m, Tuple(blocked_axes)...)
40+
ax = Tuple(blocked_ax)
41+
reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax))
42+
function f(I)
43+
block_axes_I = BlockedTuple(
44+
map(ntuple(identity, length(ax))) do i
45+
return Base.axes1(ax[i][Block(I[i])])
46+
end,
47+
blocklengths(blocked_ax),
48+
)
49+
return unmatricize(reshaped_blocks_m[I], block_axes_I)
50+
end
51+
bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_m))
52+
return blocksparse(bs, ax)
3153
end
3254

3355
end

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,15 @@ tuple_oneto(n) = ntuple(identity, n)
267267

268268
function _blockreshape(a::AbstractArray, axes::Tuple{Vararg{AbstractUnitRange}})
269269
reshaped_blocks_a = reshape(blocks(a), blocklength.(axes))
270-
reshaped_a = similar(a, axes)
271-
for I in eachstoredindex(reshaped_blocks_a)
272-
block_size_I = map(i -> length(axes[i][Block(I[i])]), tuple_oneto(length(axes)))
270+
function f(I)
271+
block_axes_I = map(ntuple(identity, length(axes))) do i
272+
return Base.axes1(axes[i][Block(I[i])])
273+
end
273274
# TODO: Better converter here.
274-
reshaped_a[Block(Tuple(I))] = reshape(reshaped_blocks_a[I], block_size_I)
275+
return reshape(reshaped_blocks_a[I], block_axes_I)
275276
end
276-
return reshaped_a
277+
bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_a))
278+
return blocksparse(bs, axes)
277279
end
278280

279281
function blockreshape(

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ function Base.setindex!(
9292
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
9393
# https://github.com/JuliaLang/julia/pull/52487).
9494
# TODO: Delete once we drop support for Julia v1.10.
95-
blocks(a)[Int.(I)...] = _convert(blocktype(a), value)
95+
aI = @view! a[I...]
96+
copyto!(aI, value)
9697
return a
9798
end
9899

src/blocksparsearray/blocksparsearray.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,19 @@ function blocksparsezeros(::BlockType{A}, axes...) where {A<:AbstractArray}
262262
# to make a bit more generic.
263263
return BlockSparseArray{eltype(A),ndims(A),A}(undef, axes...)
264264
end
265-
function blocksparse(d::Dict{<:Block,<:AbstractArray}, axes...)
266-
a = blocksparsezeros(BlockType(valtype(d)), axes...)
265+
function blocksparse(d::Dict{<:Block,<:AbstractArray}, ax::Tuple)
266+
blockaxtype = Tuple{map(eltype eachblockaxis, ax)...}
267+
# TODO: Catch if inference fails and use `valtype(d)` instead.
268+
blockt = Base.promote_op(similar, Type{valtype(d)}, blockaxtype)
269+
a = blocksparsezeros(BlockType(blockt), ax)
267270
for I in eachindex(d)
268271
a[I] = d[I]
269272
end
270273
return a
271274
end
275+
function blocksparse(d::Dict{<:Block,<:AbstractArray}, ax::AbstractUnitRange...)
276+
return blocksparse(d, ax)
277+
end
272278

273279
# Base `AbstractArray` interface
274280
Base.axes(a::BlockSparseArray) = a.axes

src/blocksparsearrayinterface/getunstoredblock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
@inline function Base.getindex(a::ZeroBlocks{N,A}, I::Vararg{Int,N}) where {N,A}
2121
# TODO: Use `BlockArrays.eachblockaxes`.
2222
ax = ntuple(N) do d
23-
return only(axes(a.parentaxes[d][Block(I[d])]))
23+
return eachblockaxis(a.parentaxes[d])[I[d]]
2424
end
2525
!isconcretetype(A) && return zero!(similar(Array{eltype(A),N}, ax))
2626
return zero!(similar(A, ax))

src/factorizations/svd.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,22 @@ function MatrixAlgebraKit.initialize_output(
4949
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
5050
)
5151
bm, bn = blocksize(A)
52-
bmn = min(bm, bn)
52+
(bmn, mindim) = findmin((bm, bn))
5353

5454
brows = eachblockaxis(axes(A, 1))
5555
bcols = eachblockaxis(axes(A, 2))
5656
u_axes = similar(brows, bmn)
57-
v_axes = similar(brows, bmn)
57+
v_axes = similar(bcols, bmn)
5858

5959
# fill in values for blocks that are present
6060
bIs = collect(eachblockstoredindex(A))
6161
browIs = Int.(first.(Tuple.(bIs)))
6262
bcolIs = Int.(last.(Tuple.(bIs)))
6363
for bI in eachblockstoredindex(A)
6464
row, col = Int.(Tuple(bI))
65-
u_axes[col] = infimum(brows[row], bcols[col])
66-
v_axes[col] = infimum(bcols[col], brows[row])
65+
dim = (row, col)[mindim]
66+
u_axes[dim] = infimum(brows[row], bcols[col])
67+
v_axes[dim] = infimum(bcols[col], brows[row])
6768
end
6869

6970
# fill in values for blocks that aren't present, pairing them in order of occurence
@@ -83,9 +84,10 @@ function MatrixAlgebraKit.initialize_output(
8384
# allocate output
8485
for bI in eachblockstoredindex(A)
8586
brow, bcol = Tuple(bI)
87+
bdim = (brow, bcol)[mindim]
8688
block = @view!(A[bI])
8789
block_alg = block_algorithm(alg, block)
88-
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
90+
U[brow, bdim], S[bdim, bdim], Vt[bdim, bcol] = MatrixAlgebraKit.initialize_output(
8991
svd_compact!, block, block_alg
9092
)
9193
end

0 commit comments

Comments
 (0)