Skip to content

Commit b46bfc4

Browse files
authored
Show improvements + subblocks iterator (#304)
1 parent 9ba3b6c commit b46bfc4

File tree

9 files changed

+344
-184
lines changed

9 files changed

+344
-184
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1010
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
1111
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
12+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1415
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
@@ -37,6 +38,7 @@ LinearAlgebra = "1"
3738
MatrixAlgebraKit = "0.5.0"
3839
OhMyThreads = "0.8.0"
3940
PackageExtensionCompat = "1"
41+
Printf = "1"
4042
Random = "1"
4143
SafeTestsets = "0.1"
4244
ScopedValues = "1.3.0"

docs/src/lib/tensors.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,15 @@ blocks
118118

119119
To access the data associated with a specific fusion tree pair, you can use:
120120
```@docs
121-
Base.getindex(::TensorMap{T,S,N₁,N₂}, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}
122-
Base.setindex!(::TensorMap{T,S,N₁,N₂}, ::Any, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}
121+
Base.getindex(::AbstractTensorMap, ::FusionTree, ::FusionTree)
122+
Base.setindex!(::AbstractTensorMap, ::Any, ::FusionTree, ::FusionTree)
123123
```
124124

125125
For a tensor `t` with `FusionType(sectortype(t)) isa UniqueFusion`, fusion trees are
126126
completely determined by the outcoming sectors, and the data can be accessed in a more
127127
straightforward way:
128128
```@docs
129-
Base.getindex(::TensorMap, ::Tuple{I,Vararg{I}}) where {I<:Sector}
129+
Base.getindex(::AbstractTensorMap, ::Tuple{I,Vararg{I}}) where {I<:Sector}
130130
```
131131

132132
For tensor `t` with `sectortype(t) == Trivial`, the data can be accessed and manipulated

src/TensorKit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ export ℤ₂Space, ℤ₃Space, ℤ₄Space, U₁Space, CU₁Space, SU₂Space
5757
# Export tensor map methods
5858
export domain, codomain, numind, numout, numin, domainind, codomainind, allind
5959
export spacetype, storagetype, scalartype, tensormaptype
60-
export blocksectors, blockdim, block, blocks
60+
export blocksectors, blockdim, block, blocks, subblocks, subblock
6161

6262
# random methods for constructor
6363
export randisometry, randisometry!, rand, rand!, randn, randn!
@@ -127,6 +127,7 @@ using Base: @boundscheck, @propagate_inbounds, @constprop,
127127
tuple_type_head, tuple_type_tail, tuple_type_cons,
128128
SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype
129129
using Base.Iterators: product, filter
130+
using Printf: @sprintf
130131

131132
using LinearAlgebra: LinearAlgebra, BlasFloat
132133
using LinearAlgebra: norm, dot, normalize, normalize!, tr,

src/spaces/gradedspace.jl

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,21 +197,53 @@ function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I <: Sector
197197
)
198198
end
199199

200-
function Base.show(io::IO, V::GradedSpace{I}) where {I <: Sector}
201-
print(io, type_repr(typeof(V)), "(")
202-
separator = ""
203-
comma = ", "
204-
io2 = IOContext(io, :typeinfo => I)
205-
for c in sectors(V)
206-
if isdual(V)
207-
print(io2, separator, dual(c), "=>", dim(V, c))
208-
else
209-
print(io2, separator, c, "=>", dim(V, c))
210-
end
211-
separator = comma
200+
Base.summary(io::IO, V::GradedSpace) = print(io, type_repr(typeof(V)))
201+
202+
function Base.show(io::IO, V::GradedSpace)
203+
opn = (get(io, :typeinfo, Any)::DataType == typeof(V) ? "" : type_repr(typeof(V)))
204+
opn *= "("
205+
if isdual(V)
206+
cls = ")'"
207+
V = dual(V)
208+
else
209+
cls = ")"
210+
end
211+
212+
v = [c => dim(V, c) for c in sectors(V)]
213+
214+
# logic stolen from Base.show_vector
215+
limited = get(io, :limit, false)::Bool
216+
io = IOContext(io, :typeinfo => eltype(v))
217+
218+
if limited && length(v) > 20
219+
axs1 = axes(v, 1)
220+
f, l = first(axs1), last(axs1)
221+
Base.show_delim_array(io, v, opn, ",", "", false, f, f + 9)
222+
print(io, "")
223+
Base.show_delim_array(io, v, "", ",", cls, false, l - 9, l)
224+
else
225+
Base.show_delim_array(io, v, opn, ",", cls, false)
212226
end
213-
print(io, ")")
214-
V.dual && print(io, "'")
227+
return nothing
228+
end
229+
230+
function Base.show(io::IO, ::MIME"text/plain", V::GradedSpace)
231+
# print small summary, e.g.: Vect[I](…) of dim d
232+
d = dim(V)
233+
print(io, type_repr(typeof(d)), "(…)")
234+
isdual(V) && print(io, "'")
235+
print(io, " of dim ", d)
236+
237+
compact = get(io, :compact, false)::Bool
238+
(iszero(d) || compact) && return nothing
239+
240+
# print detailed sector information - hijack Base.Vector printing
241+
print(io, ":\n")
242+
isdual(V) && (V = dual(V))
243+
print_data = [c => dim(V, c) for c in sectors(V)]
244+
ioc = IOContext(io, :typeinfo => eltype(print_data))
245+
Base.print_matrix(ioc, print_data)
246+
215247
return nothing
216248
end
217249

src/tensors/abstracttensor.jl

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,12 @@ Return an iterator over all splitting - fusion tree pairs of a tensor.
250250
"""
251251
fusiontrees(t::AbstractTensorMap) = fusionblockstructure(t).fusiontreelist
252252

253+
fusiontreetype(t::AbstractTensorMap) = fusiontreetype(typeof(t))
254+
function fusiontreetype(::Type{T}) where {T <: AbstractTensorMap}
255+
I = sectortype(T)
256+
return Tuple{fusiontreetype(I, numout(T)), fusiontreetype(I, numin(T))}
257+
end
258+
253259
# auxiliary function
254260
@inline function trivial_fusiontree(t::AbstractTensorMap)
255261
sectortype(t) === Trivial ||
@@ -295,6 +301,126 @@ function blocktype(::Type{T}) where {T <: AbstractTensorMap}
295301
return Core.Compiler.return_type(block, Tuple{T, sectortype(T)})
296302
end
297303

304+
# tensor data: subblock access
305+
# ----------------------------
306+
@doc """
307+
subblocks(t::AbstractTensorMap)
308+
309+
Return an iterator over all subblocks of a tensor, i.e. all fusiontrees and their
310+
corresponding tensor subblocks.
311+
312+
See also [`subblock`](@ref), [`fusiontrees`](@ref), and [`hassubblock`](@ref).
313+
"""
314+
subblocks(t::AbstractTensorMap) = SubblockIterator(t, fusiontrees(t))
315+
316+
const _doc_subblock = """
317+
Return a view into the data of `t` corresponding to the splitting - fusion tree pair
318+
`(f₁, f₂)`. In particular, this is an `AbstractArray{T}` with `T = scalartype(t)`, of size
319+
`(dims(codomain(t), f₁.uncoupled)..., dims(codomain(t), f₂.uncoupled)...)`.
320+
321+
Whenever `FusionStyle(sectortype(t)) isa UniqueFusion` , it is also possible to provide only
322+
the external `sectors`, in which case the fusion tree pair will be constructed automatically.
323+
"""
324+
325+
@doc """
326+
subblock(t::AbstractTensorMap, (f₁, f₂)::Tuple{FusionTree,FusionTree})
327+
subblock(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}})
328+
329+
$_doc_subblock
330+
331+
In general, new tensor types should provide an implementation of this function for the
332+
fusion tree signature.
333+
334+
See also [`subblocks`](@ref) and [`fusiontrees`](@ref).
335+
""" subblock
336+
337+
Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector}
338+
# input checking
339+
I === sectortype(t) || throw(SectorMismatch("Not a valid sectortype for this tensor."))
340+
FusionStyle(I) isa UniqueFusion ||
341+
throw(SectorMismatch("Indexing with sectors is only possible for unique fusion styles."))
342+
length(sectors) == numind(t) || throw(ArgumentError("invalid number of sectors"))
343+
344+
# convert to fusiontrees
345+
s₁ = TupleTools.getindices(sectors, codomainind(t))
346+
s₂ = map(dual, TupleTools.getindices(sectors, domainind(t)))
347+
c1 = length(s₁) == 0 ? unit(I) : (length(s₁) == 1 ? s₁[1] : first((s₁...)))
348+
@boundscheck begin
349+
hassector(codomain(t), s₁) && hassector(domain(t), s₂) || throw(BoundsError(t, sectors))
350+
c2 = length(s₂) == 0 ? unit(I) : (length(s₂) == 1 ? s₂[1] : first((s₂...)))
351+
c2 == c1 || throw(SectorMismatch("Not a valid fusion channel for this tensor"))
352+
end
353+
f₁ = FusionTree(s₁, c1, map(isdual, tuple(codomain(t)...)))
354+
f₂ = FusionTree(s₂, c1, map(isdual, tuple(domain(t)...)))
355+
return @inbounds subblock(t, (f₁, f₂))
356+
end
357+
Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple)
358+
return subblock(t, map(Base.Fix1(convert, sectortype(t)), sectors))
359+
end
360+
# attempt to provide better error messages
361+
function subblock(t::AbstractTensorMap, (f₁, f₂)::Tuple{FusionTree, FusionTree})
362+
(sectortype(t)) == sectortype(f₁) == sectortype(f₂) ||
363+
throw(SectorMismatch("Not a valid sectortype for this tensor."))
364+
numout(t) == length(f₁) && numin(t) == length(f₂) ||
365+
throw(DimensionMismatch("Invalid number of fusiontree legs for this tensor."))
366+
throw(MethodError(subblock, (t, (f₁, f₂))))
367+
end
368+
369+
@doc """
370+
subblocktype(t)
371+
subblocktype(::Type{T})
372+
373+
Return the type of the tensor subblocks of a tensor.
374+
""" subblocktype
375+
376+
function subblocktype(::Type{T}) where {T <: AbstractTensorMap}
377+
return Core.Compiler.return_type(subblock, Tuple{T, fusiontreetype(T)})
378+
end
379+
subblocktype(t) = subblocktype(typeof(t))
380+
subblocktype(T::Type) = throw(MethodError(subblocktype, (T,)))
381+
382+
# Indexing behavior
383+
# -----------------
384+
# by default getindex returns views!
385+
@doc """
386+
Base.getindex(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}})
387+
t[sectors]
388+
Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree)
389+
t[f₁, f₂]
390+
391+
$_doc_subblock
392+
393+
!!! warning
394+
Contrary to Julia's array types, the default behavior is to return a view into the tensor data.
395+
As a result, modifying the view will modify the data in the tensor.
396+
397+
See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref).
398+
""" Base.getindex(::AbstractTensorMap, ::Tuple{I, Vararg{I}}) where {I <: Sector},
399+
Base.getindex(::AbstractTensorMap, ::FusionTree, ::FusionTree)
400+
401+
@inline Base.getindex(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} =
402+
subblock(t, sectors)
403+
@inline Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) =
404+
subblock(t, (f₁, f₂))
405+
406+
@doc """
407+
Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{Vararg{Sector}})
408+
t[sectors] = v
409+
Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree)
410+
t[f₁, f₂] = v
411+
412+
Copies `v` into the data slice of `t` corresponding to the splitting - fusion tree pair `(f₁, f₂)`.
413+
By default, `v` can be any object that can be copied into the view associated with `t[f₁, f₂]`.
414+
415+
See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref).
416+
""" Base.setindex!(::AbstractTensorMap, ::Any, ::Tuple{I, Vararg{I}}) where {I <: Sector},
417+
Base.setindex!(::AbstractTensorMap, ::Any, ::FusionTree, ::FusionTree)
418+
419+
@inline Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} =
420+
copy!(subblock(t, sectors), v)
421+
@inline Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree) =
422+
copy!(subblock(t, (f₁, f₂)), v)
423+
298424
# Derived indexing behavior for tensors with trivial symmetry
299425
#-------------------------------------------------------------
300426
using TensorKit.Strided: SliceIndex
@@ -499,3 +625,38 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
499625
return A
500626
end
501627
end
628+
629+
# Show and friends
630+
# ----------------
631+
632+
function Base.dims2string(V::HomSpace)
633+
str_cod = numout(V) == 0 ? "()" : join(dim.(codomain(V)), '×')
634+
str_dom = numin(V) == 0 ? "()" : join(dim.(domain(V)), '×')
635+
return str_cod * "" * str_dom
636+
end
637+
638+
function Base.summary(io::IO, t::AbstractTensorMap)
639+
V = space(t)
640+
print(io, Base.dims2string(V), " ")
641+
Base.showarg(io, t, true)
642+
return nothing
643+
end
644+
645+
# Human-readable:
646+
function Base.show(io::IO, ::MIME"text/plain", t::AbstractTensorMap)
647+
# 1) show summary: typically d₁×d₂×… ← d₃×d₄×… $(typeof(t)):
648+
summary(io, t)
649+
println(io, ":")
650+
651+
# 2) show spaces
652+
# println(io, " space(t):")
653+
println(io, " codomain: ", codomain(t))
654+
println(io, " domain: ", domain(t))
655+
656+
# 3) [optional]: show data
657+
get(io, :compact, true) && return nothing
658+
ioc = IOContext(io, :typeinfo => sectortype(t))
659+
println(io, "\n\n blocks: ")
660+
show_blocks(io, MIME"text/plain"(), blocks(t))
661+
return nothing
662+
end

src/tensors/adjoint.jl

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -42,45 +42,17 @@ function Base.getindex(iter::BlockIterator{<:AdjointTensorMap}, c::Sector)
4242
return adjoint(Base.getindex(iter.structure, c))
4343
end
4444

45-
function Base.getindex(
46-
t::AdjointTensorMap{T, S, N₁, N₂}, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂}
47-
) where {T, S, N₁, N₂, I}
45+
Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tuple{FusionTree, FusionTree})
4846
tp = parent(t)
49-
subblock = getindex(tp, f₂, f₁)
50-
return permutedims(conj(subblock), (domainind(tp)..., codomainind(tp)...))
51-
end
52-
function Base.setindex!(
53-
t::AdjointTensorMap{T, S, N₁, N₂}, v, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂}
54-
) where {T, S, N₁, N₂, I}
55-
return copy!(getindex(t, f₁, f₂), v)
47+
data = subblock(tp, (f₂, f₁))
48+
return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...))
5649
end
5750

5851
# Show
5952
#------
60-
function Base.summary(io::IO, t::AdjointTensorMap)
61-
return print(io, "AdjointTensorMap(", codomain(t), "", domain(t), ")")
62-
end
63-
function Base.show(io::IO, t::AdjointTensorMap)
64-
if get(io, :compact, false)
65-
print(io, "AdjointTensorMap(", codomain(t), "", domain(t), ")")
66-
return
67-
end
68-
println(io, "AdjointTensorMap(", codomain(t), "", domain(t), "):")
69-
if sectortype(t) === Trivial
70-
Base.print_array(io, t[])
71-
println(io)
72-
elseif FusionStyle(sectortype(t)) isa UniqueFusion
73-
for (f₁, f₂) in fusiontrees(t)
74-
println(io, "* Data for sector ", f₁.uncoupled, "", f₂.uncoupled, ":")
75-
Base.print_array(io, t[f₁, f₂])
76-
println(io)
77-
end
78-
else
79-
for (f₁, f₂) in fusiontrees(t)
80-
println(io, "* Data for fusiontree ", f₁, "", f₂, ":")
81-
Base.print_array(io, t[f₁, f₂])
82-
println(io)
83-
end
84-
end
53+
function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool)
54+
print(io, "adjoint(")
55+
Base.showarg(io, parent(t), false)
56+
print(io, ")")
8557
return nothing
8658
end

0 commit comments

Comments
 (0)