Skip to content

Commit 9c64d2a

Browse files
committed
Move blocked unit range functionality from GradedUnitRanges
1 parent 24d8cec commit 9c64d2a

File tree

4 files changed

+187
-3
lines changed

4 files changed

+187
-3
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.9"
4+
version = "0.3.10"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -12,7 +12,6 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
1212
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1313
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
15-
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
1615
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1716
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1817
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
@@ -21,6 +20,7 @@ SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
2120
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2221

2322
[weakdeps]
23+
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
2424
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2525
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2626

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ using BlockArrays:
2020
findblock,
2121
findblockindex
2222
using Dictionaries: Dictionary, Indices
23-
using GradedUnitRanges: blockedunitrange_getindices, to_blockindices
2423
using SparseArraysBase:
2524
SparseArraysBase,
2625
eachstoredindex,
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
using BlockArrays:
2+
BlockArrays,
3+
AbstractBlockedUnitRange,
4+
AbstractBlockVector,
5+
Block,
6+
BlockIndex,
7+
BlockIndexRange,
8+
BlockRange,
9+
BlockSlice,
10+
BlockVector,
11+
block,
12+
blockindex,
13+
findblock,
14+
findblockindex,
15+
mortar
16+
17+
# Custom `BlockedUnitRange` constructor that takes a unit range
18+
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
19+
function blockedunitrange(a::AbstractUnitRange, blocklengths)
20+
blocklengths_shifted = copy(blocklengths)
21+
blocklengths_shifted[1] += (first(a) - 1)
22+
blocklasts = cumsum(blocklengths_shifted)
23+
return BlockArrays._BlockedUnitRange(first(a), blocklasts)
24+
end
25+
26+
# TODO: Move this to a `BlockArraysExtensions` library.
27+
# TODO: Rename this. `BlockArrays.findblock(a, k)` finds the
28+
# block of the value `k`, while this finds the block of the index `k`.
29+
# This could make use of the `BlockIndices` object, i.e. `block(BlockIndices(a)[index])`.
30+
function blockedunitrange_findblock(a::AbstractBlockedUnitRange, index::Integer)
31+
@boundscheck index in 1:length(a) || throw(BoundsError(a, index))
32+
return @inbounds findblock(a, index + first(a) - 1)
33+
end
34+
35+
# TODO: Move this to a `BlockArraysExtensions` library.
36+
# TODO: Rename this. `BlockArrays.findblockindex(a, k)` finds the
37+
# block index of the value `k`, while this finds the block index of the index `k`.
38+
# This could make use of the `BlockIndices` object, i.e. `BlockIndices(a)[index]`.
39+
function blockedunitrange_findblockindex(a::AbstractBlockedUnitRange, index::Integer)
40+
@boundscheck index in 1:length(a) || throw(BoundsError())
41+
return @inbounds findblockindex(a, index + first(a) - 1)
42+
end
43+
44+
function blockedunitrange_getindices(a::AbstractUnitRange, indices)
45+
return a[indices]
46+
end
47+
48+
# TODO: Move this to a `BlockArraysExtensions` library.
49+
# Like `a[indices]` but preserves block structure.
50+
# TODO: Consider calling this something else, for example
51+
# `blocked_getindex`. See the discussion here:
52+
# https://github.com/JuliaArrays/BlockArrays.jl/issues/347
53+
function blockedunitrange_getindices(
54+
a::AbstractBlockedUnitRange, indices::AbstractUnitRange{<:Integer}
55+
)
56+
first_blockindex = blockedunitrange_findblockindex(a, first(indices))
57+
last_blockindex = blockedunitrange_findblockindex(a, last(indices))
58+
first_block = block(first_blockindex)
59+
last_block = block(last_blockindex)
60+
blocklengths = if first_block == last_block
61+
[length(indices)]
62+
else
63+
map(first_block:last_block) do block
64+
if block == first_block
65+
return length(a[first_block]) - blockindex(first_blockindex) + 1
66+
end
67+
if block == last_block
68+
return blockindex(last_blockindex)
69+
end
70+
return length(a[block])
71+
end
72+
end
73+
return blockedunitrange(indices .+ (first(a) - 1), blocklengths)
74+
end
75+
76+
# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly.
77+
# TODO: Make a special case for `BlockedVector{<:Block{1},<:BlockRange{1}}`?
78+
# For example:
79+
# ```julia
80+
# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices))
81+
# return blockedrange(blocklengths)
82+
# ```
83+
function blockedunitrange_getindices(
84+
a::AbstractBlockedUnitRange, indices::AbstractBlockVector{<:Block{1}}
85+
)
86+
blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))
87+
# We pass `length.(blks)` to `mortar` in order
88+
# to pass block labels to the axes of the output,
89+
# if they exist. This makes it so that
90+
# `only(axes(a[indices])) isa `GradedUnitRange`
91+
# if `a isa `GradedUnitRange`, for example.
92+
# Note there is a more specialized definition:
93+
# ```julia
94+
# function blockedunitrange_getindices(
95+
# a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}}
96+
# )
97+
# ```
98+
# that does a better job of preserving labels, since `length`
99+
# may drop labels for certain block types.
100+
return mortar(blks, length.(blks))
101+
end
102+
103+
# TODO: Move this to a `BlockArraysExtensions` library.
104+
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndexRange)
105+
return a[block(indices)][only(indices.indices)]
106+
end
107+
108+
# TODO: Move this to a `BlockArraysExtensions` library.
109+
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockSlice)
110+
# TODO: Is this a good definition? It ignores `indices.indices`.
111+
return a[indices.block]
112+
end
113+
114+
# TODO: Move this to a `BlockArraysExtensions` library.
115+
function blockedunitrange_getindices(
116+
a::AbstractBlockedUnitRange, indices::Vector{<:Integer}
117+
)
118+
return map(index -> a[index], indices)
119+
end
120+
121+
# TODO: Move this to a `BlockArraysExtensions` library.
122+
# TODO: Make a special definition for `BlockedVector{<:Block{1}}` in order
123+
# to merge blocks.
124+
function blockedunitrange_getindices(
125+
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
126+
)
127+
# Without converting `indices` to `Vector`,
128+
# mapping `indices` outputs a `BlockVector`
129+
# which is harder to reason about.
130+
blocks = map(index -> a[index], Vector(indices))
131+
# We pass `length.(blocks)` to `mortar` in order
132+
# to pass block labels to the axes of the output,
133+
# if they exist. This makes it so that
134+
# `only(axes(a[indices])) isa `GradedUnitRange`
135+
# if `a isa `GradedUnitRange`, for example.
136+
return mortar(blocks, length.(blocks))
137+
end
138+
139+
# TODO: Move this to a `BlockArraysExtensions` library.
140+
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::Block{1})
141+
return a[indices]
142+
end
143+
144+
function blockedunitrange_getindices(
145+
a::AbstractBlockedUnitRange,
146+
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
147+
)
148+
return mortar(map(b -> a[b], blocks(indices)))
149+
end
150+
151+
# TODO: Move this to a `BlockArraysExtensions` library.
152+
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices)
153+
return error("Not implemented.")
154+
end
155+
156+
# The blocks of the corresponding slice.
157+
_blocks(a::AbstractUnitRange, indices) = error("Not implemented")
158+
function _blocks(a::AbstractUnitRange, indices::AbstractUnitRange)
159+
return findblock(a, first(indices)):findblock(a, last(indices))
160+
end
161+
function _blocks(a::AbstractUnitRange, indices::BlockRange)
162+
return indices
163+
end
164+
165+
# Slice `a` by `I`, returning a:
166+
# `BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}`
167+
# with the `BlockIndex{1}` corresponding to each value of `I`.
168+
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:Integer})
169+
return mortar(
170+
map(blocks(blockedunitrange_getindices(a, I))) do r
171+
bi_first = findblockindex(a, first(r))
172+
bi_last = findblockindex(a, last(r))
173+
@assert block(bi_first) == block(bi_last)
174+
return block(bi_first)[blockindex(bi_first):blockindex(bi_last)]
175+
end,
176+
)
177+
end
178+
179+
# This handles non-blocked slices.
180+
# For example:
181+
# a = BlockSparseArray{Float64}([2, 2, 2, 2])
182+
# I = BlockedVector(Block.(1:4), [2, 2])
183+
# @views a[I][Block(1)]
184+
to_blockindices(a::Base.OneTo{<:Integer}, I::UnitRange{<:Integer}) = I

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export BlockSparseArray,
1212
include("factorizations/svd.jl")
1313

1414
# possible upstream contributions
15+
include("BlockArraysExtensions/blockedunitrange.jl")
1516
include("BlockArraysExtensions/BlockArraysExtensions.jl")
1617

1718
# interface functions that don't have to specialize

0 commit comments

Comments
 (0)