Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions docs/src/sumspaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@ These spaces are a natural extension of the `TensorKit` vector spaces, and you c
### `SumSpace`

In `BlockTensorKit`, we provide a type `SumSpace` that allows you to define such direct sums.
They can be defined either directly via the constructor, or by using the `⊕` operator.
They can be defined either directly via the constructor, or by using the `⊞` (`\boxplus<TAB>`) operator.
In order for the direct sum to be wll-defined, all components must have the same value of `isdual`.

Essentially, that is all there is to it, and you can now use these `SumSpace` objects much in the same way as you would use an `IndexSpace` object in `TensorKit`.
In particular, it adheres to the interface of `ElementarySpace`, which means that you can query the properties as you would expect.

!!! note

The operator `⊕` is used in both TensorKit and BlockTensorKit, and therefore it must be explicitly imported to avoid name clashes.
Both functions achieve almost the same thing, as `BlockTensorKit.` can be thought of as a _lazy_ version of `TensorKit.⊕`.
The notion of a direct sum of vector spaces is used in both TensorKit (`⊕` or `oplus`) and BlockTensorKit (`⊞` or `boxplus`).
Both functions achieve almost the same thing, and `BlockTensorKit.` can be thought of as a _lazy_ version of `TensorKit.⊕`.

```@repl sumspaces
using TensorKit, BlockTensorKit
using BlockTensorKit: ⊕
V = ℂ^1 ⊕ ℂ^2 ⊕ ℂ^3
ℂ^2 ⊕ (ℂ^2)' ⊕ ℂ^2 # error
V = ℂ^1 ⊞ ℂ^2 ⊞ ℂ^3
ℂ^2 ⊞ (ℂ^2)' ⊞ ℂ^2 # error
dim(V)
isdual(V)
isdual(V')
Expand All @@ -43,7 +42,7 @@ Because these objects are naturally `ElementarySpace` objects, they can be used
Additionally, when mixing spaces and their sumspaces, all components are promoted to `SumSpace` instances.

```@repl sumspaces
V1 = ℂ^1 ℂ^2 ℂ^3
V1 = ℂ^1 ℂ^2 ℂ^3
V2 = ℂ^2
V1 ⊗ V2 ⊗ V1' == V1 * V2 * V1' == ProductSpace(V1,V2,V1') == ProductSpace(V1,V2) ⊗ V1'
V1^3
Expand Down
2 changes: 1 addition & 1 deletion src/BlockTensorKit.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module BlockTensorKit

export SumSpace, ProductSumSpace
export SumSpace, ProductSumSpace, ⊞, boxplus
export eachspace, SumSpaceIndices, sumspacetype

export AbstractBlockTensorMap, BlockTensorMap, SparseBlockTensorMap
Expand Down
30 changes: 15 additions & 15 deletions src/linalg/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function MAK.check_input(::typeof(qr_full!), t::AbstractBlockTensorMap, QR, ::Ab
@check_scalar R t

# space checks
V_Q = TK.oplus(fuse(codomain(t)))
V_Q = (fuse(codomain(t)))
@check_space(Q, codomain(t) ← V_Q)
@check_space(R, V_Q ← domain(t))

Expand All @@ -81,7 +81,7 @@ end
MAK.check_input(::typeof(qr_full!), t::AbstractBlockTensorMap, QR, ::DiagonalAlgorithm) = error()

function MAK.initialize_output(::typeof(qr_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
V_Q = TK.oplus(fuse(codomain(t)))
V_Q = (fuse(codomain(t)))
Q = dense_similar(t, codomain(t) ← V_Q)
R = dense_similar(t, V_Q ← domain(t))
return Q, R
Expand Down Expand Up @@ -111,7 +111,7 @@ function MAK.check_input(::typeof(lq_full!), t::AbstractBlockTensorMap, LQ, ::Ab
@check_scalar Q t

# space checks
V_Q = TK.oplus(fuse(domain(t)))
V_Q = (fuse(domain(t)))
@check_space(L, codomain(t) ← V_Q)
@check_space(Q, V_Q ← domain(t))

Expand All @@ -120,7 +120,7 @@ end
MAK.check_input(::typeof(lq_full!), t::AbstractBlockTensorMap, LQ, ::DiagonalAlgorithm) = error()

function MAK.initialize_output(::typeof(lq_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
V_Q = TK.oplus(fuse(domain(t)))
V_Q = (fuse(domain(t)))
L = dense_similar(t, codomain(t) ← V_Q)
Q = dense_similar(t, V_Q ← domain(t))
return L, Q
Expand Down Expand Up @@ -151,7 +151,7 @@ function MAK.check_input(::typeof(MAK.left_orth_polar!), t::AbstractBlockTensorM
@check_scalar P t

# space checks
VW = TK.oplus(fuse(domain(t)))
VW = (fuse(domain(t)))
@check_space(W, codomain(t) ← VW)
@check_space(P, VW ← domain(t))

Expand All @@ -176,7 +176,7 @@ function MAK.check_input(::typeof(MAK.right_orth_polar!), t::AbstractBlockTensor
@check_scalar Wᴴ t

# space checks
VW = TK.oplus(fuse(codomain(t)))
VW = (fuse(codomain(t)))
@check_space(P, codomain(t) ← VW)
@check_space(Wᴴ, VW ← domain(t))

Expand Down Expand Up @@ -214,7 +214,7 @@ function MAK.check_input(::typeof(eigh_full!), t::AbstractBlockTensorMap, DV, ::
@check_scalar V t

# space checks
V_D = TK.oplus(fuse(domain(t)))
V_D = (fuse(domain(t)))
@check_space(D, V_D ← V_D)
@check_space(V, codomain(t) ← V_D)

Expand All @@ -225,14 +225,14 @@ MAK.check_input(::typeof(eigh_full!), t::AbstractBlockTensorMap, DV, ::DiagonalA
function MAK.check_input(::typeof(eigh_vals!), t::AbstractBlockTensorMap, D, ::AbstractAlgorithm)
@check_scalar D t real
@assert D isa DiagonalTensorMap
V_D = TK.oplus(fuse(domain(t)))
V_D = (fuse(domain(t)))
@check_space(D, V_D ← V_D)
return nothing
end
MAK.check_input(::typeof(eigh_vals!), t::AbstractBlockTensorMap, D, ::DiagonalAlgorithm) = error()

function MAK.initialize_output(::typeof(eigh_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
V_D = TK.oplus(fuse(domain(t)))
V_D = (fuse(domain(t)))
T = real(scalartype(t))
D = DiagonalTensorMap{T}(undef, V_D)
V = dense_similar(t, codomain(t) ← V_D)
Expand All @@ -255,7 +255,7 @@ function MAK.check_input(::typeof(eig_full!), t::AbstractBlockTensorMap, DV, ::A
@check_scalar V t complex

# space checks
V_D = TK.oplus(fuse(domain(t)))
V_D = (fuse(domain(t)))
@check_space(D, V_D ← V_D)
@check_space(V, codomain(t) ← V_D)

Expand All @@ -264,7 +264,7 @@ end
MAK.check_input(::typeof(eig_full!), t::AbstractBlockTensorMap, DV, ::DiagonalAlgorithm) = error()

function MAK.initialize_output(::typeof(eig_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
V_D = TK.oplus(fuse(domain(t)))
V_D = (fuse(domain(t)))
Tc = complex(scalartype(t))
D = DiagonalTensorMap{Tc}(undef, V_D)
V = dense_similar(t, Tc, codomain(t) ← V_D)
Expand All @@ -285,17 +285,17 @@ function MAK.check_input(::typeof(svd_full!), t::AbstractBlockTensorMap, USVᴴ,
@check_scalar Vᴴ t

# space checks
V_cod = TK.oplus(fuse(codomain(t)))
V_dom = TK.oplus(fuse(domain(t)))
V_cod = (fuse(codomain(t)))
V_dom = (fuse(domain(t)))
@check_space(U, codomain(t) ← V_cod)
@check_space(S, V_cod ← V_dom)
@check_space(Vᴴ, V_dom ← domain(t))

return nothing
end
function MAK.initialize_output(::typeof(svd_full!), t::AbstractBlockTensorMap, ::AbstractAlgorithm)
V_cod = TK.oplus(fuse(codomain(t)))
V_dom = TK.oplus(fuse(domain(t)))
V_cod = (fuse(codomain(t)))
V_dom = (fuse(domain(t)))
U = dense_similar(t, codomain(t) ← V_cod)
S = similar(t, real(scalartype(t)), V_cod ← V_dom)
Vᴴ = dense_similar(t, V_dom ← domain(t))
Expand Down
77 changes: 41 additions & 36 deletions src/vectorspaces/sumspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,35 @@ const ProductSumSpace{S, N} = ProductSpace{SumSpace{S}, N}
const TensorSumSpace{S} = TensorSpace{SumSpace{S}}
const TensorMapSumSpace{S, N₁, N₂} = TensorMapSpace{SumSpace{S}, N₁, N₂}

# unicode name
"""
V1 ⊞ V2...
boxplus(V1::ElementarySpace, V2::ElementarySpace...)

Create a lazy representation of the direct sum of the supplied vector spaces, which retains the order.
See also [`SumSpace`](@ref).
"""
function ⊞ end
const boxplus = ⊞

⊞(V₁::VectorSpace, V₂::VectorSpace) = ⊞(promote(V₁, V₂)...)
⊞(V::Vararg{VectorSpace}) = reduce(⊞, V)

⊞(V::ElementarySpace) = V isa SumSpace ? V : SumSpace(V)
function (V₁::S ⊞ V₂::S) where {S <: ElementarySpace}
return if isdual(V₁) == isdual(V₂)
SumSpace(V₁, V₂)
else
throw(SpaceMismatch("Direct sum of a vector space and its dual does not exist"))
end
end
function (V₁::SumSpace{S} ⊞ V₂::SumSpace{S}) where {S}
V = SumSpace(vcat(V₁.spaces, V₂.spaces))
allequal(isdual, V.spaces) ||
throw(SpaceMismatch("Direct sum of a vector space and its dual does not exist"))
return V
end

# AbstractArray behavior
# ----------------------
Base.size(S::SumSpace) = size(S.spaces)
Expand Down Expand Up @@ -128,36 +157,12 @@ function Base.:(==)(V::TensorMapSumSpace{S}, W::TensorMapSumSpace{S}) where {S <
end


TensorKit.infimum(V::S, W::S) where {S <: SumSpace} = infimum(TensorKit.oplus(V), TensorKit.oplus(W))
TensorKit.supremum(V::S, W::S) where {S <: SumSpace} = supremum(TensorKit.oplus(V), TensorKit.oplus(W))
TensorKit.ominus(V::S, W::S) where {S <: SumSpace} = ominus(TensorKit.oplus(V), TensorKit.oplus(W))
# this conflicts with the definition in TensorKit, so users always need to specify
# ⊕(Vs::IndexSpace...) = SumSpace(Vs...)

function ⊕ end
⊕(V₁::VectorSpace, V₂::VectorSpace) = ⊕(promote(V₁, V₂)...)
⊕(V::Vararg{VectorSpace}) = foldl(⊕, V)
const oplus = ⊕

⊕(V::ElementarySpace) = V isa SumSpace ? V : SumSpace(V)
function ⊕(V₁::S, V₂::S) where {S <: ElementarySpace}
return if isdual(V₁) == isdual(V₂)
SumSpace(V₁, V₂)
else
throw(SpaceMismatch("Direct sum of a vector space and its dual does not exist"))
end
end
function ⊕(V₁::SumSpace{S}, V₂::SumSpace{S}) where {S}
V = SumSpace(vcat(V₁.spaces, V₂.spaces))
allequal(isdual, V.spaces) ||
throw(SpaceMismatch("Direct sum of a vector space and its dual does not exist"))
return V
end
TensorKit.infimum(V::S, W::S) where {S <: SumSpace} = infimum(⊕(V), ⊕(W))
TensorKit.supremum(V::S, W::S) where {S <: SumSpace} = supremum(⊕(V), ⊕(W))
TensorKit.ominus(V::S, W::S) where {S <: SumSpace} = ominus(⊕(V), ⊕(W))

#! format: off
TensorKit.:⊕(V::SumSpace{S}) where {S} = reduce(TK.oplus, V.spaces; init = isdual(V) ? zero(S)' : zero(S))
TensorKit.:⊕(V1::SumSpace{S}, V2::SumSpace{S}...) where {S} = TensorKit.oplus(⊕(V1, V2...))
#! format: on
TensorKit.oplus(V::SumSpace{S}) where {S} = reduce(⊕, V.spaces; init = isdual(V) ? zero(S)' : zero(S))
TensorKit.oplus(V1::SumSpace{S}, V2::SumSpace{S}...) where {S} = mapreduce(⊕, ⊕, (V1, V2...))

function TensorKit.fuse(V1::S, V2::S) where {S <: SumSpace}
return SumSpace(vec([fuse(v1, v2) for (v1, v2) in Base.product(V1.spaces, V2.spaces)]))
Expand Down Expand Up @@ -186,7 +191,7 @@ function Base.promote_rule(
return TensorMapSumSpace{S}
end

Base.convert(::Type{I}, S::SumSpace{I}) where {I <: ElementarySpace} = TensorKit.oplus(S)
Base.convert(::Type{I}, S::SumSpace{I}) where {I <: ElementarySpace} = (S)
Base.convert(::Type{SumSpace{S}}, V::S) where {S <: ElementarySpace} = SumSpace(V)
function Base.convert(::Type{<:ProductSumSpace{S, N}}, V::ProductSpace{S, N}) where {S, N}
return ProductSumSpace{S, N}(SumSpace.(V.spaces)...)
Expand All @@ -195,7 +200,7 @@ function Base.convert(::Type{<:ProductSumSpace{S}}, V::ProductSpace{S, N}) where
return ProductSumSpace{S, N}(SumSpace.(V.spaces)...)
end
function Base.convert(::Type{<:ProductSpace{S, N}}, V::ProductSumSpace{S, N}) where {S, N}
return ProductSpace{S, N}(TensorKit.oplus.(V.spaces)...)
return ProductSpace{S, N}(map(⊕, V.spaces)...)
end
function Base.convert(
::Type{<:TensorMapSumSpace{S}}, V::TensorMapSpace{S, N₁, N₂}
Expand All @@ -216,7 +221,7 @@ end
const SUMSPACE_SHOW_LIMIT = Ref(5)
function Base.show(io::IO, V::SumSpace)
if length(V) == 1
print(io, "(")
print(io, "(")
show(io, V[1])
print(io, ")")
return nothing
Expand All @@ -227,11 +232,11 @@ function Base.show(io::IO, V::SumSpace)
ax = axes(V.spaces, 1)
f, l = first(ax), last(ax)
h = SUMSPACE_SHOW_LIMIT[] ÷ 2
Base.show_delim_array(io, V.spaces, "(", " ", "", false, f, f + h)
print(io, " ")
Base.show_delim_array(io, V.spaces, "", " ", ")", false, l - h, l)
Base.show_delim_array(io, V.spaces, "(", " ", "", false, f, f + h)
print(io, " ")
Base.show_delim_array(io, V.spaces, "", " ", ")", false, l - h, l)
else
Base.show_delim_array(io, V.spaces, "(", " ", ")", false)
Base.show_delim_array(io, V.spaces, "(", " ", ")", false)
end
return nothing
end
Expand Down
2 changes: 1 addition & 1 deletion src/vectorspaces/sumspaceindices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function Base._cat(
allA = (A, As...)
Vs = ntuple(N₁ + N₂) do i
return if i <= length(catdims) && catdims[i]
((allA[j].sumspaces[i] for j in 1:length(allA))...)
((allA[j].sumspaces[i] for j in 1:length(allA))...)
else
A.sumspaces[i]
end
Expand Down
2 changes: 1 addition & 1 deletion test/abstracttensor/sparseblocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
@test i1 * i2 == @constinferred(id(storagetype(t), V1 ⊗ V2))
@test i2 * i1 == @constinferred(id(storagetype(t), V2 ⊗ V1))

w = @constinferred(isometry(storagetype(t), V1 ⊗ (oneunit(V1) oneunit(V1)), V1))
w = @constinferred(isometry(storagetype(t), V1 ⊗ (oneunit(V1) oneunit(V1)), V1))
@test dim(w) == 2 * dim(V1 ← V1)
@test w' * w == id(storagetype(t), V1)
@test w * w' == (w * w')^2
Expand Down
33 changes: 15 additions & 18 deletions test/vectorspaces/sumspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using TensorKit, BlockTensorKit
using Test, TestExtras

using TensorKit: hassector
using BlockTensorKit: ⊕

ds = [2, 3, 2]
d = sum(ds)
Expand Down Expand Up @@ -42,25 +41,24 @@ using TensorKit, BlockTensorKit
@test @constinferred(axes(V)) == Base.OneTo(d)
W = @constinferred SumSpace(ℝ^1)
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
@test @constinferred((V, V)) == SumSpace(vcat(V.spaces, V.spaces))
@test @constinferred((V, oneunit(V))) == SumSpace(vcat(V.spaces, ℝ^1))
@test @constinferred((V, V, V, V)) == SumSpace(repeat(V.spaces, 4))
@test @constinferred((V, V)) == SumSpace(vcat(V.spaces, V.spaces))
@test @constinferred((V, oneunit(V))) == SumSpace(vcat(V.spaces, ℝ^1))
@test @constinferred((V, V, V, V)) == SumSpace(repeat(V.spaces, 4))
@test @constinferred(fuse(V, V)) ≅ SumSpace(ℝ^(d^2))
@test @constinferred(fuse(V, V', V, V')) ≅ SumSpace(ℝ^(d^4))
@test @constinferred(flip(V)) ≅ V'
@test flip(V) ≅ V
@test flip(V) ≾ V
@test flip(V) ≿ V
@test V ≺ (V, V)
@test !(V ≻ (V, V))
@test V ≺ (V, V)
@test !(V ≻ (V, V))
end

@testset "ComplexSpace" begin
using TensorKit, BlockTensorKit
using Test, TestExtras

using TensorKit: hassector
using BlockTensorKit: ⊕

ds = [2, 3, 2]
d = sum(ds)
Expand Down Expand Up @@ -90,25 +88,24 @@ end
@test @constinferred(axes(V)) == Base.OneTo(d)
W = @constinferred SumSpace(ℂ^1)
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
@test @constinferred((V, V)) == SumSpace(vcat(V.spaces, V.spaces))
@test @constinferred((V, oneunit(V))) == SumSpace(vcat(V.spaces, ℂ^1))
@test @constinferred((V, V, V, V)) == SumSpace(repeat(V.spaces, 4))
@test @constinferred((V, V)) == SumSpace(vcat(V.spaces, V.spaces))
@test @constinferred((V, oneunit(V))) == SumSpace(vcat(V.spaces, ℂ^1))
@test @constinferred((V, V, V, V)) == SumSpace(repeat(V.spaces, 4))
@test @constinferred(fuse(V, V)) ≅ SumSpace(ℂ^(d^2))
@test @constinferred(fuse(V, V', V, V')) ≅ SumSpace(ℂ^(d^4))
@test @constinferred(flip(V)) ≅ V'
@test flip(V) ≅ V
@test flip(V) ≾ V
@test flip(V) ≿ V
@test V ≺ (V, V)
@test !(V ≻ (V, V))
@test V ≺ (V, V)
@test !(V ≻ (V, V))
end

@testset"GradedSpace" begin
using TensorKit, BlockTensorKit
using Test, TestExtras

using TensorKit: hassector
using BlockTensorKit: ⊕

V1 = U1Space(0 => 1, 1 => 1)
V2 = U1Space(0 => 1, 1 => 2)
Expand Down Expand Up @@ -143,16 +140,16 @@ end
@test @constinferred(axes(V)) == Base.OneTo(d)
W = @constinferred SumSpace(U1Space(0 => 1))
@test @constinferred(oneunit(V)) == W == @constinferred(oneunit(typeof(V)))
@test @constinferred((V, V)) == SumSpace(vcat(V.spaces, V.spaces))
@test @constinferred((V, oneunit(V))) == SumSpace(vcat(V.spaces, oneunit(V1)))
@test @constinferred((V, V, V, V)) == SumSpace(repeat(V.spaces, 4))
@test @constinferred((V, V)) == SumSpace(vcat(V.spaces, V.spaces))
@test @constinferred((V, oneunit(V))) == SumSpace(vcat(V.spaces, oneunit(V1)))
@test @constinferred((V, V, V, V)) == SumSpace(repeat(V.spaces, 4))
@test @constinferred(fuse(V, V)) ≅ SumSpace(U1Space(0 => 9, 1 => 24, 2 => 16))
@test @constinferred(fuse(V, V', V, V')) ≅
SumSpace(U1Space(0 => 913, 1 => 600, -1 => 600, 2 => 144, -2 => 144))
@test @constinferred(flip(V)) ≅ SumSpace(flip.(V.spaces)...)
@test flip(V) ≅ V
@test flip(V) ≾ V
@test flip(V) ≿ V
@test V ≺ (V, V)
@test !(V ≻ (V, V))
@test V ≺ (V, V)
@test !(V ≻ (V, V))
end