Skip to content

Commit 8df9de3

Browse files
AFeuerpfeillkdvos
andauthored
Define GeometryStyle and OperatorStyle for multiple arguments (#354)
* define GeometryStyle and OperatorStyle for multiple arguments by recursion * fix formatting * fix test * Update src/utility/styles.jl Rephrase error messages based on Lukas' suggestion Co-authored-by: Lukas Devos <[email protected]> * also change error message of GeometryStyle * fix tests * fix tests * improve code coverage --------- Co-authored-by: Lukas Devos <[email protected]>
1 parent e8a1e7d commit 8df9de3

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

src/utility/styles.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ Trait to describe the operator behavior of the input `x` or type `T`, which can
1010
abstract type OperatorStyle end
1111
OperatorStyle(x) = OperatorStyle(typeof(x))
1212
OperatorStyle(T::Type) = throw(MethodError(OperatorStyle, T)) # avoid stackoverflow if not defined
13+
OperatorStyle(x::OperatorStyle) = x
14+
15+
OperatorStyle(x, y) = OperatorStyle(OperatorStyle(x)::OperatorStyle, OperatorStyle(y)::OperatorStyle)
16+
OperatorStyle(::T, ::T) where {T <: OperatorStyle} = T()
17+
OperatorStyle(x::OperatorStyle, y::OperatorStyle) = error("Unknown combination of operator styles $x and $y")
18+
@inline OperatorStyle(x, y, zs...) = OperatorStyle(OperatorStyle(x, y), zs...)
1319

1420
struct MPOStyle <: OperatorStyle end
1521
struct HamiltonianStyle <: OperatorStyle end
@@ -29,6 +35,12 @@ Trait to describe the geometry of the input `x` or type `T`, which can be either
2935
abstract type GeometryStyle end
3036
GeometryStyle(x) = GeometryStyle(typeof(x))
3137
GeometryStyle(T::Type) = throw(MethodError(GeometryStyle, T)) # avoid stackoverflow if not defined
38+
GeometryStyle(x::GeometryStyle) = x
39+
40+
GeometryStyle(x, y) = GeometryStyle(GeometryStyle(x)::GeometryStyle, GeometryStyle(y)::GeometryStyle)
41+
GeometryStyle(::T, ::T) where {T <: GeometryStyle} = T()
42+
GeometryStyle(x::GeometryStyle, y::GeometryStyle) = error("Unknown combination of geometry styles $x and $y")
43+
@inline GeometryStyle(x, y, zs...) = GeometryStyle(GeometryStyle(x, y), zs...)
3244

3345
struct FiniteChainStyle <: GeometryStyle end
3446
struct InfiniteChainStyle <: GeometryStyle end

test/operators.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ module TestOperators
7171
mps₁ = FiniteMPS(ψ₁)
7272
mps₂ = FiniteMPS(ψ₂)
7373

74+
@test @constinferred GeometryStyle(mps₁, mpo₁, mps₁) == GeometryStyle(mps₁)
75+
7476
@test convert(TensorMap, mpo₁ * mps₁) O₁ * ψ₁
7577
@test mpo₁ * ψ₁ O₁ * ψ₁
7678
@test convert(TensorMap, mpo₃ * mps₁) O₃ * ψ₁
@@ -140,6 +142,7 @@ module TestOperators
140142
@test GeometryStyle(H) == FiniteChainStyle()
141143
@test OperatorStyle(typeof(H)) == HamiltonianStyle()
142144
@test OperatorStyle(H) == HamiltonianStyle()
145+
@test OperatorStyle(H, H′) == OperatorStyle(H)
143146

144147
# Infinite
145148
Ws = [Wmid]
@@ -410,6 +413,8 @@ module TestOperators
410413
ψ = InfiniteMPS([pspace], [ou pspace])
411414

412415
W = MPSKit.DenseMPO(make_time_mpo(ham, 1im * 0.5, WII()))
416+
417+
@test GeometryStyle(ψ, W) == GeometryStyle(ψ)
413418
@test W * (W * ψ) (W * W) * ψ atol = 1.0e-2 # TODO: there is a normalization issue here
414419
end
415420

test/other.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ module TestMiscellaneous
112112

113113
@test OperatorStyle(MPO) == MPOStyle()
114114
@test OperatorStyle(InfiniteMPO) == MPOStyle()
115+
@test OperatorStyle(HamiltonianStyle()) == HamiltonianStyle()
116+
@test @constinferred OperatorStyle(MPO, InfiniteMPO, MPO) == MPOStyle()
117+
@test_throws ErrorException OperatorStyle(MPO, HamiltonianStyle())
115118

116119
@test GeometryStyle(FiniteMPOHamiltonian) == FiniteChainStyle()
117120
@test GeometryStyle(InfiniteMPS) == InfiniteChainStyle()
@@ -120,5 +123,11 @@ module TestMiscellaneous
120123
@test GeometryStyle(FiniteMPOHamiltonian) == FiniteChainStyle()
121124
@test GeometryStyle(InfiniteMPO) == InfiniteChainStyle()
122125
@test GeometryStyle(InfiniteMPOHamiltonian) == InfiniteChainStyle()
126+
127+
@test GeometryStyle(GeometryStyle(FiniteMPS)) == GeometryStyle(FiniteMPS)
128+
@test GeometryStyle(FiniteMPS, FiniteMPO) == FiniteChainStyle()
129+
@test_throws ErrorException GeometryStyle(FiniteMPS, InfiniteMPO)
130+
@test @constinferred GeometryStyle(InfiniteMPS, InfiniteMPO, InfiniteMPS) == InfiniteChainStyle()
131+
@test_throws ErrorException GeometryStyle(FiniteMPS, FiniteMPO, InfiniteMPS)
123132
end
124133
end

0 commit comments

Comments
 (0)