diff --git a/Project.toml b/Project.toml index 0b305d80..4e4b03fc 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" @@ -37,6 +38,7 @@ LinearAlgebra = "1" MatrixAlgebraKit = "0.5.0" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" +Printf = "1" Random = "1" SafeTestsets = "0.1" ScopedValues = "1.3.0" diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index 4d1c5c9c..54a67258 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -118,15 +118,15 @@ blocks To access the data associated with a specific fusion tree pair, you can use: ```@docs -Base.getindex(::TensorMap{T,S,N₁,N₂}, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector} -Base.setindex!(::TensorMap{T,S,N₁,N₂}, ::Any, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector} +Base.getindex(::AbstractTensorMap, ::FusionTree, ::FusionTree) +Base.setindex!(::AbstractTensorMap, ::Any, ::FusionTree, ::FusionTree) ``` For a tensor `t` with `FusionType(sectortype(t)) isa UniqueFusion`, fusion trees are completely determined by the outcoming sectors, and the data can be accessed in a more straightforward way: ```@docs -Base.getindex(::TensorMap, ::Tuple{I,Vararg{I}}) where {I<:Sector} +Base.getindex(::AbstractTensorMap, ::Tuple{I,Vararg{I}}) where {I<:Sector} ``` For tensor `t` with `sectortype(t) == Trivial`, the data can be accessed and manipulated diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 6b59dac9..74d23c8c 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -57,7 +57,7 @@ export ℤ₂Space, ℤ₃Space, ℤ₄Space, U₁Space, CU₁Space, SU₂Space # Export tensor map methods export domain, codomain, numind, numout, numin, domainind, codomainind, allind export spacetype, storagetype, scalartype, tensormaptype -export blocksectors, blockdim, block, blocks +export blocksectors, blockdim, block, blocks, subblocks, subblock # random methods for constructor export randisometry, randisometry!, rand, rand!, randn, randn! @@ -127,6 +127,7 @@ using Base: @boundscheck, @propagate_inbounds, @constprop, tuple_type_head, tuple_type_tail, tuple_type_cons, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype using Base.Iterators: product, filter +using Printf: @sprintf using LinearAlgebra: LinearAlgebra, BlasFloat using LinearAlgebra: norm, dot, normalize, normalize!, tr, diff --git a/src/spaces/gradedspace.jl b/src/spaces/gradedspace.jl index 687c79ac..921bd0c6 100644 --- a/src/spaces/gradedspace.jl +++ b/src/spaces/gradedspace.jl @@ -197,21 +197,75 @@ function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I <: Sector ) end -function Base.show(io::IO, V::GradedSpace{I}) where {I <: Sector} - print(io, type_repr(typeof(V)), "(") - separator = "" - comma = ", " - io2 = IOContext(io, :typeinfo => I) - for c in sectors(V) - if isdual(V) - print(io2, separator, dual(c), "=>", dim(V, c)) - else - print(io2, separator, c, "=>", dim(V, c)) - end - separator = comma +Base.summary(io::IO, V::GradedSpace) = print(io, type_repr(typeof(V))) + +function Base.show(io::IO, V::GradedSpace) + pre = (get(io, :typeinfo, Any)::DataType == typeof(V) ? "" : type_repr(typeof(V))) + + io = IOContext(io, :typeinfo => Pair{sectortype(V), Int}) + + pre *= "(" + if isdual(V) + post = ")'" + V = dual(V) + else + post = ")" + end + hdots = " \u2026 " + sep = ", " + sepsize = length(sep) + + limited = get(io, :limit, false)::Bool + screenwidth = limited ? displaysize(io)[2] : typemax(Int) + screenwidth -= length(pre) + length(post) + + cs = reshape(collect([c => dim(V, c) for c in sectors(V)]), 1, :) + ncols = length(cs) + + maxpossiblecols = screenwidth ÷ (1 + sepsize) + if ncols > maxpossiblecols + cols = [1:(maxpossiblecols - 1); (ncols - maxpossiblecols + 1):ncols] + else + cols = collect(1:ncols) + end + + A = Base.alignment(io, cs, [1], cols, screenwidth, screenwidth, sepsize, ncols) + + if ncols <= length(A) # fits on screen + print(io, pre) + Base.print_matrix_row(io, cs, A, 1, cols, sep, ncols) + print(io, post) + else # doesn't fit on screen + half = (screenwidth - length(hdots) + 1) ÷ 2 + 1 + Ralign = reverse(Base.alignment(io, cs, [1], reverse(cols), half, half, sepsize, ncols)) + half = screenwidth - sum(map(sum, Ralign)) - (length(Ralign) - 1) * sepsize - length(hdots) + Lalign = Base.alignment(io, cs, [1], cols, half, half, sepsize, ncols) + print(io, pre) + Base.print_matrix_row(io, cs, Lalign, 1, cols[1:length(Lalign)], sep, ncols) + print(io, hdots) + Base.print_matrix_row(io, cs, Ralign, 1, (length(cs) - length(Ralign)) .+ cols, sep, length(cs)) + print(io, post) end - print(io, ")") - V.dual && print(io, "'") + + return nothing +end + +function Base.show(io::IO, ::MIME"text/plain", V::GradedSpace) + # print small summary, e.g. + # d-element Vect[I] or d-element dual(Vect[I]) + d = reduceddim(V) + print(io, d, "-element ") + isdual(V) && print(io, "dual(") + print(io, type_repr(typeof(V))) + isdual(V) && print(io, ")") + + compact = get(io, :compact, false)::Bool + (iszero(d) || compact) && return nothing + + # print detailed sector information + print(io, ":\n ") + ioc = IOContext(io, :typeinfo => typeof(V)) + show(ioc, V) return nothing end diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 9d6bae66..a06fb36a 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -250,6 +250,12 @@ Return an iterator over all splitting - fusion tree pairs of a tensor. """ fusiontrees(t::AbstractTensorMap) = fusionblockstructure(t).fusiontreelist +fusiontreetype(t::AbstractTensorMap) = fusiontreetype(typeof(t)) +function fusiontreetype(::Type{T}) where {T <: AbstractTensorMap} + I = sectortype(T) + return Tuple{fusiontreetype(I, numout(T)), fusiontreetype(I, numin(T))} +end + # auxiliary function @inline function trivial_fusiontree(t::AbstractTensorMap) sectortype(t) === Trivial || @@ -295,6 +301,145 @@ function blocktype(::Type{T}) where {T <: AbstractTensorMap} return Core.Compiler.return_type(block, Tuple{T, sectortype(T)}) end +# tensor data: subblock access +# ---------------------------- +@doc """ + subblocks(t::AbstractTensorMap) + +Return an iterator over all subblocks of a tensor, i.e. all fusiontrees and their +corresponding tensor subblocks. + +See also [`subblock`](@ref), [`fusiontrees`](@ref), and [`hassubblock`](@ref). +""" +subblocks(t::AbstractTensorMap) = SubblockIterator(t, fusiontrees(t)) + +const _doc_subblock = """ +Return a view into the data of `t` corresponding to the splitting - fusion tree pair +`(f₁, f₂)`. In particular, this is an `AbstractArray{T}` with `T = scalartype(t)`, of size +`(dims(codomain(t), f₁.uncoupled)..., dims(codomain(t), f₂.uncoupled)...)`. + +Whenever `FusionStyle(sectortype(t)) isa UniqueFusion` , it is also possible to provide only +the external `sectors`, in which case the fusion tree pair will be constructed automatically. +""" + +@doc """ + subblock(t::AbstractTensorMap, (f₁, f₂)::Tuple{FusionTree,FusionTree}) + subblock(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}}) + +$_doc_subblock + +In general, new tensor types should provide an implementation of this function for the +fusion tree signature. + +See also [`subblocks`](@ref) and [`fusiontrees`](@ref). +""" subblock + +Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} + # input checking + I === sectortype(t) || throw(SectorMismatch("Not a valid sectortype for this tensor.")) + FusionStyle(I) isa UniqueFusion || + throw(SectorMismatch("Indexing with sectors is only possible for unique fusion styles.")) + length(sectors) == numind(t) || throw(ArgumentError("invalid number of sectors")) + + # convert to fusiontrees + s₁ = TupleTools.getindices(sectors, codomainind(t)) + s₂ = map(dual, TupleTools.getindices(sectors, domainind(t))) + c1 = length(s₁) == 0 ? unit(I) : (length(s₁) == 1 ? s₁[1] : first(⊗(s₁...))) + @boundscheck begin + hassector(codomain(t), s₁) && hassector(domain(t), s₂) || throw(BoundsError(t, sectors)) + c2 = length(s₂) == 0 ? unit(I) : (length(s₂) == 1 ? s₂[1] : first(⊗(s₂...))) + c2 == c1 || throw(SectorMismatch("Not a valid fusion channel for this tensor")) + end + f₁ = FusionTree(s₁, c1, map(isdual, tuple(codomain(t)...))) + f₂ = FusionTree(s₂, c1, map(isdual, tuple(domain(t)...))) + return @inbounds subblock(t, (f₁, f₂)) +end +Base.@propagate_inbounds function subblock(t::AbstractTensorMap, sectors::Tuple) + return subblock(t, map(Base.Fix1(convert, sectortype(t)), sectors)) +end +# attempt to provide better error messages +function subblock(t::AbstractTensorMap, (f₁, f₂)::Tuple{FusionTree, FusionTree}) + (sectortype(t)) == sectortype(f₁) == sectortype(f₂) || + throw(SectorMismatch("Not a valid sectortype for this tensor.")) + numout(t) == length(f₁) && numin(t) == length(f₂) || + throw(DimensionMismatch("Invalid number of fusiontree legs for this tensor.")) + throw(MethodError(subblock, (t, (f₁, f₂)))) +end + +@doc """ + subblocktype(t) + subblocktype(::Type{T}) + +Return the type of the tensor subblocks of a tensor. +""" subblocktype + +function subblocktype(::Type{T}) where {T <: AbstractTensorMap} + return Core.Compiler.return_type(subblock, Tuple{T, fusiontreetype(T)}) +end +subblocktype(t) = subblocktype(typeof(t)) +subblocktype(T::Type) = throw(MethodError(subblocktype, (T,))) + +# Indexing behavior +# ----------------- +@doc """ + Base.view(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}}) + Base.view(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) + +$_doc_subblock + +!!! note + Contrary to Julia's array types, the default indexing behavior is to return a view + into the tensor data. As a result, `getindex` and `view` have the same behavior. + +See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref). +""" Base.view(::AbstractTensorMap, ::Tuple{I, Vararg{I}}) where {I <: Sector}, + Base.view(::AbstractTensorMap, ::FusionTree, ::FusionTree) + +@inline Base.view(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} = + subblock(t, sectors) +@inline Base.view(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) = + subblock(t, (f₁, f₂)) + +# by default getindex returns views +@doc """ + Base.getindex(t::AbstractTensorMap, sectors::Tuple{Vararg{Sector}}) + t[sectors] + Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) + t[f₁, f₂] + +$_doc_subblock + +!!! warning + Contrary to Julia's array types, the default behavior is to return a view into the tensor data. + As a result, modifying the view will modify the data in the tensor. + +See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref). +""" Base.getindex(::AbstractTensorMap, ::Tuple{I, Vararg{I}}) where {I <: Sector}, + Base.getindex(::AbstractTensorMap, ::FusionTree, ::FusionTree) + +@inline Base.getindex(t::AbstractTensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} = + view(t, sectors) +@inline Base.getindex(t::AbstractTensorMap, f₁::FusionTree, f₂::FusionTree) = + view(t, f₁, f₂) + +@doc """ + Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{Vararg{Sector}}) + t[sectors] = v + Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree) + t[f₁, f₂] = v + +Copies `v` into the data slice of `t` corresponding to the splitting - fusion tree pair `(f₁, f₂)`. +By default, `v` can be any object that can be copied into the view associated with `t[f₁, f₂]`. + +See also [`subblock`](@ref), [`subblocks`](@ref) and [`fusiontrees`](@ref). +""" Base.setindex!(::AbstractTensorMap, ::Any, ::Tuple{I, Vararg{I}}) where {I <: Sector}, + Base.setindex!(::AbstractTensorMap, ::Any, ::FusionTree, ::FusionTree) + +@inline Base.setindex!(t::AbstractTensorMap, v, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} = + copy!(view(t, sectors), v) +@inline Base.setindex!(t::AbstractTensorMap, v, f₁::FusionTree, f₂::FusionTree) = + copy!(view(t, (f₁, f₂)), v) + # Derived indexing behavior for tensors with trivial symmetry #------------------------------------------------------------- using TensorKit.Strided: SliceIndex @@ -480,6 +625,10 @@ end # Conversion to Array: #---------------------- +Base.ndims(t::AbstractTensorMap) = numind(t) +Base.size(t::AbstractTensorMap) = ntuple(Base.Fix1(size, t), numind(t)) +Base.size(t::AbstractTensorMap, i) = dim(space(t, i)) + # probably not optimized for speed, only for checking purposes function Base.convert(::Type{Array}, t::AbstractTensorMap) I = sectortype(t) @@ -499,3 +648,38 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap) return A end end + +# Show and friends +# ---------------- + +function Base.dims2string(V::HomSpace) + str_cod = numout(V) == 0 ? "()" : join(dim.(codomain(V)), '×') + str_dom = numin(V) == 0 ? "()" : join(dim.(domain(V)), '×') + return str_cod * "←" * str_dom +end + +function Base.summary(io::IO, t::AbstractTensorMap) + V = space(t) + print(io, Base.dims2string(V), " ") + Base.showarg(io, t, true) + return nothing +end + +# Human-readable: +function Base.show(io::IO, ::MIME"text/plain", t::AbstractTensorMap) + # 1) show summary: typically d₁×d₂×… ← d₃×d₄×… $(typeof(t)): + summary(io, t) + println(io, ":") + + # 2) show spaces + # println(io, " space(t):") + println(io, " codomain: ", codomain(t)) + println(io, " domain: ", domain(t)) + + # 3) [optional]: show data + get(io, :compact, true) && return nothing + ioc = IOContext(io, :typeinfo => sectortype(t)) + println(io, "\n\n blocks(t):") + show_blocks(io, MIME"text/plain"(), blocks(t)) + return nothing +end diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index a016e0b4..2a229f2c 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -42,45 +42,17 @@ function Base.getindex(iter::BlockIterator{<:AdjointTensorMap}, c::Sector) return adjoint(Base.getindex(iter.structure, c)) end -function Base.getindex( - t::AdjointTensorMap{T, S, N₁, N₂}, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂} - ) where {T, S, N₁, N₂, I} +Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tuple{FusionTree, FusionTree}) tp = parent(t) - subblock = getindex(tp, f₂, f₁) - return permutedims(conj(subblock), (domainind(tp)..., codomainind(tp)...)) -end -function Base.setindex!( - t::AdjointTensorMap{T, S, N₁, N₂}, v, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂} - ) where {T, S, N₁, N₂, I} - return copy!(getindex(t, f₁, f₂), v) + data = subblock(tp, (f₂, f₁)) + return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...)) end # Show #------ -function Base.summary(io::IO, t::AdjointTensorMap) - return print(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), ")") -end -function Base.show(io::IO, t::AdjointTensorMap) - if get(io, :compact, false) - print(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), ")") - return - end - println(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), "):") - if sectortype(t) === Trivial - Base.print_array(io, t[]) - println(io) - elseif FusionStyle(sectortype(t)) isa UniqueFusion - for (f₁, f₂) in fusiontrees(t) - println(io, "* Data for sector ", f₁.uncoupled, " ← ", f₂.uncoupled, ":") - Base.print_array(io, t[f₁, f₂]) - println(io) - end - else - for (f₁, f₂) in fusiontrees(t) - println(io, "* Data for fusiontree ", f₁, " ← ", f₂, ":") - Base.print_array(io, t[f₁, f₂]) - println(io) - end - end +function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool) + print(io, "adjoint(") + Base.showarg(io, parent(t), false) + print(io, ")") return nothing end diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index 409e9e63..e5f5e562 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -44,3 +44,101 @@ function foreachblock(f, t::AbstractTensorMap; scheduler = nothing) end return nothing end + +function show_blocks(io, mime::MIME"text/plain", iter) + first = true + for (c, b) in iter + first || print(io, "\n\n") + print(io, " * ", c, " => ") + show(io, mime, b) + first = false + end + return nothing +end + +function show_blocks(io, iter) + print(io, "(") + join(io, iter, ", ") + print(io, ")") + return nothing +end + +function Base.summary(io::IO, b::BlockIterator) + print(io, "blocks(") + Base.showarg(io, b.t, false) + print(io, ")") + return nothing +end + +function Base.show(io::IO, mime::MIME"text/plain", b::BlockIterator) + summary(io, b) + println(io, ":") + show_blocks(io, mime, b) + return nothing +end + +""" + struct SubblockIterator{T <: AbstractTensorMap, S} + +Iterator over the subblocks of a tensor of type `T`, possibly holding some pre-computed data of type `S`. +This is typically constructed through of [`subblocks`](@ref). +""" +struct SubblockIterator{T <: AbstractTensorMap, S} + t::T + structure::S +end + +Base.IteratorSize(::SubblockIterator) = Base.HasLength() +Base.IteratorEltype(::SubblockIterator) = Base.HasEltype() +Base.eltype(::Type{<:SubblockIterator{T}}) where {T} = Pair{fusiontreetype(T), subblocktype(T)} +Base.length(iter::SubblockIterator) = length(iter.structure) +Base.isdone(iter::SubblockIterator, state...) = Base.isdone(iter.structure, state...) + +# default implementation assumes `structure = fusiontrees(t)` +function Base.iterate(iter::SubblockIterator, state...) + next = Base.iterate(iter.structure, state...) + isnothing(next) && return nothing + (f₁, f₂), state = next + @inbounds data = subblock(iter.t, (f₁, f₂)) + return (f₁, f₂) => data, state +end + + +function Base.showarg(io::IO, iter::SubblockIterator, toplevel::Bool) + print(io, "subblocks(") + Base.showarg(io, iter.t, false) + print(io, ")") + return nothing +end +function Base.summary(io::IO, iter::SubblockIterator) + Base.showarg(io, iter, true) + return nothing +end + +function show_subblocks(io::IO, mime::MIME"text/plain", iter::SubblockIterator) + if FusionStyle(sectortype(iter.t)) isa UniqueFusion + first = true + for ((f₁, f₂), b) in iter + first || print(io, "\n\n") + print(io, " * ", f₁.uncoupled, " ← ", f₂.uncoupled, " => ") + show(io, mime, b) + first = false + end + else + first = true + for ((f₁, f₂), b) in iter + first || print(io, "\n\n") + print(io, " * ", (f₁, f₂), " => ") + show(io, mime, b) + first = false + end + end + return nothing +end + +function Base.show(io::IO, mime::MIME"text/plain", iter::SubblockIterator) + summary(io, iter) + println(io, ":") + show_subblocks(io, mime, iter) + return nothing +end diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index df04cb9b..904e40cb 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -166,19 +166,12 @@ end # Indexing and getting and setting the data at the subblock level #----------------------------------------------------------------- -@inline function Base.getindex( - d::DiagonalTensorMap, f₁::FusionTree{I, 1}, f₂::FusionTree{I, 1} +Base.@propagate_inbounds function subblock( + d::DiagonalTensorMap, (f₁, f₂)::Tuple{FusionTree{I, 1}, FusionTree{I, 1}} ) where {I <: Sector} s = f₁.uncoupled[1] s == f₁.coupled == f₂.uncoupled[1] == f₂.coupled || throw(SectorMismatch()) return block(d, s) - # TODO: do we want a StridedView here? Then we need to allocate a new matrix. -end - -function Base.setindex!( - d::DiagonalTensorMap, v, f₁::FusionTree{I, 1}, f₂::FusionTree{I, 1} - ) where {I <: Sector} - return copy!(getindex(d, f₁, f₂), v) end function Base.getindex(d::DiagonalTensorMap) @@ -335,23 +328,13 @@ end # Show #------ -function Base.summary(io::IO, t::DiagonalTensorMap) - return print(io, "DiagonalTensorMap(", space(t), ")") +function type_repr(::Type{DiagonalTensorMap{T, S, A}}) where {T, S, A} + return "DiagonalTensorMap{$T, $(type_repr(S)), $A}" end -function Base.show(io::IO, t::DiagonalTensorMap) - summary(io, t) - get(io, :compact, false) && return nothing - println(io, ":") - - if sectortype(t) == Trivial - Base.print_array(io, Diagonal(t.data)) - println(io) - else - for (c, b) in blocks(t) - println(io, "* Data for sector ", c, ":") - Base.print_array(io, b) - println(io) - end - end +function Base.showarg(io::IO, t::DiagonalTensorMap, toplevel::Bool) + !toplevel && print(io, "::") + print(io, type_repr(typeof(t))) return nothing end +Base.show(io::IO, t::DiagonalTensorMap) = + print(io, type_repr(typeof(t)), "(", t.data, ", ", space(t, 1), ")") diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 3f28f168..f8f04e99 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -464,26 +464,10 @@ function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector) return reshape(view(iter.t.data, r), (d₁, d₂)) end -# Indexing and getting and setting the data at the subblock level -#----------------------------------------------------------------- -""" - Base.getindex(t::TensorMap{T,S,N₁,N₂,I}, - f₁::FusionTree{I,N₁}, - f₂::FusionTree{I,N₂}) where {T,SN₁,N₂,I<:Sector} - -> StridedViews.StridedView - t[f₁, f₂] - -Return a view into the data slice of `t` corresponding to the splitting - fusion tree pair -`(f₁, f₂)`. In particular, if `f₁.coupled == f₂.coupled == c`, then a -`StridedViews.StridedView` of size -`(dims(codomain(t), f₁.uncoupled)..., dims(domain(t), f₂.uncoupled))` is returned which -represents the slice of `block(t, c)` whose row indices correspond to `f₁.uncoupled` and -column indices correspond to `f₂.uncoupled`. - -See also [`Base.setindex!(::TensorMap{T,S,N₁,N₂}, ::Any, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}`](@ref) -""" -@inline function Base.getindex( - t::TensorMap{T, S, N₁, N₂}, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂} +# Getting and setting the data at the subblock level +# -------------------------------------------------- +function subblock( + t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}} ) where {T, S, N₁, N₂, I <: Sector} structure = fusionblockstructure(t) @boundscheck begin @@ -495,9 +479,10 @@ See also [`Base.setindex!(::TensorMap{T,S,N₁,N₂}, ::Any, ::FusionTree{I,N₁ return StridedView(t.data, sz, str, offset) end end + # The following is probably worth special casing for trivial tensors -@inline function Base.getindex( - t::TensorMap{T, S, N₁, N₂}, f₁::FusionTree{Trivial, N₁}, f₂::FusionTree{Trivial, N₂} +@inline function subblock( + t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{Trivial, N₁}, FusionTree{Trivial, N₂}} ) where {T, S, N₁, N₂} @boundscheck begin sectortype(t) == Trivial || throw(SectorMismatch()) @@ -505,95 +490,21 @@ end return sreshape(StridedView(t.data), (dims(codomain(t))..., dims(domain(t))...)) end -""" - Base.setindex!(t::TensorMap{T,S,N₁,N₂,I}, - v, - f₁::FusionTree{I,N₁}, - f₂::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector} - t[f₁, f₂] = v - -Copies `v` into the data slice of `t` corresponding to the splitting - fusion tree pair -`(f₁, f₂)`. Here, `v` can be any object that can be copied into a `StridedViews.StridedView` -of size `(dims(codomain(t), f₁.uncoupled)..., dims(domain(t), f₂.uncoupled))` using -`Base.copy!`. - -See also [`Base.getindex(::TensorMap{T,S,N₁,N₂}, ::FusionTree{I,N₁}, ::FusionTree{I,N₂}) where {T,S,N₁,N₂,I<:Sector}`](@ref) -""" -@propagate_inbounds function Base.setindex!( - t::TensorMap{T, S, N₁, N₂}, v, f₁::FusionTree{I, N₁}, f₂::FusionTree{I, N₂} - ) where {T, S, N₁, N₂, I <: Sector} - return copy!(getindex(t, f₁, f₂), v) -end - -""" - Base.getindex(t::TensorMap - sectors::NTuple{N₁+N₂,I}) where {N₁,N₂,I<:Sector} - -> StridedViews.StridedView - t[sectors] - -Return a view into the data slice of `t` corresponding to the splitting - fusion tree pair -with combined uncoupled charges `sectors`. In particular, if `sectors == (s₁..., s₂...)` -where `s₁` and `s₂` correspond to the uncoupled charges in the codomain and domain -respectively, then a `StridedViews.StridedView` of size -`(dims(codomain(t), s₁)..., dims(domain(t), s₂))` is returned. - -This method is only available for the case where `FusionStyle(I) isa UniqueFusion`, -since it assumes a uniquely defined coupled charge. -""" -@inline function Base.getindex(t::TensorMap, sectors::Tuple{I, Vararg{I}}) where {I <: Sector} - I === sectortype(t) || throw(SectorMismatch("Not a valid sectortype for this tensor.")) - FusionStyle(I) isa UniqueFusion || - throw(SectorMismatch("Indexing with sectors only possible if unique fusion")) - length(sectors) == numind(t) || - throw(ArgumentError("Number of sectors does not match.")) - s₁ = TupleTools.getindices(sectors, codomainind(t)) - s₂ = map(dual, TupleTools.getindices(sectors, domainind(t))) - c1 = length(s₁) == 0 ? unit(I) : (length(s₁) == 1 ? s₁[1] : first(⊗(s₁...))) - @boundscheck begin - c2 = length(s₂) == 0 ? unit(I) : (length(s₂) == 1 ? s₂[1] : first(⊗(s₂...))) - c2 == c1 || throw(SectorMismatch("Not a valid sector for this tensor")) - hassector(codomain(t), s₁) && hassector(domain(t), s₂) - end - f₁ = FusionTree(s₁, c1, map(isdual, tuple(codomain(t)...))) - f₂ = FusionTree(s₂, c1, map(isdual, tuple(domain(t)...))) - @inbounds begin - return t[f₁, f₂] - end -end -@propagate_inbounds function Base.getindex(t::TensorMap, sectors::Tuple) - return t[map(sectortype(t), sectors)] -end - # Show #------ -function Base.summary(io::IO, t::TensorMap) - return print(io, "TensorMap(", space(t), ")") +function type_repr(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A} + return "TensorMap{$T, $(type_repr(S)), $N₁, $N₂, $A}" end -function Base.show(io::IO, t::TensorMap) - if get(io, :compact, false) - print(io, "TensorMap(", space(t), ")") - return - end - println(io, "TensorMap(", space(t), "):") - if sectortype(t) == Trivial - Base.print_array(io, t[]) - println(io) - elseif FusionStyle(sectortype(t)) isa UniqueFusion - for (f₁, f₂) in fusiontrees(t) - println(io, "* Data for sector ", f₁.uncoupled, " ← ", f₂.uncoupled, ":") - Base.print_array(io, t[f₁, f₂]) - println(io) - end - else - for (f₁, f₂) in fusiontrees(t) - println(io, "* Data for fusiontree ", f₁, " ← ", f₂, ":") - Base.print_array(io, t[f₁, f₂]) - println(io) - end - end + +function Base.showarg(io::IO, t::TensorMap, toplevel::Bool) + !toplevel && print(io, "::") + print(io, type_repr(typeof(t))) return nothing end +Base.show(io::IO, t::TensorMap) = + print(io, type_repr(typeof(t)), "(", t.data, ", ", space(t), ")") + # Complex, real and imaginary parts #----------------------------------- for f in (:real, :imag, :complex)