Skip to content

Commit 56ee5a5

Browse files
committed
[WIP] Make more operations more agnostic about the block type
1 parent 6be3e32 commit 56ee5a5

File tree

6 files changed

+46
-12
lines changed

6 files changed

+46
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.4"
4+
version = "0.7.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ end
3333
# Take a collection of axes and mortar them
3434
# into a single blocked axis.
3535
function mortar_axis(axs)
36-
return blockedrange(length.(axs))
36+
## return blockedrange(length.(axs))
37+
return blockrange(axs)
3738
end
3839

3940
# Custom `BlockedUnitRange` constructor that takes a unit range

src/abstractblocksparsearray/map.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,13 @@ function Base.isreal(a::AnyAbstractBlockSparseArray)
111111
return @interface interface(a) isreal(a)
112112
end
113113

114+
# Helps with specialization.
114115
function Base.:*(x::Number, a::AnyAbstractBlockSparseArray)
115116
return map(Base.Fix1(*, x), a)
116117
end
117118
function Base.:*(a::AnyAbstractBlockSparseArray, x::Number)
118119
return map(Base.Fix2(*, x), a)
119120
end
121+
function Base.:/(a::AnyAbstractBlockSparseArray, x::Number)
122+
return map(Base.Fix2(/, x), a)
123+
end

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,16 @@ function Base.similar(
230230
return similar(arraytype, eltype(arraytype), axes)
231231
end
232232

233+
# This circumvents some issues with `TypeParameterAccessors.similartype`.
234+
# TODO: Fix this poperly in `TypeParameterAccessors.jl`.
235+
function _similartype(arraytype::Type{<:AbstractArray}, elt::Type, axt)
236+
return Base.promote_op(similar, arraytype, elt, axt)
237+
end
238+
233239
function blocksparse_similar(a, elt::Type, axes::Tuple)
234-
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}(
235-
undef, axes
236-
)
240+
block_axt = Tuple{eltype.(eachblockaxis.(axes))...}
241+
blockt = _similartype(blocktype(a), Type{elt}, block_axt)
242+
return BlockSparseArray{elt,length(axes),blockt}(undef, axes)
237243
end
238244
@interface ::AbstractBlockSparseArrayInterface function Base.similar(
239245
a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}}

src/blocksparsearray/blocksparsearray.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ end
173173
function BlockSparseArray{T,N}(
174174
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}}
175175
) where {T,N}
176-
return BlockSparseArray{T,N,Array{T,N}}(undef, axes)
176+
# TODO: Use `similartype` to determine the block type.
177+
A = Base.promote_op(similar, Array{T}, Tuple{eltype.(eachblockaxis.(axes))...})
178+
return BlockSparseArray{T,N,A}(undef, axes)
177179
end
178180

179181
function BlockSparseArray{T,N}(
@@ -230,6 +232,20 @@ function BlockSparseArray{T}(
230232
return BlockSparseArray{T}(undef, axes)
231233
end
232234

235+
function blocksparsezeros(elt::Type, axes...)
236+
return BlockSparseArray{elt}(undef, axes...)
237+
end
238+
function blocksparsezeros(::BlockType{A}, axes...) where {A<:AbstractArray}
239+
return BlockSparseArray{eltype(A),ndims(A),A}(undef, axes...)
240+
end
241+
function blocksparse(d::Dict{<:Block,<:AbstractArray}, axes...)
242+
a = blocksparsezeros(BlockType(valtype(d)), axes...)
243+
for I in eachindex(d)
244+
a[I] = d[I]
245+
end
246+
return a
247+
end
248+
233249
# Base `AbstractArray` interface
234250
Base.axes(a::BlockSparseArray) = a.axes
235251

src/factorizations/svd.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,21 @@ function MatrixAlgebraKit.default_svd_algorithm(
2020
return BlockPermutedDiagonalAlgorithm(alg)
2121
end
2222

23+
# TODO: Put this in a common location or package,
24+
# maybe `TypeParameterAccessors.jl`?
25+
# Also define `imagtype`, `complextype`, etc.
26+
realtype(a::AbstractArray) = realtype(typeof(a))
27+
function realtype(A::Type{<:AbstractArray})
28+
return Base.promote_op(real, A)
29+
end
30+
31+
using DiagonalArrays: diagonaltype
2332
function similar_output(
2433
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
2534
)
2635
U = similar(A, axes(A, 1), S_axes[1])
2736
T = real(eltype(A))
28-
# TODO: this should be replaced with a more general similar function that can handle setting
29-
# the blocktype and element type - something like S = similar(A, BlockType(...))
30-
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, S_axes)
37+
S = similar(A, BlockType(diagonaltype(realtype(blocktype(A)))), S_axes)
3138
Vt = similar(A, S_axes[2], axes(A, 2))
3239
return U, S, Vt
3340
end
@@ -49,9 +56,9 @@ function MatrixAlgebraKit.initialize_output(
4956
bcolIs = Int.(last.(Tuple.(bIs)))
5057
for bI in eachblockstoredindex(A)
5158
row, col = Int.(Tuple(bI))
52-
len = minimum(length, (brows[row], bcols[col]))
53-
u_axes[col] = brows[row][Base.OneTo(len)]
54-
v_axes[col] = bcols[col][Base.OneTo(len)]
59+
b = argmin(length, (brows[row], bcols[col]))
60+
u_axes[col] = b
61+
v_axes[col] = b
5562
end
5663

5764
# fill in values for blocks that aren't present, pairing them in order of occurence

0 commit comments

Comments
 (0)