diff --git a/src/utility/styles.jl b/src/utility/styles.jl index 104b4e562..962859019 100644 --- a/src/utility/styles.jl +++ b/src/utility/styles.jl @@ -10,6 +10,12 @@ Trait to describe the operator behavior of the input `x` or type `T`, which can abstract type OperatorStyle end OperatorStyle(x) = OperatorStyle(typeof(x)) OperatorStyle(T::Type) = throw(MethodError(OperatorStyle, T)) # avoid stackoverflow if not defined +OperatorStyle(x::OperatorStyle) = x + +OperatorStyle(x, y) = OperatorStyle(OperatorStyle(x)::OperatorStyle, OperatorStyle(y)::OperatorStyle) +OperatorStyle(::T, ::T) where {T <: OperatorStyle} = T() +OperatorStyle(x::OperatorStyle, y::OperatorStyle) = error("Unknown combination of operator styles $x and $y") +@inline OperatorStyle(x, y, zs...) = OperatorStyle(OperatorStyle(x, y), zs...) struct MPOStyle <: OperatorStyle end struct HamiltonianStyle <: OperatorStyle end @@ -29,6 +35,12 @@ Trait to describe the geometry of the input `x` or type `T`, which can be either abstract type GeometryStyle end GeometryStyle(x) = GeometryStyle(typeof(x)) GeometryStyle(T::Type) = throw(MethodError(GeometryStyle, T)) # avoid stackoverflow if not defined +GeometryStyle(x::GeometryStyle) = x + +GeometryStyle(x, y) = GeometryStyle(GeometryStyle(x)::GeometryStyle, GeometryStyle(y)::GeometryStyle) +GeometryStyle(::T, ::T) where {T <: GeometryStyle} = T() +GeometryStyle(x::GeometryStyle, y::GeometryStyle) = error("Unknown combination of geometry styles $x and $y") +@inline GeometryStyle(x, y, zs...) = GeometryStyle(GeometryStyle(x, y), zs...) struct FiniteChainStyle <: GeometryStyle end struct InfiniteChainStyle <: GeometryStyle end diff --git a/test/operators.jl b/test/operators.jl index b2390437d..f5dce96b5 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -71,6 +71,8 @@ module TestOperators mps₁ = FiniteMPS(ψ₁) mps₂ = FiniteMPS(ψ₂) + @test @constinferred GeometryStyle(mps₁, mpo₁, mps₁) == GeometryStyle(mps₁) + @test convert(TensorMap, mpo₁ * mps₁) ≈ O₁ * ψ₁ @test mpo₁ * ψ₁ ≈ O₁ * ψ₁ @test convert(TensorMap, mpo₃ * mps₁) ≈ O₃ * ψ₁ @@ -140,6 +142,7 @@ module TestOperators @test GeometryStyle(H) == FiniteChainStyle() @test OperatorStyle(typeof(H)) == HamiltonianStyle() @test OperatorStyle(H) == HamiltonianStyle() + @test OperatorStyle(H, H′) == OperatorStyle(H) # Infinite Ws = [Wmid] @@ -410,6 +413,8 @@ module TestOperators ψ = InfiniteMPS([pspace], [ou ⊕ pspace]) W = MPSKit.DenseMPO(make_time_mpo(ham, 1im * 0.5, WII())) + + @test GeometryStyle(ψ, W) == GeometryStyle(ψ) @test W * (W * ψ) ≈ (W * W) * ψ atol = 1.0e-2 # TODO: there is a normalization issue here end diff --git a/test/other.jl b/test/other.jl index 17b3138fe..2eb3443ca 100644 --- a/test/other.jl +++ b/test/other.jl @@ -112,6 +112,9 @@ module TestMiscellaneous @test OperatorStyle(MPO) == MPOStyle() @test OperatorStyle(InfiniteMPO) == MPOStyle() + @test OperatorStyle(HamiltonianStyle()) == HamiltonianStyle() + @test @constinferred OperatorStyle(MPO, InfiniteMPO, MPO) == MPOStyle() + @test_throws ErrorException OperatorStyle(MPO, HamiltonianStyle()) @test GeometryStyle(FiniteMPOHamiltonian) == FiniteChainStyle() @test GeometryStyle(InfiniteMPS) == InfiniteChainStyle() @@ -120,5 +123,11 @@ module TestMiscellaneous @test GeometryStyle(FiniteMPOHamiltonian) == FiniteChainStyle() @test GeometryStyle(InfiniteMPO) == InfiniteChainStyle() @test GeometryStyle(InfiniteMPOHamiltonian) == InfiniteChainStyle() + + @test GeometryStyle(GeometryStyle(FiniteMPS)) == GeometryStyle(FiniteMPS) + @test GeometryStyle(FiniteMPS, FiniteMPO) == FiniteChainStyle() + @test_throws ErrorException GeometryStyle(FiniteMPS, InfiniteMPO) + @test @constinferred GeometryStyle(InfiniteMPS, InfiniteMPO, InfiniteMPS) == InfiniteChainStyle() + @test_throws ErrorException GeometryStyle(FiniteMPS, FiniteMPO, InfiniteMPS) end end