Skip to content

Commit 3c7b830

Browse files
committed
Add block iterator for TensorMap
1 parent c34f460 commit 3c7b830

File tree

3 files changed

+53
-27
lines changed

3 files changed

+53
-27
lines changed

src/TensorKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ include("spaces/vectorspaces.jl")
185185
#-------------------------------------
186186
# general definitions
187187
include("tensors/abstracttensor.jl")
188-
# include("tensors/tensortreeiterator.jl")
188+
include("tensors/blockiterator.jl")
189189
include("tensors/tensor.jl")
190190
include("tensors/adjoint.jl")
191191
include("tensors/linalg.jl")

src/tensors/blockiterator.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
struct BlockIterator{T<:AbstractTensorMap,S}
3+
4+
Iterator over the blocks of type `T`, possibly holding some pre-computed data of type `S`
5+
"""
6+
struct BlockIterator{T<:AbstractTensorMap,S}
7+
t::T
8+
structure::S
9+
end
10+
11+
Base.IteratorSize(::BlockIterator) = Base.HasLength()
12+
Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
13+
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
14+
Base.length(iter::BlockIterator) = length(iter.structure)
15+
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)

src/tensors/tensor.jl

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -421,28 +421,35 @@ end
421421

422422
# Getting and setting the data at the block level
423423
#-------------------------------------------------
424-
function block(t::TensorMap, s::Sector)
425-
sectortype(t) == typeof(s) || throw(SectorMismatch())
426-
structure = fusionblockstructure(t).blockstructure
427-
(d₁, d₂), r = get(structure, s) do
424+
block(t::TensorMap, c::Sector) = blocks(t)[c]
425+
426+
blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)
427+
428+
function blocktype(::Type{TT}) where {TT<:TensorMap}
429+
A = storagetype(TT)
430+
T = eltype(A)
431+
return Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}
432+
end
433+
434+
function Base.iterate(iter::BlockIterator{<:TensorMap}, state...)
435+
next = iterate(iter.structure, state...)
436+
isnothing(next) && return next
437+
(c, (sz, r)), newstate = next
438+
return c => reshape(view(iter.t.data, r), sz), newstate
439+
end
440+
441+
function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector)
442+
sectortype(iter.t) === typeof(c) || throw(SectorMismatch())
443+
(d₁, d₂), r = get(iter.structure, c) do
428444
# is s is not a key, at least one of the two dimensions will be zero:
429445
# it then does not matter where exactly we construct a view in `t.data`,
430446
# as it will have length zero anyway
431-
d₁′ = blockdim(codomain(t), s)
432-
d₂′ = blockdim(domain(t), s)
447+
d₁′ = blockdim(codomain(iter.t), c)
448+
d₂′ = blockdim(domain(iter.t), c)
433449
l = d₁′ * d₂′
434450
return (d₁′, d₂′), 1:l
435451
end
436-
return reshape(view(t.data, r), (d₁, d₂))
437-
end
438-
439-
function blocks(t::TensorMap)
440-
structure = fusionblockstructure(t).blockstructure
441-
iter = Base.Iterators.map(structure) do (c, ((d₁, d₂), r))
442-
b = reshape(view(t.data, r), (d₁, d₂))
443-
return c => b
444-
end
445-
return iter
452+
return reshape(view(iter.t.data, r), (d₁, d₂))
446453
end
447454

448455
# Indexing and getting and setting the data at the subblock level
@@ -559,18 +566,22 @@ function Base.show(io::IO, t::TensorMap)
559566
if sectortype(t) == Trivial
560567
Base.print_array(io, t[])
561568
println(io)
562-
elseif FusionStyle(sectortype(t)) isa UniqueFusion
563-
for (f₁, f₂) in fusiontrees(t)
564-
println(io, "* Data for sector ", f₁.uncoupled, "", f₂.uncoupled, ":")
565-
Base.print_array(io, t[f₁, f₂])
566-
println(io)
567-
end
568569
else
569-
for (f₁, f₂) in fusiontrees(t)
570-
println(io, "* Data for fusiontree ", f₁, "", f₂, ":")
571-
Base.print_array(io, t[f₁, f₂])
572-
println(io)
570+
for (c, b) in blocks(t)
571+
println(io, "* Data for sector $c:")
572+
Base.print_array(io, b)
573573
end
574+
# for (f₁, f₂) in fusiontrees(t)
575+
# println(io, "* Data for sector ", f₁.uncoupled, " ← ", f₂.uncoupled, ":")
576+
# Base.print_array(io, t[f₁, f₂])
577+
# println(io)
578+
# end
579+
# else
580+
# for (f₁, f₂) in fusiontrees(t)
581+
# println(io, "* Data for fusiontree ", f₁, " ← ", f₂, ":")
582+
# Base.print_array(io, t[f₁, f₂])
583+
# println(io)
584+
# end
574585
end
575586
end
576587

0 commit comments

Comments
 (0)