Skip to content

Commit cdab73b

Browse files
committed
add SubblockIterator
1 parent 74f006b commit cdab73b

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

src/tensors/blockiterator.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,69 @@ function Base.show(io::IO, mime::MIME"text/plain", b::BlockIterator)
7676
show_blocks(io, mime, b)
7777
return nothing
7878
end
79+
80+
"""
81+
struct SubblockIterator{T <: AbstractTensorMap, S}
82+
83+
Iterator over the subblocks of a tensor of type `T`, possibly holding some pre-computed data of type `S`.
84+
This is typically constructed through of [`subblocks`](@ref).
85+
"""
86+
struct SubblockIterator{T <: AbstractTensorMap, S}
87+
t::T
88+
structure::S
89+
end
90+
91+
Base.IteratorSize(::SubblockIterator) = Base.HasLength()
92+
Base.IteratorEltype(::SubblockIterator) = Base.HasEltype()
93+
Base.eltype(::Type{<:SubblockIterator{T}}) where {T} = Pair{fusiontreetype(T), subblocktype(T)}
94+
Base.length(iter::SubblockIterator) = length(iter.structure)
95+
Base.isdone(iter::SubblockIterator, state...) = Base.isdone(iter.structure, state...)
96+
97+
# default implementation assumes `structure = fusiontrees(t)`
98+
function Base.iterate(iter::SubblockIterator, state...)
99+
next = Base.iterate(iter.structure, state...)
100+
isnothing(next) && return nothing
101+
(f₁, f₂), state = next
102+
@inbounds data = subblock(iter.t, (f₁, f₂))
103+
return (f₁, f₂) => data, state
104+
end
105+
106+
107+
function Base.showarg(io::IO, iter::SubblockIterator, toplevel::Bool)
108+
print(io, "subblocks(")
109+
Base.showarg(io, iter.t, false)
110+
print(io, ")")
111+
return nothing
112+
end
113+
function Base.summary(io::IO, iter::SubblockIterator)
114+
Base.showarg(io, iter, true)
115+
return nothing
116+
end
117+
118+
function show_subblocks(io::IO, mime::MIME"text/plain", iter::SubblockIterator)
119+
if FusionStyle(sectortype(iter.t)) isa UniqueFusion
120+
first = true
121+
for ((f₁, f₂), b) in iter
122+
first || print(io, "\n\n")
123+
print(io, " * ", f₁.uncoupled, "", f₂.uncoupled, " => ")
124+
show(io, mime, b)
125+
first = false
126+
end
127+
else
128+
first = true
129+
for ((f₁, f₂), b) in iter
130+
first || print(io, "\n\n")
131+
print(io, " * ", (f₁, f₂), " => ")
132+
show(io, mime, b)
133+
first = false
134+
end
135+
end
136+
return nothing
137+
end
138+
139+
function Base.show(io::IO, mime::MIME"text/plain", iter::SubblockIterator)
140+
summary(io, iter)
141+
println(io, ":")
142+
show_subblocks(io, mime, iter)
143+
return nothing
144+
end

0 commit comments

Comments
 (0)