Skip to content

Commit 0bbf584

Browse files
author
Jack Dunham
committed
Reintroduce AbstractNetworkIterator and AbstractProblem interface
1 parent 7fd8cd9 commit 0bbf584

File tree

4 files changed

+337
-0
lines changed

4 files changed

+337
-0
lines changed

src/ITensorNetworksNext.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ include("lazynameddimsarrays.jl")
44
include("abstracttensornetwork.jl")
55
include("tensornetwork.jl")
66
include("contract_network.jl")
7+
include("abstract_problem.jl")
8+
include("iterators.jl")
79

810
include("beliefpropagation/abstractbeliefpropagationcache.jl")
911
include("beliefpropagation/beliefpropagationcache.jl")

src/abstract_problem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
abstract type AbstractProblem end

src/iterators.jl

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""
2+
abstract type AbstractNetworkIterator
3+
4+
A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
5+
with a call to `increment!` before executing `compute!`, however the initial call to
6+
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
7+
this call is implict. Termination of the iterator is controlled by the function `done`.
8+
"""
9+
abstract type AbstractNetworkIterator end
10+
11+
# We use greater than or equals here as we increment the state at the start of the iteration
12+
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)
13+
14+
function Base.iterate(iterator::AbstractNetworkIterator, init = true)
15+
# The assumption is that first "increment!" is implicit, therefore we must skip the
16+
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
17+
# defined when length < 1,
18+
init || islaststep(iterator) && return nothing
19+
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
20+
# define a method for increment! This way we avoid cases where one may wish to nest
21+
# calls to different step! methods accidentaly incrementing multiple times.
22+
init || increment!(iterator)
23+
rv = compute!(iterator)
24+
return rv, false
25+
end
26+
27+
function increment! end
28+
compute!(iterator::AbstractNetworkIterator) = iterator
29+
30+
step!(iterator::AbstractNetworkIterator) = step!(identity, iterator)
31+
function step!(f, iterator::AbstractNetworkIterator)
32+
compute!(iterator)
33+
f(iterator)
34+
increment!(iterator)
35+
return iterator
36+
end
37+
38+
#
39+
# RegionIterator
40+
#
41+
"""
42+
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
43+
"""
44+
mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
45+
problem::Problem
46+
region_plan::RegionPlan
47+
which_region::Int
48+
const which_sweep::Int
49+
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
50+
if length(region_plan) == 0
51+
throw(BoundsError("Cannot construct a region iterator with 0 elements."))
52+
end
53+
return new{P, R}(problem, region_plan, 1, sweep)
54+
end
55+
end
56+
57+
function RegionIterator(problem; sweep, sweep_kwargs...)
58+
plan = region_plan(problem; sweep_kwargs...)
59+
return RegionIterator(problem, plan, sweep)
60+
end
61+
62+
function new_region_iterator(iterator::RegionIterator; sweep_kwargs...)
63+
return RegionIterator(iterator.problem; sweep_kwargs...)
64+
end
65+
66+
state(region_iter::RegionIterator) = region_iter.which_region
67+
Base.length(region_iter::RegionIterator) = length(region_iter.region_plan)
68+
69+
problem(region_iter::RegionIterator) = region_iter.problem
70+
71+
function current_region_plan(region_iter::RegionIterator)
72+
return region_iter.region_plan[region_iter.which_region]
73+
end
74+
75+
function current_region(region_iter::RegionIterator)
76+
region, _ = current_region_plan(region_iter)
77+
return region
78+
end
79+
80+
function region_kwargs(region_iter::RegionIterator)
81+
_, kwargs = current_region_plan(region_iter)
82+
return kwargs
83+
end
84+
function region_kwargs(f::Function, iter::RegionIterator)
85+
return get(region_kwargs(iter), Symbol(f, :_kwargs), (;))
86+
end
87+
88+
function prev_region(region_iter::RegionIterator)
89+
state(region_iter) <= 1 && return nothing
90+
prev, _ = region_iter.region_plan[region_iter.which_region - 1]
91+
return prev
92+
end
93+
94+
function next_region(region_iter::RegionIterator)
95+
islaststep(region_iter) && return nothing
96+
next, _ = region_iter.region_plan[region_iter.which_region + 1]
97+
return next
98+
end
99+
100+
#
101+
# Functions associated with RegionIterator
102+
#
103+
function increment!(region_iter::RegionIterator)
104+
region_iter.which_region += 1
105+
return region_iter
106+
end
107+
108+
function compute!(iter::RegionIterator)
109+
_, local_state = extract!(iter; region_kwargs(extract!, iter)...)
110+
_, local_state = update!(iter, local_state; region_kwargs(update!, iter)...)
111+
insert!(iter, local_state; region_kwargs(insert!, iter)...)
112+
113+
return iter
114+
end
115+
116+
region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...)
117+
118+
#
119+
# SweepIterator
120+
#
121+
122+
mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
123+
region_iter::RegionIterator{Problem}
124+
sweep_kwargs::Iterators.Stateful{Iter}
125+
which_sweep::Int
126+
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
127+
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)
128+
first_state = Iterators.peel(stateful_sweep_kwargs)
129+
130+
if isnothing(first_state)
131+
throw(BoundsError("Cannot construct a sweep iterator with 0 elements."))
132+
end
133+
134+
first_kwargs, _ = first_state
135+
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)
136+
137+
return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
138+
end
139+
end
140+
141+
islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))
142+
143+
region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
144+
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))
145+
146+
state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
147+
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
148+
function increment!(sweep_iter::SweepIterator)
149+
sweep_iter.which_sweep += 1
150+
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
151+
update_region_iterator!(sweep_iter; sweep_kwargs...)
152+
return sweep_iter
153+
end
154+
155+
function update_region_iterator!(iterator::SweepIterator; kwargs...)
156+
sweep = state(iterator)
157+
iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...)
158+
return iterator
159+
end
160+
161+
function compute!(sweep_iter::SweepIterator)
162+
for _ in sweep_iter.region_iter
163+
# TODO: Is it sensible to execute the default region callback function?
164+
end
165+
return
166+
end
167+
168+
# More basic constructor where sweep_kwargs are constant throughout sweeps
169+
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
170+
# Initialize this to an empty RegionIterator
171+
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
172+
return SweepIterator(problem, sweep_kwargs_iter)
173+
end

test/test_iterators.jl

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
using Test: @test, @testset, @test_throws
2+
import ITensorNetworksNext as ITensorNetworks
3+
using .ITensorNetworks: SweepIterator, RegionIterator, islaststep, state, increment!, compute!, eachregion
4+
5+
module TestIteratorUtils
6+
7+
import ITensorNetworksNext as ITensorNetworks
8+
using .ITensorNetworks
9+
10+
struct TestProblem <: ITensorNetworks.AbstractProblem
11+
data::Vector{Int}
12+
end
13+
ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)]
14+
function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem})
15+
kwargs = ITensorNetworks.region_kwargs(iter)
16+
push!(ITensorNetworks.problem(iter).data, kwargs.val)
17+
return iter
18+
end
19+
20+
21+
mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator
22+
state::Int
23+
max::Int
24+
output::Vector{Int}
25+
end
26+
27+
ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1
28+
Base.length(TI::TestIterator) = TI.max
29+
ITensorNetworks.state(TI::TestIterator) = TI.state
30+
function ITensorNetworks.compute!(TI::TestIterator)
31+
push!(TI.output, ITensorNetworks.state(TI))
32+
return TI
33+
end
34+
35+
mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator
36+
parent::TestIterator
37+
end
38+
39+
Base.length(SA::SquareAdapter) = length(SA.parent)
40+
ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent)
41+
ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent)
42+
function ITensorNetworks.compute!(SA::SquareAdapter)
43+
ITensorNetworks.compute!(SA.parent)
44+
return last(SA.parent.output)^2
45+
end
46+
47+
end
48+
49+
@testset "Iterators" begin
50+
51+
import .TestIteratorUtils
52+
53+
@testset "`AbstractNetworkIterator` Interface" begin
54+
55+
@testset "Edge cases" begin
56+
TI = TestIteratorUtils.TestIterator(1, 1, [])
57+
cb = []
58+
@test islaststep(TI)
59+
for _ in TI
60+
@test islaststep(TI)
61+
push!(cb, state(TI))
62+
end
63+
@test length(cb) == 1
64+
@test length(TI.output) == 1
65+
@test only(cb) == 1
66+
67+
prob = TestIteratorUtils.TestProblem([])
68+
@test_throws BoundsError SweepIterator(prob, 0)
69+
@test_throws BoundsError RegionIterator(prob, [], 1)
70+
end
71+
72+
TI = TestIteratorUtils.TestIterator(1, 4, [])
73+
74+
@test !islaststep((TI))
75+
76+
# First iterator should compute only
77+
rv, st = iterate(TI)
78+
@test !islaststep((TI))
79+
@test !st
80+
@test rv === TI
81+
@test length(TI.output) == 1
82+
@test only(TI.output) == 1
83+
@test state(TI) == 1
84+
@test !st
85+
86+
rv, st = iterate(TI, st)
87+
@test !islaststep((TI))
88+
@test !st
89+
@test length(TI.output) == 2
90+
@test state(TI) == 2
91+
@test TI.output == [1, 2]
92+
93+
increment!(TI)
94+
@test !islaststep((TI))
95+
@test state(TI) == 3
96+
@test length(TI.output) == 2
97+
@test TI.output == [1, 2]
98+
99+
compute!(TI)
100+
@test !islaststep((TI))
101+
@test state(TI) == 3
102+
@test length(TI.output) == 3
103+
@test TI.output == [1, 2, 3]
104+
105+
# Final Step
106+
iterate(TI, false)
107+
@test islaststep((TI))
108+
@test state(TI) == 4
109+
@test length(TI.output) == 4
110+
@test TI.output == [1, 2, 3, 4]
111+
112+
@test iterate(TI, false) === nothing
113+
114+
TI = TestIteratorUtils.TestIterator(1, 5, [])
115+
116+
cb = []
117+
118+
for _ in TI
119+
@test length(cb) == length(TI.output) - 1
120+
@test cb == (TI.output)[1:(end - 1)]
121+
push!(cb, state(TI))
122+
@test cb == TI.output
123+
end
124+
125+
@test islaststep((TI))
126+
@test length(TI.output) == 5
127+
@test length(cb) == 5
128+
@test cb == TI.output
129+
130+
131+
TI = TestIteratorUtils.TestIterator(1, 5, [])
132+
end
133+
134+
@testset "Adapters" begin
135+
TI = TestIteratorUtils.TestIterator(1, 5, [])
136+
SA = TestIteratorUtils.SquareAdapter(TI)
137+
138+
@testset "Generic" begin
139+
140+
i = 0
141+
for rv in SA
142+
i += 1
143+
@test rv isa Int
144+
@test rv == i^2
145+
@test state(SA) == i
146+
end
147+
148+
@test islaststep((SA))
149+
150+
TI = TestIteratorUtils.TestIterator(1, 5, [])
151+
SA = TestIteratorUtils.SquareAdapter(TI)
152+
153+
SA_c = collect(SA)
154+
155+
@test SA_c isa Vector
156+
@test length(SA_c) == 5
157+
@test SA_c == [1, 4, 9, 16, 25]
158+
159+
end
160+
end
161+
end

0 commit comments

Comments
 (0)