Skip to content

Commit 255d482

Browse files
jack-dunhamJack Dunhammtfishman
authored
Reintroduce AbstractNetworkIterator and AbstractProblem interface (#17)
Co-authored-by: Jack Dunham <[email protected]> Co-authored-by: Matt Fishman <[email protected]>
1 parent c4085f7 commit 255d482

File tree

6 files changed

+441
-1
lines changed

6 files changed

+441
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.11"
4+
version = "0.1.12"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ITensorNetworksNext.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,8 @@ include("lazynameddimsarrays.jl")
44
include("abstracttensornetwork.jl")
55
include("tensornetwork.jl")
66
include("contract_network.jl")
7+
include("abstract_problem.jl")
8+
include("iterators.jl")
9+
include("adapters.jl")
710

811
end

src/abstract_problem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
abstract type AbstractProblem end

src/adapters.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
3+
4+
Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
5+
process. This allows one to manually call a custom `compute!` or insert their own code it in
6+
the loop body in place of `compute!`.
7+
"""
8+
struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator
9+
parent::S
10+
end
11+
12+
islaststep(adapter::IncrementOnly) = islaststep(adapter.parent)
13+
increment!(adapter::IncrementOnly) = increment!(adapter.parent)
14+
compute!(adapter::IncrementOnly) = adapter
15+
16+
IncrementOnly(adapter::IncrementOnly) = adapter
17+
18+
"""
19+
struct EachRegion{SweepIterator} <: AbstractNetworkIterator
20+
21+
Adapter that flattens each region iterator in the parent sweep iterator into a single
22+
iterator.
23+
"""
24+
struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator
25+
parent::SI
26+
end
27+
28+
# In keeping with Julia convention.
29+
eachregion(iter::SweepIterator) = EachRegion(iter)
30+
31+
# Essential definitions
32+
function islaststep(adapter::EachRegion)
33+
region_iter = region_iterator(adapter.parent)
34+
return islaststep(adapter.parent) && islaststep(region_iter)
35+
end
36+
function increment!(adapter::EachRegion)
37+
region_iter = region_iterator(adapter.parent)
38+
islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter)
39+
return adapter
40+
end
41+
function compute!(adapter::EachRegion)
42+
region_iter = region_iterator(adapter.parent)
43+
compute!(region_iter)
44+
return adapter
45+
end

src/iterators.jl

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""
2+
abstract type AbstractNetworkIterator
3+
4+
A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
5+
with a call to `increment!` before executing `compute!`, however the initial call to
6+
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
7+
this call is implict. Termination of the iterator is controlled by the function `done`.
8+
"""
9+
abstract type AbstractNetworkIterator end
10+
11+
# We use greater than or equals here as we increment the state at the start of the iteration
12+
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)
13+
14+
function Base.iterate(iterator::AbstractNetworkIterator, init = true)
15+
# The assumption is that first "increment!" is implicit, therefore we must skip the
16+
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
17+
# defined when length < 1,
18+
init || islaststep(iterator) && return nothing
19+
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
20+
# define a method for increment! This way we avoid cases where one may wish to nest
21+
# calls to different step! methods accidentaly incrementing multiple times.
22+
init || increment!(iterator)
23+
rv = compute!(iterator)
24+
return rv, false
25+
end
26+
27+
increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)}))
28+
compute!(iterator::AbstractNetworkIterator) = iterator
29+
30+
step!(iterator::AbstractNetworkIterator) = step!(identity, iterator)
31+
function step!(f, iterator::AbstractNetworkIterator)
32+
compute!(iterator)
33+
f(iterator)
34+
increment!(iterator)
35+
return iterator
36+
end
37+
38+
#
39+
# RegionIterator
40+
#
41+
"""
42+
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
43+
"""
44+
mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
45+
problem::Problem
46+
region_plan::RegionPlan
47+
which_region::Int
48+
const which_sweep::Int
49+
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
50+
if isempty(region_plan)
51+
throw(ArgumentError("Cannot construct a region iterator with 0 elements."))
52+
end
53+
return new{P, R}(problem, region_plan, 1, sweep)
54+
end
55+
end
56+
57+
function RegionIterator(problem; sweep, sweep_kwargs...)
58+
plan = region_plan(problem; sweep_kwargs...)
59+
return RegionIterator(problem, plan, sweep)
60+
end
61+
62+
state(region_iter::RegionIterator) = region_iter.which_region
63+
Base.length(region_iter::RegionIterator) = length(region_iter.region_plan)
64+
65+
problem(region_iter::RegionIterator) = region_iter.problem
66+
67+
function current_region_plan(region_iter::RegionIterator)
68+
return region_iter.region_plan[region_iter.which_region]
69+
end
70+
71+
function current_region(region_iter::RegionIterator)
72+
region, _ = current_region_plan(region_iter)
73+
return region
74+
end
75+
76+
function region_kwargs(region_iter::RegionIterator)
77+
_, kwargs = current_region_plan(region_iter)
78+
return kwargs
79+
end
80+
function region_kwargs(f::Function, iter::RegionIterator)
81+
return get(region_kwargs(iter), Symbol(f, :_kwargs), (;))
82+
end
83+
84+
function prev_region(region_iter::RegionIterator)
85+
state(region_iter) <= 1 && return nothing
86+
prev, _ = region_iter.region_plan[region_iter.which_region - 1]
87+
return prev
88+
end
89+
90+
function next_region(region_iter::RegionIterator)
91+
islaststep(region_iter) && return nothing
92+
next, _ = region_iter.region_plan[region_iter.which_region + 1]
93+
return next
94+
end
95+
96+
#
97+
# Functions associated with RegionIterator
98+
#
99+
function increment!(region_iter::RegionIterator)
100+
region_iter.which_region += 1
101+
return region_iter
102+
end
103+
104+
function compute!(iter::RegionIterator)
105+
extract!(iter; region_kwargs(extract!, iter)...)
106+
update!(iter; region_kwargs(update!, iter)...)
107+
insert!(iter; region_kwargs(insert!, iter)...)
108+
109+
return iter
110+
end
111+
112+
region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...)
113+
114+
#
115+
# SweepIterator
116+
#
117+
118+
mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
119+
region_iter::RegionIterator{Problem}
120+
sweep_kwargs::Iterators.Stateful{Iter}
121+
which_sweep::Int
122+
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
123+
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)
124+
125+
first_state = Iterators.peel(stateful_sweep_kwargs)
126+
127+
if isnothing(first_state)
128+
throw(ArgumentError("Cannot construct a sweep iterator with 0 elements."))
129+
end
130+
131+
first_kwargs, _ = first_state
132+
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)
133+
134+
return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
135+
end
136+
end
137+
138+
islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))
139+
140+
region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
141+
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))
142+
143+
state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
144+
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
145+
function increment!(sweep_iter::SweepIterator)
146+
sweep_iter.which_sweep += 1
147+
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
148+
update_region_iterator!(sweep_iter; sweep_kwargs...)
149+
return sweep_iter
150+
end
151+
152+
function update_region_iterator!(iterator::SweepIterator; kwargs...)
153+
sweep = state(iterator)
154+
iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...)
155+
return iterator
156+
end
157+
158+
function compute!(sweep_iter::SweepIterator)
159+
for _ in sweep_iter.region_iter
160+
# TODO: Is it sensible to execute the default region callback function?
161+
end
162+
return
163+
end
164+
165+
# More basic constructor where sweep_kwargs are constant throughout sweeps
166+
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
167+
# Initialize this to an empty RegionIterator
168+
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
169+
return SweepIterator(problem, sweep_kwargs_iter)
170+
end

0 commit comments

Comments
 (0)