Skip to content

Commit 568c631

Browse files
author
Jack Dunham
committed
Some minor refactoring of the iterators.
- Reordered the struct fields to be consistant with each other - Some field and function renames
1 parent 20bf783 commit 568c631

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

src/solvers/iterators.jl

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,22 @@ end
4141
mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator
4242
problem::Problem
4343
region_plan::RegionPlan
44-
const sweep::Int
4544
which_region::Int
45+
const which_sweep::Int
4646
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P,R}
47-
return new{P,R}(problem, region_plan, sweep, 1)
47+
return new{P,R}(problem, region_plan, 1, sweep)
4848
end
4949
end
5050

51+
function RegionIterator(problem; sweep, sweep_kwargs...)
52+
plan = region_plan(problem; sweep_kwargs...)
53+
return RegionIterator(problem, plan, sweep)
54+
end
55+
56+
function new_region_iterator(iterator::RegionIterator; sweep_kwargs...)
57+
return RegionIterator(iterator.problem; sweep_kwargs...)
58+
end
59+
5160
state(region_iter::RegionIterator) = region_iter.which_region
5261
Base.length(region_iter::RegionIterator) = length(region_iter.region_plan)
5362

@@ -74,11 +83,10 @@ function prev_region(region_iter::RegionIterator)
7483
end
7584

7685
function next_region(region_iter::RegionIterator)
77-
is_last_region(region_iter) && return nothing
86+
laststep(region_iter) && return nothing
7887
next, _ = region_iter.region_plan[region_iter.which_region + 1]
7988
return next
8089
end
81-
is_last_region(region_iter::RegionIterator) = length(region_iter) === state(region_iter)
8290

8391
#
8492
# Functions associated with RegionIterator
@@ -96,45 +104,44 @@ function compute!(iter::RegionIterator)
96104
return iter
97105
end
98106

99-
function RegionIterator(problem; sweep, sweep_kwargs...)
100-
plan = region_plan(problem; sweep, sweep_kwargs...)
101-
return RegionIterator(problem, plan, sweep)
102-
end
103-
104107
region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...)
105108

106109
#
107110
# SweepIterator
108111
#
109112

110-
mutable struct SweepIterator{Problem} <: AbstractNetworkIterator
111-
sweep_kws
113+
mutable struct SweepIterator{Problem,Iter} <: AbstractNetworkIterator
112114
region_iter::RegionIterator{Problem}
115+
sweep_kwargs::Iterators.Stateful{Iter}
113116
which_sweep::Int
114-
function SweepIterator(problem, sweep_kws)
115-
sweep_kws = Iterators.Stateful(sweep_kws)
116-
first_kwargs, _ = Iterators.peel(sweep_kws)
117+
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob,Iter}
118+
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)
119+
first_kwargs, _ = Iterators.peel(stateful_sweep_kwargs)
117120
region_iter = RegionIterator(problem; sweep=1, first_kwargs...)
118-
return new{typeof(problem)}(sweep_kws, region_iter, 1)
121+
return new{Prob,Iter}(region_iter, stateful_sweep_kwargs, 1)
119122
end
120123
end
121124

122-
laststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kws))
125+
laststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))
123126

124127
region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
125128
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))
126129

127130
state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
128-
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kws)
131+
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
129132
function increment!(sweep_iter::SweepIterator)
130133
sweep_iter.which_sweep += 1
131-
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kws)
132-
sweep_iter.region_iter = RegionIterator(
133-
problem(sweep_iter); sweep=state(sweep_iter), sweep_kwargs...
134-
)
134+
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
135+
update_region_iterator!(sweep_iter; sweep_kwargs...)
135136
return sweep_iter
136137
end
137138

139+
function update_region_iterator!(iterator::SweepIterator; kwargs...)
140+
sweep = state(iterator)
141+
iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...)
142+
return iterator
143+
end
144+
138145
function compute!(sweep_iter::SweepIterator)
139146
for _ in sweep_iter.region_iter
140147
# TODO: Is it sensible to execute the default region callback function?

0 commit comments

Comments
 (0)