Skip to content

Commit 2765254

Browse files
committed
controller: Change step/iteration state management.
Rework tracking of current time, step, and iteration. Signed-off-by: Matthew Emmett <[email protected]>
1 parent 6e62617 commit 2765254

File tree

7 files changed

+51
-37
lines changed

7 files changed

+51
-37
lines changed

examples/advection_diffusion/serial_mlsdc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ error_map run_serial_mlsdc()
7070
/*
7171
* run mlsdc!
7272
*/
73-
mlsdc.set_duration(dt, nsteps, niters);
73+
mlsdc.set_duration(0.0, nsteps*dt, dt, niters);
7474
mlsdc.run();
7575

7676
fftw_cleanup();

examples/advection_diffusion/serial_mlsdc_autobuild.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ int main(int /*argc*/, char** /*argv*/)
6565

6666
auto_build(mlsdc, nodes, build_level);
6767
auto_setup(mlsdc, initial);
68-
mlsdc.set_duration(dt, nsteps, niters);
68+
mlsdc.set_duration(0.0, nsteps*dt, dt, niters);
6969
mlsdc.run();
7070

7171
fftw_cleanup();

examples/advection_diffusion/vanilla_sdc.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ error_map run_vanilla_sdc()
3232
sweeper->set_factory(factory);
3333

3434
sdc.add_level(sweeper);
35-
sdc.set_duration(dt, nsteps, niters);
35+
sdc.set_duration(0.0, nsteps*dt, dt, niters);
3636
sdc.setup();
3737

3838
auto q0 = sweeper->get_state(0);
@@ -47,7 +47,7 @@ error_map run_vanilla_sdc()
4747

4848

4949
#ifndef PFASST_UNIT_TESTING
50-
int main(int argc, char** argv)
50+
int main(int /*argc*/, char** /*argv*/)
5151
{
5252
run_vanilla_sdc();
5353
}

include/pfasst/controller.hpp

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ namespace pfasst
2626
deque<shared_ptr<ISweeper<time>>> levels;
2727
deque<shared_ptr<ITransfer<time>>> transfer;
2828

29-
time dt;
29+
int step, iteration, max_iterations;
30+
time t, dt, tend;
3031

3132
public:
3233
//! @{
@@ -38,11 +39,15 @@ namespace pfasst
3839
}
3940
}
4041

41-
void set_duration(time dt, size_t nsteps, size_t niters)
42+
// XXX
43+
void set_duration(time t0, time tend, time dt, int niters)
4244
{
45+
this->t = t0;
46+
this->tend = tend;
4347
this->dt = dt;
44-
steps.set_size(nsteps);
45-
iterations.set_size(niters);
48+
this->step = 0;
49+
this->iteration = 0;
50+
this->max_iterations = niters;
4651
}
4752

4853
void add_level(shared_ptr<ISweeper<time>> swpr,
@@ -176,36 +181,52 @@ namespace pfasst
176181

177182

178183
/**
179-
* Simple range iterator.
184+
* Get current time step number.
180185
*/
181-
class RangeIter {
182-
friend Controller;
183-
size_t i, n;
184-
public:
185-
void set_size(size_t n) { this->n = n; }
186-
void reset(size_t i = 0) { this->i = i; }
187-
bool valid() { return i < n; }
188-
void next() { i++; }
189-
} steps, iterations;
190-
191-
size_t get_step()
186+
int get_step()
192187
{
193-
return steps.i;
188+
return step;
194189
}
195190

196-
size_t get_iteration()
191+
time get_time_step()
197192
{
198-
return iterations.i;
193+
return dt;
199194
}
200195

201196
time get_time()
202197
{
203-
return get_step() * get_time_step();
198+
return t;
204199
}
205200

206-
time get_time_step()
201+
void advance_time(int nsteps=1)
207202
{
208-
return dt;
203+
step += nsteps;
204+
t += nsteps*dt;
205+
}
206+
207+
time get_end_time()
208+
{
209+
return tend;
210+
}
211+
212+
int get_iteration()
213+
{
214+
return iteration;
215+
}
216+
217+
void set_iteration(int iter)
218+
{
219+
this->iteration = iter;
220+
}
221+
222+
void advance_iteration()
223+
{
224+
iteration++;
225+
}
226+
227+
int get_max_iteration()
228+
{
229+
return max_iterations;
209230
}
210231

211232
};

include/pfasst/mlsdc.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,12 @@ namespace pfasst
5858
*/
5959
void run()
6060
{
61-
auto& steps = this->steps;
62-
auto& iters = this->iterations;
63-
64-
for (steps.reset(); steps.valid(); steps.next()) {
61+
for (; this->get_time() < this->get_end_time(); this->advance_time()) {
6562
predict = true; // use predictor for first fine sweep of each step
6663
initial = this->get_step() == 0; // only evaluate node 0 functions on first step
6764

6865
// iterate by performing v-cycles
69-
for (iters.reset(); iters.valid(); iters.next()) {
66+
for (this->set_iteration(0); this->get_iteration() < this->get_max_iteration(); this->advance_iteration()) {
7067
cycle_v(this->finest());
7168
}
7269

include/pfasst/sdc.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@ namespace pfasst
2222
void run()
2323
{
2424
auto sweeper = this->get_level(0);
25-
auto& steps = this->steps;
26-
auto& iters = this->iterations;
2725

28-
for (steps.reset(); steps.valid(); steps.next()) {
26+
for (; this->get_time() < this->get_end_time(); this->advance_time()) {
2927
bool initial = this->get_step() == 0;
30-
for (iters.reset(); iters.valid(); iters.next()) {
28+
for (this->set_iteration(0); this->get_iteration() < this->get_max_iteration(); this->advance_iteration()) {
3129
bool predict = this->get_iteration() == 0;
3230
if (predict) {
3331
sweeper->predict(initial);

tests/test-advection-diffusion.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ TEST(ErrorTest, VanillaSDC)
3434

3535
auto errors = run_vanilla_sdc();
3636
auto get_iter = [](const vtype x) { return get<1>(get<0>(x)); };
37-
auto get_step = [](const vtype x) { return get<0>(get<0>(x)); };
3837
auto get_error = [](const vtype x) { return get<1>(x); };
3938

4039
auto max_iter = get_iter(*std::max_element(errors.begin(), errors.end(),
@@ -57,7 +56,6 @@ TEST(ErrorTest, SerialMLSDC)
5756

5857
auto errors = run_serial_mlsdc();
5958
auto get_iter = [](const vtype x) { return get<1>(get<0>(x)); };
60-
auto get_step = [](const vtype x) { return get<0>(get<0>(x)); };
6159
auto get_error = [](const vtype x) { return get<1>(x); };
6260

6361
auto max_iter = get_iter(*std::max_element(errors.begin(), errors.end(),

0 commit comments

Comments
 (0)