|
| 1 | +""" |
| 2 | + abstract type AbstractNetworkIterator |
| 3 | +
|
| 4 | +A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins |
| 5 | +with a call to `increment!` before executing `compute!`, however the initial call to |
| 6 | +`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that |
| 7 | +this call is implict. Termination of the iterator is controlled by the function `done`. |
| 8 | +""" |
| 9 | +abstract type AbstractNetworkIterator end |
| 10 | + |
| 11 | +# We use greater than or equals here as we increment the state at the start of the iteration |
| 12 | +islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) |
| 13 | + |
| 14 | +function Base.iterate(iterator::AbstractNetworkIterator, init = true) |
| 15 | + # The assumption is that first "increment!" is implicit, therefore we must skip the |
| 16 | + # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not |
| 17 | + # defined when length < 1, |
| 18 | + init || islaststep(iterator) && return nothing |
| 19 | + # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* |
| 20 | + # define a method for increment! This way we avoid cases where one may wish to nest |
| 21 | + # calls to different step! methods accidentaly incrementing multiple times. |
| 22 | + init || increment!(iterator) |
| 23 | + rv = compute!(iterator) |
| 24 | + return rv, false |
| 25 | +end |
| 26 | + |
| 27 | +increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)})) |
| 28 | +compute!(iterator::AbstractNetworkIterator) = iterator |
| 29 | + |
| 30 | +step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) |
| 31 | +function step!(f, iterator::AbstractNetworkIterator) |
| 32 | + compute!(iterator) |
| 33 | + f(iterator) |
| 34 | + increment!(iterator) |
| 35 | + return iterator |
| 36 | +end |
| 37 | + |
| 38 | +# |
| 39 | +# RegionIterator |
| 40 | +# |
| 41 | +""" |
| 42 | + struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator |
| 43 | +""" |
| 44 | +mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator |
| 45 | + problem::Problem |
| 46 | + region_plan::RegionPlan |
| 47 | + which_region::Int |
| 48 | + const which_sweep::Int |
| 49 | + function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} |
| 50 | + if isempty(region_plan) |
| 51 | + throw(ArgumentError("Cannot construct a region iterator with 0 elements.")) |
| 52 | + end |
| 53 | + return new{P, R}(problem, region_plan, 1, sweep) |
| 54 | + end |
| 55 | +end |
| 56 | + |
| 57 | +function RegionIterator(problem; sweep, sweep_kwargs...) |
| 58 | + plan = region_plan(problem; sweep_kwargs...) |
| 59 | + return RegionIterator(problem, plan, sweep) |
| 60 | +end |
| 61 | + |
| 62 | +state(region_iter::RegionIterator) = region_iter.which_region |
| 63 | +Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) |
| 64 | + |
| 65 | +problem(region_iter::RegionIterator) = region_iter.problem |
| 66 | + |
| 67 | +function current_region_plan(region_iter::RegionIterator) |
| 68 | + return region_iter.region_plan[region_iter.which_region] |
| 69 | +end |
| 70 | + |
| 71 | +function current_region(region_iter::RegionIterator) |
| 72 | + region, _ = current_region_plan(region_iter) |
| 73 | + return region |
| 74 | +end |
| 75 | + |
| 76 | +function region_kwargs(region_iter::RegionIterator) |
| 77 | + _, kwargs = current_region_plan(region_iter) |
| 78 | + return kwargs |
| 79 | +end |
| 80 | +function region_kwargs(f::Function, iter::RegionIterator) |
| 81 | + return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) |
| 82 | +end |
| 83 | + |
| 84 | +function prev_region(region_iter::RegionIterator) |
| 85 | + state(region_iter) <= 1 && return nothing |
| 86 | + prev, _ = region_iter.region_plan[region_iter.which_region - 1] |
| 87 | + return prev |
| 88 | +end |
| 89 | + |
| 90 | +function next_region(region_iter::RegionIterator) |
| 91 | + islaststep(region_iter) && return nothing |
| 92 | + next, _ = region_iter.region_plan[region_iter.which_region + 1] |
| 93 | + return next |
| 94 | +end |
| 95 | + |
| 96 | +# |
| 97 | +# Functions associated with RegionIterator |
| 98 | +# |
| 99 | +function increment!(region_iter::RegionIterator) |
| 100 | + region_iter.which_region += 1 |
| 101 | + return region_iter |
| 102 | +end |
| 103 | + |
| 104 | +function compute!(iter::RegionIterator) |
| 105 | + extract!(iter; region_kwargs(extract!, iter)...) |
| 106 | + update!(iter; region_kwargs(update!, iter)...) |
| 107 | + insert!(iter; region_kwargs(insert!, iter)...) |
| 108 | + |
| 109 | + return iter |
| 110 | +end |
| 111 | + |
| 112 | +region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) |
| 113 | + |
| 114 | +# |
| 115 | +# SweepIterator |
| 116 | +# |
| 117 | + |
| 118 | +mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator |
| 119 | + region_iter::RegionIterator{Problem} |
| 120 | + sweep_kwargs::Iterators.Stateful{Iter} |
| 121 | + which_sweep::Int |
| 122 | + function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} |
| 123 | + stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) |
| 124 | + |
| 125 | + first_state = Iterators.peel(stateful_sweep_kwargs) |
| 126 | + |
| 127 | + if isnothing(first_state) |
| 128 | + throw(ArgumentError("Cannot construct a sweep iterator with 0 elements.")) |
| 129 | + end |
| 130 | + |
| 131 | + first_kwargs, _ = first_state |
| 132 | + region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) |
| 133 | + |
| 134 | + return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) |
| 135 | + end |
| 136 | +end |
| 137 | + |
| 138 | +islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) |
| 139 | + |
| 140 | +region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter |
| 141 | +problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) |
| 142 | + |
| 143 | +state(sweep_iter::SweepIterator) = sweep_iter.which_sweep |
| 144 | +Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) |
| 145 | +function increment!(sweep_iter::SweepIterator) |
| 146 | + sweep_iter.which_sweep += 1 |
| 147 | + sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) |
| 148 | + update_region_iterator!(sweep_iter; sweep_kwargs...) |
| 149 | + return sweep_iter |
| 150 | +end |
| 151 | + |
| 152 | +function update_region_iterator!(iterator::SweepIterator; kwargs...) |
| 153 | + sweep = state(iterator) |
| 154 | + iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) |
| 155 | + return iterator |
| 156 | +end |
| 157 | + |
| 158 | +function compute!(sweep_iter::SweepIterator) |
| 159 | + for _ in sweep_iter.region_iter |
| 160 | + # TODO: Is it sensible to execute the default region callback function? |
| 161 | + end |
| 162 | + return |
| 163 | +end |
| 164 | + |
| 165 | +# More basic constructor where sweep_kwargs are constant throughout sweeps |
| 166 | +function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) |
| 167 | + # Initialize this to an empty RegionIterator |
| 168 | + sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) |
| 169 | + return SweepIterator(problem, sweep_kwargs_iter) |
| 170 | +end |
0 commit comments