Skip to content

Commit 1c6d1c0

Browse files
committed
Merge pull request #37 from memmett/feature/state
Add unit tests for AdvectionDiffusion example and change how state is passed to sweepers.
2 parents 3f9df5f + e396b0c commit 1c6d1c0

File tree

11 files changed

+274
-77
lines changed

11 files changed

+274
-77
lines changed

examples/advection_diffusion/advection_diffusion_sweeper.hpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
#ifndef _ADVECTION_DIFFUSION_SWEEPER_HPP_
66
#define _ADVECTION_DIFFUSION_SWEEPER_HPP_
77

8-
#include <complex>
9-
#include <vector>
108
#include <cassert>
9+
#include <complex>
10+
#include <map>
1111
#include <ostream>
12+
#include <vector>
1213

1314
#include <pfasst/encap/imex_sweeper.hpp>
1415

@@ -23,13 +24,16 @@ using pfasst::encap::Encapsulation;
2324
using pfasst::encap::as_vector;
2425

2526

27+
typedef map<pair<size_t, size_t>, double> error_map;
28+
2629
template<typename time = pfasst::time_precision>
2730
class AdvectionDiffusionSweeper
2831
: public pfasst::encap::IMEXSweeper<time>
2932
{
3033
FFT fft;
3134

3235
vector<complex<double>> ddx, lap;
36+
error_map errors;
3337

3438
double v = 1.0;
3539
time t0 = 1.0;
@@ -87,20 +91,35 @@ class AdvectionDiffusionSweeper
8791
double d = abs(qend[i] - qex[i]);
8892
if (d > max) { max = d; }
8993
}
90-
cout << "err: " << scientific << max
94+
95+
auto n = this->get_controller()->get_step();
96+
auto k = this->get_controller()->get_iteration();
97+
cout << "err: " << n << " " << k << " " << scientific << max
9198
<< " (" << qend.size() << ", " << predict << ")"
9299
<< endl;
100+
101+
errors.insert(pair<pair<size_t, size_t>, double>
102+
(pair<size_t, size_t>(n, k), max));
103+
}
104+
105+
error_map get_errors()
106+
{
107+
return errors;
93108
}
94109

95-
void predict(time t, time dt, bool initial)
110+
void predict(bool initial)
96111
{
97-
pfasst::encap::IMEXSweeper<time>::predict(t, dt, initial);
112+
pfasst::encap::IMEXSweeper<time>::predict(initial);
113+
time t = this->get_controller()->get_time();
114+
time dt = this->get_controller()->get_time_step();
98115
echo_error(t + dt, true);
99116
}
100117

101-
void sweep(time t, time dt)
118+
void sweep()
102119
{
103-
pfasst::encap::IMEXSweeper<time>::sweep(t, dt);
120+
pfasst::encap::IMEXSweeper<time>::sweep();
121+
time t = this->get_controller()->get_time();
122+
time dt = this->get_controller()->get_time_step();
104123
echo_error(t + dt);
105124
}
106125

examples/advection_diffusion/serial_mlsdc.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
using namespace pfasst;
1919
using namespace pfasst::encap;
2020

21-
int main(int /*argc*/, char** /*argv*/)
21+
error_map run_serial_mlsdc()
2222
{
2323
MLSDC<> mlsdc;
2424

@@ -70,8 +70,17 @@ int main(int /*argc*/, char** /*argv*/)
7070
/*
7171
* run mlsdc!
7272
*/
73-
mlsdc.set_duration(dt, nsteps, niters);
73+
mlsdc.set_duration(0.0, nsteps*dt, dt, niters);
7474
mlsdc.run();
7575

7676
fftw_cleanup();
77+
78+
return sweeper->get_errors();
79+
}
80+
81+
#ifndef PFASST_UNIT_TESTING
82+
int main(int /*argc*/, char** /*argv*/)
83+
{
84+
run_serial_mlsdc();
7785
}
86+
#endif

examples/advection_diffusion/serial_mlsdc_autobuild.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ int main(int /*argc*/, char** /*argv*/)
6565

6666
auto_build(mlsdc, nodes, build_level);
6767
auto_setup(mlsdc, initial);
68-
mlsdc.set_duration(dt, nsteps, niters);
68+
mlsdc.set_duration(0.0, nsteps*dt, dt, niters);
6969
mlsdc.run();
7070

7171
fftw_cleanup();

examples/advection_diffusion/vanilla_sdc.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include "advection_diffusion_sweeper.hpp"
1616

17-
int main(int /*argc*/, char** /*argv*/)
17+
error_map run_vanilla_sdc()
1818
{
1919
pfasst::SDC<> sdc;
2020

@@ -32,7 +32,7 @@ int main(int /*argc*/, char** /*argv*/)
3232
sweeper->set_factory(factory);
3333

3434
sdc.add_level(sweeper);
35-
sdc.set_duration(dt, nsteps, niters);
35+
sdc.set_duration(0.0, nsteps*dt, dt, niters);
3636
sdc.setup();
3737

3838
auto q0 = sweeper->get_state(0);
@@ -41,4 +41,14 @@ int main(int /*argc*/, char** /*argv*/)
4141
sdc.run();
4242

4343
fftw_cleanup();
44+
45+
return sweeper->get_errors();
46+
}
47+
48+
49+
#ifndef PFASST_UNIT_TESTING
50+
int main(int /*argc*/, char** /*argv*/)
51+
{
52+
run_vanilla_sdc();
4453
}
54+
#endif

include/pfasst/controller.hpp

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,30 @@ namespace pfasst
2626
deque<shared_ptr<ISweeper<time>>> levels;
2727
deque<shared_ptr<ITransfer<time>>> transfer;
2828

29-
size_t nsteps, niters;
30-
time dt;
29+
size_t step, iteration, max_iterations;
30+
time t, dt, tend;
3131

3232
public:
3333
//! @{
3434
void setup()
3535
{
3636
for (auto l = coarsest(); l <= finest(); ++l) {
37+
l.current()->set_controller(this);
3738
l.current()->setup();
3839
}
3940
}
4041

41-
void set_duration(time dt, size_t nsteps, size_t niters)
42+
void set_duration(time t0, time tend, time dt, size_t niters)
4243
{
44+
this->t = t0;
45+
this->tend = tend;
4346
this->dt = dt;
44-
this->nsteps = nsteps;
45-
this->niters = niters;
47+
this->step = 0;
48+
this->iteration = 0;
49+
this->max_iterations = niters;
4650
}
4751

48-
void add_level(shared_ptr<ISweeper<time>> swpr,
52+
void add_level(shared_ptr<ISweeper<time>> swpr,
4953
shared_ptr<ITransfer<time>> trnsfr = shared_ptr<ITransfer<time>>(nullptr),
5054
bool coarse = true)
5155
{
@@ -63,32 +67,26 @@ namespace pfasst
6367
template<typename R = ISweeper<time>>
6468
shared_ptr<R> get_level(size_t level)
6569
{
66-
shared_ptr<R> r = dynamic_pointer_cast<R>(levels[level]);
67-
assert(r);
70+
shared_ptr<R> r = dynamic_pointer_cast<R>(levels[level]); assert(r);
6871
return r;
6972
}
7073

7174
template<typename R = ISweeper<time>>
7275
shared_ptr<R> get_finest()
7376
{
74-
shared_ptr<R> r = dynamic_pointer_cast<R>(levels.back());
75-
assert(r);
76-
return r;
77+
return get_level<R>(nlevels()-1);
7778
}
7879

7980
template<typename R = ISweeper<time>>
8081
shared_ptr<R> get_coarsest()
8182
{
82-
shared_ptr<R> r = dynamic_pointer_cast<R>(levels.front());
83-
assert(r);
84-
return r;
83+
return get_level<R>(0);
8584
}
8685

8786
template<typename R = ITransfer<time>>
8887
shared_ptr<R> get_transfer(size_t level)
8988
{
90-
shared_ptr<R> r = dynamic_pointer_cast<R>(transfer[level]);
91-
assert(r);
89+
shared_ptr<R> r = dynamic_pointer_cast<R>(transfer[level]); assert(r);
9290
return r;
9391
}
9492

@@ -100,13 +98,13 @@ namespace pfasst
10098

10199
/**
102100
* level (MLSDC/PFASST) iterator.
103-
*
101+
*
104102
* This iterator is used to walk through the MLSDC/PFASST hierarchy of sweepers.
105-
* It keeps track of the _current_ level, and has convenience routines to return the
103+
* It keeps track of the _current_ level, and has convenience routines to return the
106104
* LevelIter::current(), LevelIter::fine() (i.e. `current+1`), and LevelIter::coarse()
107105
* (`current-1`) sweepers.
108-
*
109-
* Under the hood it satisfies the requirements of std::random_access_iterator_tag, thus
106+
*
107+
* Under the hood it satisfies the requirements of std::random_access_iterator_tag, thus
110108
* implementing a `RandomAccessIterator`.
111109
*/
112110
class LevelIter
@@ -179,6 +177,57 @@ namespace pfasst
179177
LevelIter finest() { return LevelIter(nlevels() - 1, this); }
180178
LevelIter coarsest() { return LevelIter(0, this); }
181179
//! @}
180+
181+
182+
/**
183+
* Get current time step number.
184+
*/
185+
size_t get_step()
186+
{
187+
return step;
188+
}
189+
190+
time get_time_step()
191+
{
192+
return dt;
193+
}
194+
195+
time get_time()
196+
{
197+
return t;
198+
}
199+
200+
void advance_time(size_t nsteps=1)
201+
{
202+
step += nsteps;
203+
t += nsteps*dt;
204+
}
205+
206+
time get_end_time()
207+
{
208+
return tend;
209+
}
210+
211+
size_t get_iteration()
212+
{
213+
return iteration;
214+
}
215+
216+
void set_iteration(size_t iter)
217+
{
218+
this->iteration = iter;
219+
}
220+
221+
void advance_iteration()
222+
{
223+
iteration++;
224+
}
225+
226+
size_t get_max_iterations()
227+
{
228+
return max_iterations;
229+
}
230+
182231
};
183232
}
184233

include/pfasst/encap/imex_sweeper.hpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,15 @@ namespace pfasst
9494
}
9595
}
9696

97-
virtual void sweep(time t0, time dt)
97+
virtual void sweep()
9898
{
9999
const auto nodes = this->get_nodes();
100100
const size_t nnodes = nodes.size();
101101
assert(nnodes >= 1);
102102

103+
time dt = this->get_controller()->get_time_step();
104+
time t = this->get_controller()->get_time();
105+
103106
// integrate
104107
S[0]->mat_apply(S, dt, SEmat, Fe, true);
105108
S[0]->mat_apply(S, dt, SImat, Fi, false);
@@ -112,7 +115,6 @@ namespace pfasst
112115
// sweep
113116
shared_ptr<Encapsulation<time>> rhs = this->get_factory()->create(pfasst::encap::solution);
114117

115-
time t = t0;
116118
for (size_t m = 0; m < nnodes - 1; m++) {
117119
time ds = dt * (nodes[m + 1] - nodes[m]);
118120

@@ -126,20 +128,22 @@ namespace pfasst
126128
}
127129
}
128130

129-
virtual void predict(time t0, time dt, bool initial)
131+
virtual void predict(bool initial)
130132
{
131133
const auto nodes = this->get_nodes();
132134
const size_t nnodes = nodes.size();
133135
assert(nnodes >= 1);
134136

137+
time dt = this->get_controller()->get_time_step();
138+
time t = this->get_controller()->get_time();
139+
135140
if (initial) {
136-
f1eval(Fe[0], Q[0], t0);
137-
f2eval(Fi[0], Q[0], t0);
141+
f1eval(Fe[0], Q[0], t);
142+
f2eval(Fi[0], Q[0], t);
138143
}
139144

140145
shared_ptr<Encapsulation<time>> rhs = this->get_factory()->create(pfasst::encap::solution);
141146

142-
time t = t0;
143147
for (size_t m = 0; m < nnodes - 1; m++) {
144148
time ds = dt * (nodes[m + 1] - nodes[m]);
145149
rhs->copy(Q[m]);
@@ -165,19 +169,19 @@ namespace pfasst
165169
f2eval(Fi[m], Q[m], t);
166170
}
167171

168-
virtual void f1eval(shared_ptr<Encapsulation<time>> /*q*/, shared_ptr<Encapsulation<time>> /*f*/,
172+
virtual void f1eval(shared_ptr<Encapsulation<time>> /*f*/, shared_ptr<Encapsulation<time>> /*q*/,
169173
time /*t*/)
170174
{
171175
throw NotImplementedYet("imex (f1eval)");
172176
}
173177

174-
virtual void f2eval(shared_ptr<Encapsulation<time>> /*q*/, shared_ptr<Encapsulation<time>> /*f*/,
178+
virtual void f2eval(shared_ptr<Encapsulation<time>> /*f*/, shared_ptr<Encapsulation<time>> /*q*/,
175179
time /*t*/)
176180
{
177181
throw NotImplementedYet("imex (f2eval)");
178182
}
179183

180-
virtual void f2comp(shared_ptr<Encapsulation<time>> /*q*/, shared_ptr<Encapsulation<time>> /*f*/,
184+
virtual void f2comp(shared_ptr<Encapsulation<time>> /*f*/, shared_ptr<Encapsulation<time>> /*q*/,
181185
time /*t*/, time /*dt*/,
182186
shared_ptr<Encapsulation<time>> /*rhs*/)
183187
{

0 commit comments

Comments
 (0)