Skip to content

Commit a3307f7

Browse files
committed
make blocks iterator adn use where possible
1 parent 7fde8a2 commit a3307f7

File tree

7 files changed

+67
-110
lines changed

7 files changed

+67
-110
lines changed

src/TensorKit.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,23 +146,23 @@ const FusionTreeDict{K,V} = Dict{K,V}
146146
abstract type TensorException <: Exception end
147147

148148
# Exception type for all errors related to sector mismatch
149-
struct SectorMismatch{S<:Union{Nothing,String}} <: TensorException
149+
struct SectorMismatch{S<:Union{Nothing,AbstractString}} <: TensorException
150150
message::S
151151
end
152152
SectorMismatch() = SectorMismatch{Nothing}(nothing)
153153
Base.show(io::IO, ::SectorMismatch{Nothing}) = print(io, "SectorMismatch()")
154154
Base.show(io::IO, e::SectorMismatch) = print(io, "SectorMismatch(\"", e.message, "\")")
155155

156156
# Exception type for all errors related to vector space mismatch
157-
struct SpaceMismatch{S<:Union{Nothing,String}} <: TensorException
157+
struct SpaceMismatch{S<:Union{Nothing,AbstractString}} <: TensorException
158158
message::S
159159
end
160160
SpaceMismatch() = SpaceMismatch{Nothing}(nothing)
161161
Base.show(io::IO, ::SpaceMismatch{Nothing}) = print(io, "SpaceMismatch()")
162162
Base.show(io::IO, e::SpaceMismatch) = print(io, "SpaceMismatch(\"", e.message, "\")")
163163

164164
# Exception type for all errors related to invalid tensor index specification.
165-
struct IndexError{S<:Union{Nothing,String}} <: TensorException
165+
struct IndexError{S<:Union{Nothing,AbstractString}} <: TensorException
166166
message::S
167167
end
168168
IndexError() = IndexError{Nothing}(nothing)

src/tensors/abstracttensor.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,24 @@ end
286286
# tensor data: block access
287287
#---------------------------
288288
@doc """
289-
blocks(t::AbstractTensorMap) -> SectorDict{<:Sector,<:DenseMatrix}
289+
blocks(t::AbstractTensorMap)
290290
291291
Return an iterator over all blocks of a tensor, i.e. all coupled sectors and their
292-
corresponding blocks.
292+
corresponding matrix blocks.
293293
294294
See also [`block`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasblock`](@ref).
295295
"""
296-
blocks(t::AbstractTensorMap) = SectorDict(c => block(t, c) for c in blocksectors(t)) # TODO: make iterator
296+
function blocks(t::AbstractTensorMap)
297+
iter = Base.Iterators.map(blocksectors(t)) do c
298+
return c => block(t, c)
299+
end
300+
return iter
301+
end
297302

298303
@doc """
299-
block(t::AbstractTensorMap, c::Sector) -> DenseMatrix
304+
block(t::AbstractTensorMap, c::Sector)
300305
301-
Return the block of a tensor corresponding to a coupled sector `c`.
306+
Return the matrix block of a tensor corresponding to a coupled sector `c`.
302307
303308
See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasblock`](@ref).
304309
""" block

src/tensors/adjoint.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ storagetype(::Type{AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT}
2323

2424
# Blocks and subblocks
2525
#----------------------
26-
blocksectors(t::AdjointTensorMap) = blocksectors(parent(t))
2726
block(t::AdjointTensorMap, s::Sector) = block(parent(t), s)'
2827

28+
function blocks(t::AdjointTensorMap)
29+
iter = Base.Iterators.map(blocks(parent(t))) do (c, b)
30+
return c => b'
31+
end
32+
return iter
33+
end
34+
2935
function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂},
3036
f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {T,S,N₁,N₂,I}
3137
tp = parent(t)

src/tensors/linalg.jl

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::SectorDict
174174
blockdim(dom, c), b)
175175
for (c, b) in v), codom dom)
176176
end
177-
LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(blocks(t)))
177+
LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag last, blocks(t))
178178

179179
# In-place methods
180180
#------------------
@@ -184,8 +184,8 @@ LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(bl
184184
# Copy, adjoint and fill:
185185
function Base.copy!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap)
186186
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
187-
for c in blocksectors(tdst)
188-
copy!(StridedView(block(tdst, c)), StridedView(block(tsrc, c)))
187+
for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc))
188+
copy!(StridedView(bdst), StridedView(bsrc))
189189
end
190190
return tdst
191191
end
@@ -284,20 +284,44 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
284284
tA::AbstractTensorMap,
285285
tB::AbstractTensorMap, α=true, β=false)
286286
compose(space(tA), space(tB)) == space(tC) ||
287-
throw(SpaceMismatch("$(space(tC))$(space(tA)) * $(space(tB))"))
288-
289-
for c in blocksectors(tC)
290-
if hasblock(tA, c) # then also tB should have such a block
291-
A = block(tA, c)
292-
B = block(tB, c)
293-
C = block(tC, c)
294-
mul!(C, A, B, α, β)
295-
elseif β != one(β)
296-
rmul!(block(tC, c), β)
287+
throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))
288+
289+
iterC = blocks(tC)
290+
iterA = blocks(tA)
291+
iterB = blocks(tB)
292+
nextA = iterate(iterA)
293+
nextB = iterate(iterB)
294+
nextC = iterate(iterC)
295+
while !isnothing(nextC)
296+
(cC, C), stateC = nextC
297+
if !isnothing(nextA) && !isnothing(nextB)
298+
(cA, A), stateA = nextA
299+
(cB, B), stateB = nextB
300+
if cA == cC && cB == cC
301+
mul!(C, A, B, α, β)
302+
nextA = iterate(iterA, stateA)
303+
nextB = iterate(iterB, stateB)
304+
nextC = iterate(iterC, stateC)
305+
elseif cA < cC
306+
nextA = iterate(iterA, stateA)
307+
elseif cB < cC
308+
nextB = iterate(iterB, stateB)
309+
else
310+
if β != one(β)
311+
rmul!(C, β)
312+
end
313+
nextC = iterate(iterC, stateC)
314+
end
315+
else
316+
if β != one(β)
317+
rmul!(C, β)
318+
end
319+
nextC = iterate(iterC, stateC)
297320
end
298321
end
299322
return tC
300323
end
324+
301325
# TODO: consider spawning threads for different blocks, support backends
302326

303327
# TensorMap inverse

src/tensors/tensor.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,14 @@ function block(t::TensorMap, s::Sector)
426426
return reshape(view(t.data, r), (d₁, d₂))
427427
end
428428

429-
blocks(t::TensorMap) = SectorDict(c => block(t, c) for c in blocksectors(t)) # TODO: make iterator
429+
function blocks(t::TensorMap)
430+
structure = fusionblockstructure(t).blockstructure
431+
iter = Base.Iterators.map(structure) do (c, ((d₁, d₂), r))
432+
b = reshape(view(t.data, r), (d₁, d₂))
433+
return c => b
434+
end
435+
return iter
436+
end
430437

431438
# Indexing and getting and setting the data at the subblock level
432439
#-----------------------------------------------------------------

src/tensors/tensortreeiterator.jl

Lines changed: 0 additions & 85 deletions
This file was deleted.

test/planar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ function force_planar(tsrc::TensorMap{<:Any,ComplexSpace})
1717
tdst = TensorMap{scalartype(tsrc)}(undef,
1818
force_planar(codomain(tsrc))
1919
force_planar(domain(tsrc)))
20-
copyto!(blocks(tdst)[PlanarTrivial()], blocks(tsrc)[Trivial()])
20+
copyto!(block(tdst, PlanarTrivial()), block(tsrc, Trivial()))
2121
return tdst
2222
end
2323
function force_planar(tsrc::TensorMap{<:Any,<:GradedSpace})
2424
tdst = TensorMap{scalartype(tsrc)}(undef,
2525
force_planar(codomain(tsrc))
2626
force_planar(domain(tsrc)))
2727
for (c, b) in blocks(tsrc)
28-
copyto!(blocks(tdst)[c PlanarTrivial()], b)
28+
copyto!(block(tdst, c PlanarTrivial()), b)
2929
end
3030
return tdst
3131
end

0 commit comments

Comments
 (0)