Skip to content

Commit 7af4b25

Browse files
author
Jack Dunham
committed
Redesign iterator interface by introducing AbstractNetworkIterator abstract type
Other changes: - Both `sweep_callback` and `region_callback` in `sweep_solve` now take only one positional argument, the sweep iterator. - Iterating `SweepIterator` now automatically performs the RegionIteration - Added an 'adapter' `PauseAfterIncrement` that allows `SweepIterator` to be iterated without performing region iteration - `RegionIteration` now tracks the current sweep number - Replaced some function calls with explict calls to constructors to make it clear when new iterators are being constructed (instead of returned from a field etc). Note, AbstractNetworkIterator interface requires some documentation.
1 parent 245182b commit 7af4b25

File tree

7 files changed

+187
-120
lines changed

7 files changed

+187
-120
lines changed

src/solvers/adapters.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,21 @@ iterator which outputs a tuple of the form (current_region, current_region_kwarg
3030
at each step.
3131
"""
3232
region_tuples(R::RegionIterator) = TupleRegionIterator(R)
33+
34+
"""
35+
struct PauseAfterIncrement{S<:AbstractNetworkIterator}
36+
37+
Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
38+
process. This allows one to manually call a custom `compute!` or insert their own code it in
39+
the loop body in place of `compute!`.
40+
"""
41+
struct PauseAfterIncrement{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
42+
parent::S
43+
end
44+
45+
done(NC::PauseAfterIncrement) = done(NC.parent)
46+
state(NC::PauseAfterIncrement) = state(NC.parent)
47+
increment!(NC::PauseAfterIncrement) = increment!(NC.parent)
48+
compute!(NC::PauseAfterIncrement) = NC
49+
50+
PauseAfterIncrement(NC::PauseAfterIncrement) = NC

src/solvers/applyexp.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ operator(A::ApplyExpProblem) = A.operator
1111
state(A::ApplyExpProblem) = A.state
1212
current_exponent(A::ApplyExpProblem) = A.current_exponent
1313
function current_time(A::ApplyExpProblem)
14-
t = im*A.current_exponent
14+
t = im * A.current_exponent
1515
return iszero(imag(t)) ? real(t) : t
1616
end
1717

@@ -36,41 +36,42 @@ function update(
3636
iszero(abs(exponent_step)) && return prob, local_state
3737

3838
local_state, info = solver(
39-
x->optimal_map(operator(prob), x), exponent_step, local_state; kws...
39+
x -> optimal_map(operator(prob), x), exponent_step, local_state; kws...
4040
)
41-
if nsites==1
41+
if nsites == 1
4242
curr_reg = current_region(region_iterator)
4343
next_reg = next_region(region_iterator)
4444
if !isnothing(next_reg) && next_reg != curr_reg
4545
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
4646
v1, v2 = src(next_edge), dst(next_edge)
4747
psi = copy(state(prob))
4848
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
49-
shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2))
50-
R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...)
51-
local_state = psi[v1]*R_t
49+
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
50+
R_t, _ = solver(x -> optimal_map(shifted_operator, x), -exponent_step, R; kws...)
51+
local_state = psi[v1] * R_t
5252
end
5353
end
5454

55-
prob = set_current_exponent(prob, current_exponent(prob)+exponent_step)
55+
prob = set_current_exponent(prob, current_exponent(prob) + exponent_step)
5656

5757
return prob, local_state
5858
end
5959

60-
function sweep_callback(
61-
problem::ApplyExpProblem;
60+
function default_sweep_callback(
61+
sweep_iterator::SweepIterator{<:ApplyExpProblem};
6262
exponent_description="exponent",
6363
outputlevel,
64-
sweep,
65-
nsweeps,
6664
process_time=identity,
67-
kws...,
65+
kwargs...,
6866
)
6967
if outputlevel >= 1
68+
the_problem = problem(sweep_iterator)
7069
@printf(
71-
" Current %s = %s, ", exponent_description, process_time(current_exponent(problem))
70+
" Current %s = %s, ",
71+
exponent_description,
72+
process_time(current_exponent(the_problem))
7273
)
73-
@printf("maxlinkdim=%d", maxlinkdim(state(problem)))
74+
@printf("maxlinkdim=%d", maxlinkdim(state(the_problem)))
7475
println()
7576
flush(stdout)
7677
end
@@ -88,9 +89,10 @@ function applyexp(
8889
kws...,
8990
)
9091
exponent_steps = diff([zero(eltype(exponents)); exponents])
92+
# exponent_steps = diff(exponents)
9193
sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs)
9294
kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps]
93-
sweep_iter = sweep_iterator(init_prob, kws_array)
95+
sweep_iter = SweepIterator(init_prob, kws_array)
9496
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
9597
return state(converged_prob)
9698
end
@@ -111,11 +113,10 @@ function time_evolve(
111113
time_points,
112114
init_state;
113115
process_time=process_real_times,
114-
sweep_callback=(
115-
a...; k...
116-
)->sweep_callback(a...; exponent_description="time", process_time, k...),
116+
sweep_callback=(a...; k...) ->
117+
default_sweep_callback(a...; exponent_description="time", process_time, k...),
117118
kws...,
118119
)
119-
exponents = [-im*t for t in time_points]
120+
exponents = [-im * t for t in time_points]
120121
return applyexp(operator, exponents, init_state; sweep_callback, kws...)
121122
end

src/solvers/eigsolve.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function update(
3434
solver=eigsolve_solver,
3535
kws...,
3636
)
37-
eigval, local_state = solver->optimal_map(operator(prob), ψ), local_state; kws...)
37+
eigval, local_state = solver -> optimal_map(operator(prob), ψ), local_state; kws...)
3838
prob = set_eigenvalue(prob, eigval)
3939
if outputlevel >= 2
4040
@printf(
@@ -44,12 +44,16 @@ function update(
4444
return prob, local_state
4545
end
4646

47-
function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...)
47+
function default_sweep_callback(
48+
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel
49+
)
4850
if outputlevel >= 1
49-
if nsweeps >= 10
50-
@printf("After sweep %02d/%d ", sweep, nsweeps)
51+
nsweeps = length(sweep_iterator)
52+
current_sweep = sweep_iterator.which_sweep
53+
if length(sweep_iterator) >= 10
54+
@printf("After sweep %02d/%d ", current_sweep, nsweeps)
5155
else
52-
@printf("After sweep %d/%d ", sweep, nsweeps)
56+
@printf("After sweep %d/%d ", current_sweep, nsweeps)
5357
end
5458
@printf("eigenvalue=%.12f", eigenvalue(problem))
5559
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
@@ -73,7 +77,7 @@ function eigsolve(
7377
init_prob = EigsolveProblem(;
7478
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
7579
)
76-
sweep_iter = sweep_iterator(
80+
sweep_iter = SweepIterator(
7781
init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs
7882
)
7983
prob = sweep_solve(sweep_iter; outputlevel, kws...)

src/solvers/fitting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function fit_tensornetwork(
7979
insert_kwargs = (; insert_kwargs..., normalize, set_orthogonal_region=false)
8080
common_sweep_kwargs = (; nsites, outputlevel, update_kwargs, insert_kwargs)
8181
kwargs_array = [(; common_sweep_kwargs..., sweep=s) for s in 1:nsweeps]
82-
sweep_iter = sweep_iterator(init_prob, kwargs_array)
82+
sweep_iter = SweepIterator(init_prob, kwargs_array)
8383
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)
8484
return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob))
8585
end

src/solvers/iterators.jl

Lines changed: 114 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,113 @@
1+
"""
2+
abstract type AbstractNetworkIterator
3+
4+
A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins
5+
with a call to `increment!` before executing `compute!`, however the initial call to
6+
`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that
7+
this call is implict. Termination of the iterator is controlled by the function `done`.
8+
"""
9+
abstract type AbstractNetworkIterator end
10+
11+
# We use greater than or equals here as we increment the state at the start of the iteration
12+
done(NI::AbstractNetworkIterator) = state(NI) >= length(NI)
13+
14+
function Base.iterate(NI::AbstractNetworkIterator, init=true)
15+
done(NI) && return nothing
16+
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
17+
# define a method for increment! This way we avoid cases where one may wish to nest
18+
# calls to different step! methods accidentaly incrementing multiple times.
19+
init || increment!(NI)
20+
rv = compute!(NI)
21+
return rv, false
22+
end
23+
24+
function increment! end
25+
compute!(NI::AbstractNetworkIterator) = NI
26+
27+
step!(NI::AbstractNetworkIterator) = step!(identity, NI)
28+
function step!(f, NI::AbstractNetworkIterator)
29+
compute!(NI)
30+
f(NI)
31+
increment!(NI)
32+
return NI
33+
end
34+
135
#
236
# RegionIterator
337
#
4-
5-
@kwdef mutable struct RegionIterator{Problem,RegionPlan}
38+
"""
39+
struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
40+
"""
41+
mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator
642
problem::Problem
743
region_plan::RegionPlan
8-
which_region::Int = 1
44+
const sweep::Int
45+
which_region::Int
46+
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P,R}
47+
return new{P,R}(problem, region_plan, sweep, 1)
48+
end
949
end
1050

51+
state(R::RegionIterator) = R.which_region
52+
Base.length(R::RegionIterator) = length(R.region_plan)
53+
1154
problem(R::RegionIterator) = R.problem
55+
1256
current_region_plan(R::RegionIterator) = R.region_plan[R.which_region]
13-
current_region(R::RegionIterator) = current_region_plan(R)[1]
14-
region_kwargs(R::RegionIterator) = current_region_plan(R)[2]
15-
function previous_region(R::RegionIterator)
16-
return R.which_region == 1 ? nothing : R.region_plan[R.which_region - 1][1]
57+
58+
function current_region(R::RegionIterator)
59+
region, _ = current_region_plan(R)
60+
return region
1761
end
18-
function next_region(R::RegionIterator)
19-
return if R.which_region == length(R.region_plan)
20-
nothing
21-
else
22-
R.region_plan[R.which_region + 1][1]
23-
end
62+
63+
function current_region_kwargs(R::RegionIterator)
64+
_, kwargs = current_region_plan(R)
65+
return kwargs
2466
end
25-
is_last_region(R::RegionIterator) = isnothing(next_region(R))
2667

27-
function Base.iterate(R::RegionIterator, which=1)
28-
R.which_region = which
29-
region_plan_state = iterate(R.region_plan, which)
30-
isnothing(region_plan_state) && return nothing
31-
(current_region, region_kwargs), next = region_plan_state
32-
R.problem = region_step(problem(R), R; region_kwargs...)
33-
return R, next
68+
function previous_region(R::RegionIterator)
69+
state(R) <= 1 && return nothing
70+
prev, _ = R.region_plan[R.which_region - 1]
71+
return prev
72+
end
73+
74+
function next_region(R::RegionIterator)
75+
is_last_region(R) && return nothing
76+
next, _ = R.region_plan[R.which_region + 1]
77+
return next
3478
end
79+
is_last_region(R::RegionIterator) = length(R) === state(R)
3580

3681
#
3782
# Functions associated with RegionIterator
3883
#
3984

40-
function region_iterator(problem; sweep_kwargs...)
41-
return RegionIterator(; problem, region_plan=region_plan(problem; sweep_kwargs...))
85+
function compute!(R::RegionIterator)
86+
region_kwargs = current_region_kwargs(R)
87+
R.problem = region_step(R; region_kwargs...)
88+
return R
89+
end
90+
function increment!(R::RegionIterator)
91+
R.which_region += 1
92+
return R
93+
end
94+
95+
function RegionIterator(problem; sweep, sweep_kwargs...)
96+
plan = region_plan(problem; sweep, sweep_kwargs...)
97+
return RegionIterator(problem, plan, sweep)
4298
end
4399

44100
function region_step(
45-
problem,
46-
region_iterator;
47-
extract_kwargs=(;),
48-
update_kwargs=(;),
49-
insert_kwargs=(;),
50-
sweep,
51-
kws...,
101+
region_iterator; extract_kwargs=(;), update_kwargs=(;), insert_kwargs=(;), kws...
52102
)
53-
problem, local_state = extract(problem, region_iterator; extract_kwargs..., sweep, kws...)
54-
problem, local_state = update(
55-
problem, local_state, region_iterator; update_kwargs..., kws...
56-
)
57-
problem = insert(problem, local_state, region_iterator; sweep, insert_kwargs..., kws...)
58-
return problem
103+
prob = problem(region_iterator)
104+
105+
sweep = region_iterator.sweep
106+
107+
prob, local_state = extract(prob, region_iterator; extract_kwargs..., sweep, kws...)
108+
prob, local_state = update(prob, local_state, region_iterator; update_kwargs..., kws...)
109+
prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...)
110+
return prob
59111
end
60112

61113
function region_plan(problem; kws...)
@@ -66,39 +118,41 @@ end
66118
# SweepIterator
67119
#
68120

69-
mutable struct SweepIterator{Problem}
121+
mutable struct SweepIterator{Problem} <: AbstractNetworkIterator
70122
sweep_kws
71123
region_iter::RegionIterator{Problem}
72124
which_sweep::Int
125+
function SweepIterator(problem, sweep_kws)
126+
sweep_kws = Iterators.Stateful(sweep_kws)
127+
first_kwargs, _ = Iterators.peel(sweep_kws)
128+
region_iter = RegionIterator(problem; sweep=1, first_kwargs...)
129+
return new{typeof(problem)}(sweep_kws, region_iter, 1)
130+
end
73131
end
74132

75-
problem(S::SweepIterator) = problem(S.region_iter)
76-
77-
Base.length(S::SweepIterator) = length(S.sweep_kws)
133+
done(SR::SweepIterator) = isnothing(peek(SR.sweep_kws))
78134

79-
function Base.iterate(S::SweepIterator, which=nothing)
80-
if isnothing(which)
81-
sweep_kws_state = iterate(S.sweep_kws)
82-
else
83-
sweep_kws_state = iterate(S.sweep_kws, which)
84-
end
85-
isnothing(sweep_kws_state) && return nothing
86-
current_sweep_kws, next = sweep_kws_state
135+
region_iterator(S::SweepIterator) = S.region_iter
136+
problem(S::SweepIterator) = problem(region_iterator(S))
87137

88-
if !isnothing(which)
89-
S.region_iter = region_iterator(
90-
problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws...
91-
)
92-
end
93-
S.which_sweep += 1
94-
return S.region_iter, next
138+
state(SR::SweepIterator) = SR.which_sweep
139+
Base.length(S::SweepIterator) = length(S.sweep_kws)
140+
function increment!(SR::SweepIterator)
141+
SR.which_sweep += 1
142+
sweep_kwargs, _ = Iterators.peel(SR.sweep_kws)
143+
SR.region_iter = RegionIterator(problem(SR); sweep=state(SR), sweep_kwargs...)
144+
return SR
95145
end
96146

97-
function sweep_iterator(problem, sweep_kws)
98-
region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...)
99-
return SweepIterator(sweep_kws, region_iter, 1)
147+
function compute!(SR::SweepIterator)
148+
for _ in SR.region_iter
149+
# TODO: Is it sensible to execute the default region callback function?
150+
end
100151
end
101152

102-
function sweep_iterator(problem, nsweeps::Integer; sweep_kws...)
103-
return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps))
153+
# More basic constructor where sweep_kwargs are constant throughout sweeps
154+
function SweepIterator(problem, nsweeps::Int; sweep_kwargs...)
155+
# Initialize this to an empty RegionIterator
156+
sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps)
157+
return SweepIterator(problem, sweep_kwargs_iter)
104158
end

0 commit comments

Comments
 (0)