|
| 1 | +""" |
| 2 | + abstract type AbstractNetworkIterator |
| 3 | +
|
| 4 | +A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins |
| 5 | +with a call to `increment!` before executing `compute!`, however the initial call to |
| 6 | +`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that |
| 7 | +this call is implict. Termination of the iterator is controlled by the function `done`. |
| 8 | +""" |
| 9 | +abstract type AbstractNetworkIterator end |
| 10 | + |
| 11 | +# We use greater than or equals here as we increment the state at the start of the iteration |
| 12 | +done(NI::AbstractNetworkIterator) = state(NI) >= length(NI) |
| 13 | + |
| 14 | +function Base.iterate(NI::AbstractNetworkIterator, init=true) |
| 15 | + done(NI) && return nothing |
| 16 | + # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* |
| 17 | + # define a method for increment! This way we avoid cases where one may wish to nest |
| 18 | + # calls to different step! methods accidentaly incrementing multiple times. |
| 19 | + init || increment!(NI) |
| 20 | + rv = compute!(NI) |
| 21 | + return rv, false |
| 22 | +end |
| 23 | + |
| 24 | +function increment! end |
| 25 | +compute!(NI::AbstractNetworkIterator) = NI |
| 26 | + |
| 27 | +step!(NI::AbstractNetworkIterator) = step!(identity, NI) |
| 28 | +function step!(f, NI::AbstractNetworkIterator) |
| 29 | + compute!(NI) |
| 30 | + f(NI) |
| 31 | + increment!(NI) |
| 32 | + return NI |
| 33 | +end |
| 34 | + |
1 | 35 | # |
2 | 36 | # RegionIterator |
3 | 37 | # |
4 | | - |
5 | | -@kwdef mutable struct RegionIterator{Problem,RegionPlan} |
| 38 | +""" |
| 39 | + struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator |
| 40 | +""" |
| 41 | +mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator |
6 | 42 | problem::Problem |
7 | 43 | region_plan::RegionPlan |
8 | | - which_region::Int = 1 |
| 44 | + const sweep::Int |
| 45 | + which_region::Int |
| 46 | + function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P,R} |
| 47 | + return new{P,R}(problem, region_plan, sweep, 1) |
| 48 | + end |
9 | 49 | end |
10 | 50 |
|
| 51 | +state(R::RegionIterator) = R.which_region |
| 52 | +Base.length(R::RegionIterator) = length(R.region_plan) |
| 53 | + |
11 | 54 | problem(R::RegionIterator) = R.problem |
| 55 | + |
12 | 56 | current_region_plan(R::RegionIterator) = R.region_plan[R.which_region] |
13 | | -current_region(R::RegionIterator) = current_region_plan(R)[1] |
14 | | -region_kwargs(R::RegionIterator) = current_region_plan(R)[2] |
15 | | -function previous_region(R::RegionIterator) |
16 | | - return R.which_region == 1 ? nothing : R.region_plan[R.which_region - 1][1] |
| 57 | + |
| 58 | +function current_region(R::RegionIterator) |
| 59 | + region, _ = current_region_plan(R) |
| 60 | + return region |
17 | 61 | end |
18 | | -function next_region(R::RegionIterator) |
19 | | - return if R.which_region == length(R.region_plan) |
20 | | - nothing |
21 | | - else |
22 | | - R.region_plan[R.which_region + 1][1] |
23 | | - end |
| 62 | + |
| 63 | +function current_region_kwargs(R::RegionIterator) |
| 64 | + _, kwargs = current_region_plan(R) |
| 65 | + return kwargs |
24 | 66 | end |
25 | | -is_last_region(R::RegionIterator) = isnothing(next_region(R)) |
26 | 67 |
|
27 | | -function Base.iterate(R::RegionIterator, which=1) |
28 | | - R.which_region = which |
29 | | - region_plan_state = iterate(R.region_plan, which) |
30 | | - isnothing(region_plan_state) && return nothing |
31 | | - (current_region, region_kwargs), next = region_plan_state |
32 | | - R.problem = region_step(problem(R), R; region_kwargs...) |
33 | | - return R, next |
| 68 | +function previous_region(R::RegionIterator) |
| 69 | + state(R) <= 1 && return nothing |
| 70 | + prev, _ = R.region_plan[R.which_region - 1] |
| 71 | + return prev |
| 72 | +end |
| 73 | + |
| 74 | +function next_region(R::RegionIterator) |
| 75 | + is_last_region(R) && return nothing |
| 76 | + next, _ = R.region_plan[R.which_region + 1] |
| 77 | + return next |
34 | 78 | end |
| 79 | +is_last_region(R::RegionIterator) = length(R) === state(R) |
35 | 80 |
|
36 | 81 | # |
37 | 82 | # Functions associated with RegionIterator |
38 | 83 | # |
39 | 84 |
|
40 | | -function region_iterator(problem; sweep_kwargs...) |
41 | | - return RegionIterator(; problem, region_plan=region_plan(problem; sweep_kwargs...)) |
| 85 | +function compute!(R::RegionIterator) |
| 86 | + region_kwargs = current_region_kwargs(R) |
| 87 | + R.problem = region_step(R; region_kwargs...) |
| 88 | + return R |
| 89 | +end |
| 90 | +function increment!(R::RegionIterator) |
| 91 | + R.which_region += 1 |
| 92 | + return R |
| 93 | +end |
| 94 | + |
| 95 | +function RegionIterator(problem; sweep, sweep_kwargs...) |
| 96 | + plan = region_plan(problem; sweep, sweep_kwargs...) |
| 97 | + return RegionIterator(problem, plan, sweep) |
42 | 98 | end |
43 | 99 |
|
44 | 100 | function region_step( |
45 | | - problem, |
46 | | - region_iterator; |
47 | | - extract_kwargs=(;), |
48 | | - update_kwargs=(;), |
49 | | - insert_kwargs=(;), |
50 | | - sweep, |
51 | | - kws..., |
| 101 | + region_iterator; extract_kwargs=(;), update_kwargs=(;), insert_kwargs=(;), kws... |
52 | 102 | ) |
53 | | - problem, local_state = extract(problem, region_iterator; extract_kwargs..., sweep, kws...) |
54 | | - problem, local_state = update( |
55 | | - problem, local_state, region_iterator; update_kwargs..., kws... |
56 | | - ) |
57 | | - problem = insert(problem, local_state, region_iterator; sweep, insert_kwargs..., kws...) |
58 | | - return problem |
| 103 | + prob = problem(region_iterator) |
| 104 | + |
| 105 | + sweep = region_iterator.sweep |
| 106 | + |
| 107 | + prob, local_state = extract(prob, region_iterator; extract_kwargs..., sweep, kws...) |
| 108 | + prob, local_state = update(prob, local_state, region_iterator; update_kwargs..., kws...) |
| 109 | + prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...) |
| 110 | + return prob |
59 | 111 | end |
60 | 112 |
|
61 | 113 | function region_plan(problem; kws...) |
|
66 | 118 | # SweepIterator |
67 | 119 | # |
68 | 120 |
|
69 | | -mutable struct SweepIterator{Problem} |
| 121 | +mutable struct SweepIterator{Problem} <: AbstractNetworkIterator |
70 | 122 | sweep_kws |
71 | 123 | region_iter::RegionIterator{Problem} |
72 | 124 | which_sweep::Int |
| 125 | + function SweepIterator(problem, sweep_kws) |
| 126 | + sweep_kws = Iterators.Stateful(sweep_kws) |
| 127 | + first_kwargs, _ = Iterators.peel(sweep_kws) |
| 128 | + region_iter = RegionIterator(problem; sweep=1, first_kwargs...) |
| 129 | + return new{typeof(problem)}(sweep_kws, region_iter, 1) |
| 130 | + end |
73 | 131 | end |
74 | 132 |
|
75 | | -problem(S::SweepIterator) = problem(S.region_iter) |
76 | | - |
77 | | -Base.length(S::SweepIterator) = length(S.sweep_kws) |
| 133 | +done(SR::SweepIterator) = isnothing(peek(SR.sweep_kws)) |
78 | 134 |
|
79 | | -function Base.iterate(S::SweepIterator, which=nothing) |
80 | | - if isnothing(which) |
81 | | - sweep_kws_state = iterate(S.sweep_kws) |
82 | | - else |
83 | | - sweep_kws_state = iterate(S.sweep_kws, which) |
84 | | - end |
85 | | - isnothing(sweep_kws_state) && return nothing |
86 | | - current_sweep_kws, next = sweep_kws_state |
| 135 | +region_iterator(S::SweepIterator) = S.region_iter |
| 136 | +problem(S::SweepIterator) = problem(region_iterator(S)) |
87 | 137 |
|
88 | | - if !isnothing(which) |
89 | | - S.region_iter = region_iterator( |
90 | | - problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws... |
91 | | - ) |
92 | | - end |
93 | | - S.which_sweep += 1 |
94 | | - return S.region_iter, next |
| 138 | +state(SR::SweepIterator) = SR.which_sweep |
| 139 | +Base.length(S::SweepIterator) = length(S.sweep_kws) |
| 140 | +function increment!(SR::SweepIterator) |
| 141 | + SR.which_sweep += 1 |
| 142 | + sweep_kwargs, _ = Iterators.peel(SR.sweep_kws) |
| 143 | + SR.region_iter = RegionIterator(problem(SR); sweep=state(SR), sweep_kwargs...) |
| 144 | + return SR |
95 | 145 | end |
96 | 146 |
|
97 | | -function sweep_iterator(problem, sweep_kws) |
98 | | - region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...) |
99 | | - return SweepIterator(sweep_kws, region_iter, 1) |
| 147 | +function compute!(SR::SweepIterator) |
| 148 | + for _ in SR.region_iter |
| 149 | + # TODO: Is it sensible to execute the default region callback function? |
| 150 | + end |
100 | 151 | end |
101 | 152 |
|
102 | | -function sweep_iterator(problem, nsweeps::Integer; sweep_kws...) |
103 | | - return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps)) |
| 153 | +# More basic constructor where sweep_kwargs are constant throughout sweeps |
| 154 | +function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) |
| 155 | + # Initialize this to an empty RegionIterator |
| 156 | + sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) |
| 157 | + return SweepIterator(problem, sweep_kwargs_iter) |
104 | 158 | end |
0 commit comments