Skip to content

Commit b1a077b

Browse files
committed
define GeometryStyle and OperatorStyle for multiple arguments by recursion
1 parent e8a1e7d commit b1a077b

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

src/utility/styles.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ Trait to describe the operator behavior of the input `x` or type `T`, which can
1010
abstract type OperatorStyle end
1111
OperatorStyle(x) = OperatorStyle(typeof(x))
1212
OperatorStyle(T::Type) = throw(MethodError(OperatorStyle, T)) # avoid stackoverflow if not defined
13+
OperatorStyle(x::OperatorStyle) = x
14+
15+
OperatorStyle(x, y) = OperatorStyle(OperatorStyle(x)::OperatorStyle, OperatorStyle(y)::OperatorStyle)
16+
OperatorStyle(::T, ::T) where {T<:OperatorStyle} = T()
17+
OperatorStyle(x::OperatorStyle, y::OperatorStyle) = throw(MethodError(OperatorStyle, (x, y)))
18+
@inline OperatorStyle(x, y, zs...) = OperatorStyle(OperatorStyle(x, y), zs...)
1319

1420
struct MPOStyle <: OperatorStyle end
1521
struct HamiltonianStyle <: OperatorStyle end
@@ -29,6 +35,12 @@ Trait to describe the geometry of the input `x` or type `T`, which can be either
2935
abstract type GeometryStyle end
3036
GeometryStyle(x) = GeometryStyle(typeof(x))
3137
GeometryStyle(T::Type) = throw(MethodError(GeometryStyle, T)) # avoid stackoverflow if not defined
38+
GeometryStyle(x::GeometryStyle) = x
39+
40+
GeometryStyle(x, y) = GeometryStyle(GeometryStyle(x)::GeometryStyle, GeometryStyle(y)::GeometryStyle)
41+
GeometryStyle(::T, ::T) where {T<:GeometryStyle} = T()
42+
GeometryStyle(x::GeometryStyle, y::GeometryStyle) = throw(MethodError(GeometryStyle, (x, y)))
43+
@inline GeometryStyle(x, y, zs...) = GeometryStyle(GeometryStyle(x, y), zs...)
3244

3345
struct FiniteChainStyle <: GeometryStyle end
3446
struct InfiniteChainStyle <: GeometryStyle end

test/operators.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ module TestOperators
7171
mps₁ = FiniteMPS(ψ₁)
7272
mps₂ = FiniteMPS(ψ₂)
7373

74+
@test @constinferred GeometryStyle(mps₁, mpo₁, mps₁) == GeometryStyle(mps₁)
75+
7476
@test convert(TensorMap, mpo₁ * mps₁) O₁ * ψ₁
7577
@test mpo₁ * ψ₁ O₁ * ψ₁
7678
@test convert(TensorMap, mpo₃ * mps₁) O₃ * ψ₁
@@ -140,6 +142,7 @@ module TestOperators
140142
@test GeometryStyle(H) == FiniteChainStyle()
141143
@test OperatorStyle(typeof(H)) == HamiltonianStyle()
142144
@test OperatorStyle(H) == HamiltonianStyle()
145+
@test OperatorStyle(H, H´) == OperatorStyle(H)
143146

144147
# Infinite
145148
Ws = [Wmid]
@@ -410,6 +413,8 @@ module TestOperators
410413
ψ = InfiniteMPS([pspace], [ou pspace])
411414

412415
W = MPSKit.DenseMPO(make_time_mpo(ham, 1im * 0.5, WII()))
416+
417+
@test GeometryStyle(ψ, W) == GeometryStyle(ψ)
413418
@test W * (W * ψ) (W * W) * ψ atol = 1.0e-2 # TODO: there is a normalization issue here
414419
end
415420

test/other.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,11 @@ module TestMiscellaneous
120120
@test GeometryStyle(FiniteMPOHamiltonian) == FiniteChainStyle()
121121
@test GeometryStyle(InfiniteMPO) == InfiniteChainStyle()
122122
@test GeometryStyle(InfiniteMPOHamiltonian) == InfiniteChainStyle()
123+
124+
@test GeometryStyle(GeometryStyle(FiniteMPS)) == GeometryStyle(FiniteMPS)
125+
@test GeometryStyle(FiniteMPS, FiniteMPO) == FiniteChainStyle()
126+
@test_throws MethodError GeometryStyle(FiniteMPS, InfiniteMPO)
127+
@test @constinferred GeometryStyle(InfiniteMPS, InfiniteMPO, InfiniteMPS) == InfiniteChainStyle()
128+
@test_throws MethodError GeometryStyle(FiniteMPS, FiniteMPO, InfiniteMPS)
123129
end
124130
end

0 commit comments

Comments
 (0)