Skip to content

Commit e6fdb68

Browse files
committed
[WIP] Start handling abstract block types
1 parent 89faa55 commit e6fdb68

File tree

5 files changed

+48
-9
lines changed

5 files changed

+48
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.14"
4+
version = "0.7.15"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,37 @@ function ArrayLayouts.MemoryLayout(
2323
end
2424

2525
function Base.similar(
26-
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
26+
mul::MulAdd{
27+
<:BlockLayout{<:SparseLayout,BlockLayoutA},
28+
<:BlockLayout{<:SparseLayout,BlockLayoutB},
29+
LayoutC,
30+
T,
31+
A,
32+
B,
33+
C,
34+
},
2735
elt::Type,
2836
axes,
29-
) where {A,B}
30-
# TODO: Use something like `Base.promote_op(*, A, B)` to determine the output block type.
31-
output_blocktype = similartype(blocktype(A), Type{elt}, Tuple{blockaxistype.(axes)...})
32-
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
37+
) where {BlockLayoutA,BlockLayoutB,LayoutC,T,A,B,C}
38+
39+
# TODO: Consider using this instead:
40+
# ```julia
41+
# blockmultype = MulAdd{BlockLayoutA,BlockLayoutB,LayoutC,T,blocktype(A),blocktype(B),C}
42+
# output_blocktype = Base.promote_op(
43+
# similar, blockmultype, Type{elt}, Tuple{eltype.(eachblockaxis.(axes))...}
44+
# )
45+
# ```
46+
# The issue is that it in some cases it seems to lose some information about the block types.
47+
48+
# TODO: Maybe this should be:
49+
# output_blocktype = Base.promote_op(
50+
# mul!, blocktype(mul.A), blocktype(mul.B), blocktype(mul.C), typeof(mul.α), typeof(mul.β)
51+
# )
52+
53+
output_blocktype = Base.promote_op(*, blocktype(mul.A), blocktype(mul.B))
54+
output_blocktype′ =
55+
!isconcretetype(output_blocktype) ? AbstractMatrix{elt} : output_blocktype
56+
return similar(BlockSparseArray{elt,length(axes),output_blocktype′}, axes)
3357
end
3458

3559
# Materialize a SubArray view.

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,11 @@ end
231231

232232
function blocksparse_similar(a, elt::Type, axes::Tuple)
233233
ndims = length(axes)
234-
blockt = similartype(blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
235-
return BlockSparseArray{elt,ndims,blockt}(undef, axes)
234+
# TODO: Define a version of `similartype` that catches the case
235+
# where the output isn't concrete and returns an `AbstractArray`.
236+
blockt = Base.promote_op(similar, blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
237+
blockt′ = !isconcretetype(blockt) ? AbstractArray{elt,ndims} : blockt
238+
return BlockSparseArray{elt,ndims,blockt′}(undef, axes)
236239
end
237240
@interface ::AbstractBlockSparseArrayInterface function Base.similar(
238241
a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}}

src/blocksparsearrayinterface/getunstoredblock.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ end
1111
ax = ntuple(N) do d
1212
return only(axes(f.axes[d][Block(I[d])]))
1313
end
14+
!isconcretetype(A) && return zero!(similar(Array{eltype(A),N}, ax))
1415
return zero!(similar(A, ax))
1516
end
1617
@inline function (f::GetUnstoredBlock)(

src/factorizations/svd.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,21 @@ function MatrixAlgebraKit.default_svd_algorithm(
2222
return BlockPermutedDiagonalAlgorithm(alg)
2323
end
2424

25+
function output_type(
26+
::typeof(svd_compact!),
27+
A::Type{<:AbstractMatrix{T}},
28+
Alg::Type{<:MatrixAlgebraKit.AbstractAlgorithm},
29+
) where {T}
30+
USVᴴ = Base.promote_op(svd_compact!, A, Alg)
31+
!isconcretetype(USVᴴ) &&
32+
return Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
33+
return USVᴴ
34+
end
35+
2536
function similar_output(
2637
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
2738
)
28-
BU, BS, BVᴴ = fieldtypes(Base.promote_op(svd_compact!, blocktype(A), typeof(alg.alg)))
39+
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A), typeof(alg.alg)))
2940
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
3041
S = similar(A, BlockType(BS), S_axes)
3142
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))

0 commit comments

Comments
 (0)