Skip to content

Commit 6199696

Browse files
committed
Merge pull request #77 from torbjoernk/feature/sweeper-tweaks
LG
2 parents fd8d2e2 + 8d51457 commit 6199696

File tree

4 files changed

+123
-105
lines changed

4 files changed

+123
-105
lines changed

examples/advection_diffusion/advection_diffusion_sweeper.hpp

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -73,60 +73,6 @@ class AdvectionDiffusionSweeper
7373
}
7474
//! @}
7575

76-
//! @{
77-
void exact(shared_ptr<Encapsulation<time>> q, time t)
78-
{
79-
this->exact(as_vector<double, time>(q), t);
80-
}
81-
82-
void exact(DVectorT& q, time t)
83-
{
84-
size_t n = q.size();
85-
double a = 1.0 / sqrt(4 * PI * nu * (t + t0));
86-
87-
for (size_t i = 0; i < n; i++) {
88-
q[i] = 0.0;
89-
}
90-
91-
for (int ii = -2; ii < 3; ii++) {
92-
for (size_t i = 0; i < n; i++) {
93-
double x = double(i) / n - 0.5 + ii - t * v;
94-
q[i] += a * exp(-x * x / (4 * nu * (t + t0)));
95-
}
96-
}
97-
}
98-
99-
void echo_error(time t, bool predict = false)
100-
{
101-
auto& qend = as_vector<double, time>(this->get_end_state());
102-
DVectorT qex(qend.size());
103-
104-
this->exact(qex, t);
105-
106-
double max = 0.0;
107-
for (size_t i = 0; i < qend.size(); i++) {
108-
double d = abs(qend[i] - qex[i]);
109-
if (d > max) { max = d; }
110-
}
111-
112-
auto n = this->get_controller()->get_step();
113-
auto k = this->get_controller()->get_iteration();
114-
cout << "err: " << n << " " << k << " " << scientific << max
115-
<< " (" << qend.size() << ", " << predict << ")"
116-
<< endl;
117-
118-
this->errors.insert(pair<pair<size_t, size_t>, double>(pair<size_t, size_t>(n, k), max));
119-
}
120-
121-
/**
122-
* retrieve errors at iterations and time nodes
123-
*/
124-
error_map get_errors()
125-
{
126-
return this->errors;
127-
}
128-
//! @}
129-
13076
//! @{
13177
/**
13278
* @copybrief pfasst::encap::IMEXSweeper::predict()
@@ -220,6 +166,60 @@ class AdvectionDiffusionSweeper
220166
}
221167
}
222168
//! @}
169+
170+
//! @{
171+
void exact(shared_ptr<Encapsulation<time>> q, time t)
172+
{
173+
this->exact(as_vector<double, time>(q), t);
174+
}
175+
176+
void exact(DVectorT& q, time t)
177+
{
178+
size_t n = q.size();
179+
double a = 1.0 / sqrt(4 * PI * nu * (t + t0));
180+
181+
for (size_t i = 0; i < n; i++) {
182+
q[i] = 0.0;
183+
}
184+
185+
for (int ii = -2; ii < 3; ii++) {
186+
for (size_t i = 0; i < n; i++) {
187+
double x = double(i) / n - 0.5 + ii - t * v;
188+
q[i] += a * exp(-x * x / (4 * nu * (t + t0)));
189+
}
190+
}
191+
}
192+
193+
void echo_error(time t, bool predict = false)
194+
{
195+
auto& qend = as_vector<double, time>(this->get_end_state());
196+
DVectorT qex(qend.size());
197+
198+
this->exact(qex, t);
199+
200+
double max = 0.0;
201+
for (size_t i = 0; i < qend.size(); i++) {
202+
double d = abs(qend[i] - qex[i]);
203+
if (d > max) { max = d; }
204+
}
205+
206+
auto n = this->get_controller()->get_step();
207+
auto k = this->get_controller()->get_iteration();
208+
cout << "err: " << n << " " << k << " " << scientific << max
209+
<< " (" << qend.size() << ", " << predict << ")"
210+
<< endl;
211+
212+
this->errors.insert(pair<pair<size_t, size_t>, double>(pair<size_t, size_t>(n, k), max));
213+
}
214+
215+
/**
216+
* retrieve errors at iterations and time nodes
217+
*/
218+
error_map get_errors()
219+
{
220+
return this->errors;
221+
}
222+
//! @}
223223
};
224224

225225
#endif

include/pfasst/encap/encap_sweeper.hpp

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,17 @@ namespace pfasst
2323
class EncapSweeper
2424
: public ISweeper<time>
2525
{
26+
public:
27+
//! @{
28+
typedef Encapsulation<time> encap_type;
29+
typedef EncapFactory<time> factory_type;
30+
//! @}
31+
32+
private:
2633
//! @{
2734
vector<time> nodes;
2835
vector<bool> is_proper;
29-
shared_ptr<EncapFactory<time>> factory;
36+
shared_ptr<factory_type> factory;
3037
//! @}
3138

3239
public:
@@ -36,46 +43,14 @@ namespace pfasst
3643
//! @}
3744

3845
//! @{
39-
virtual void spread() override
40-
{
41-
for (size_t m = 1; m < nodes.size(); m++) {
42-
this->get_state(m)->copy(this->get_state(0));
43-
}
44-
}
45-
46-
virtual void post(ICommunicator* comm, int tag) override
47-
{
48-
this->get_state(0)->post(comm, tag);
49-
}
50-
51-
virtual void send(ICommunicator* comm, int tag, bool blocking) override
52-
{
53-
this->get_state(this->get_nodes().size() - 1)->send(comm, tag, blocking);
54-
}
55-
56-
virtual void recv(ICommunicator* comm, int tag, bool blocking) override
57-
{
58-
this->get_state(0)->recv(comm, tag, blocking);
59-
}
60-
61-
virtual void broadcast(ICommunicator* comm) override
62-
{
63-
if (comm->rank() == comm->size() - 1) {
64-
this->get_state(0)->copy(this->get_state(this->get_nodes().size() - 1));
65-
}
66-
this->get_state(0)->broadcast(comm);
67-
}
68-
//! @}
69-
70-
//! @{
71-
void set_nodes(vector<time> nodes)
46+
virtual void set_nodes(vector<time> nodes)
7247
{
7348
auto augmented = pfasst::augment_nodes(nodes);
7449
this->nodes = get<0>(augmented);
7550
this->is_proper = get<1>(augmented);
7651
}
7752

78-
const vector<time> get_nodes() const
53+
virtual const vector<time> get_nodes() const
7954
{
8055
return nodes;
8156
}
@@ -85,12 +60,12 @@ namespace pfasst
8560
return is_proper;
8661
}
8762

88-
void set_factory(shared_ptr<EncapFactory<time>> factory)
63+
virtual void set_factory(shared_ptr<factory_type> factory)
8964
{
90-
this->factory = shared_ptr<EncapFactory<time>>(factory);
65+
this->factory = factory;
9166
}
9267

93-
shared_ptr<EncapFactory<time>> get_factory() const
68+
virtual shared_ptr<factory_type> get_factory() const
9469
{
9570
return factory;
9671
}
@@ -103,7 +78,7 @@ namespace pfasst
10378
*
10479
* @note This method must be implemented in derived sweepers.
10580
*/
106-
virtual void set_state(shared_ptr<const Encapsulation<time>> u0, size_t m)
81+
virtual void set_state(shared_ptr<const encap_type> u0, size_t m)
10782
{
10883
UNUSED(u0); UNUSED(m);
10984
throw NotImplementedYet("sweeper");
@@ -168,6 +143,13 @@ namespace pfasst
168143
throw NotImplementedYet("sweeper");
169144
}
170145

146+
virtual void spread() override
147+
{
148+
for (size_t m = 1; m < nodes.size(); m++) {
149+
this->get_state(m)->copy(this->get_state(0));
150+
}
151+
}
152+
171153
/**
172154
* evaluates the right hand side at given time node
173155
*
@@ -193,12 +175,37 @@ namespace pfasst
193175
*
194176
* @note This method must be implemented in derived sweepers.
195177
*/
196-
virtual void integrate(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
178+
virtual void integrate(time dt, vector<shared_ptr<encap_type>> dst) const
197179
{
198180
UNUSED(dt); UNUSED(dst);
199181
throw NotImplementedYet("sweeper");
200182
}
201183
//! @}
184+
185+
//! @{
186+
virtual void post(ICommunicator* comm, int tag) override
187+
{
188+
this->get_state(0)->post(comm, tag);
189+
}
190+
191+
virtual void send(ICommunicator* comm, int tag, bool blocking) override
192+
{
193+
this->get_state(this->get_nodes().size() - 1)->send(comm, tag, blocking);
194+
}
195+
196+
virtual void recv(ICommunicator* comm, int tag, bool blocking) override
197+
{
198+
this->get_state(0)->recv(comm, tag, blocking);
199+
}
200+
201+
virtual void broadcast(ICommunicator* comm) override
202+
{
203+
if (comm->rank() == comm->size() - 1) {
204+
this->get_state(0)->copy(this->get_state(this->get_nodes().size() - 1));
205+
}
206+
this->get_state(0)->broadcast(comm);
207+
}
208+
//! @}
202209
};
203210

204211

include/pfasst/encap/imex_sweeper.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,16 @@ namespace pfasst
233233
}
234234
}
235235

236-
virtual void predict(bool initial)
236+
/**
237+
* Compute low-order provisional solution.
238+
*
239+
* This does not simply copy the initial value to all time nodes but carries out a few
240+
* forward/backward IMEX Euler steps between the nodes.
241+
*
242+
* @param[in] initial if `true` the explicit and implicit part of the right hand side of the
243+
* ODE get evaluated with the initial value
244+
*/
245+
virtual void predict(bool initial) override
237246
{
238247
const auto nodes = this->get_nodes();
239248
const size_t nnodes = nodes.size();
@@ -259,7 +268,7 @@ namespace pfasst
259268
if (this->last_node_is_virtual()) { this->integrate_end_state(dt); }
260269
}
261270

262-
virtual void sweep()
271+
virtual void sweep() override
263272
{
264273
const auto nodes = this->get_nodes();
265274
const size_t nnodes = nodes.size();

include/pfasst/interfaces.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ namespace pfasst
3131
* Used by PFASST to mark methods that are required for a particular algorithm (SDC/MLSDC/PFASST)
3232
* that may not be necessary for all others.
3333
*/
34-
class NotImplementedYet : public exception
34+
class NotImplementedYet
35+
: public exception
3536
{
3637
string msg;
3738
public:
@@ -51,7 +52,8 @@ namespace pfasst
5152
*
5253
* Thrown when a PFASST routine is passed an invalid value.
5354
*/
54-
class ValueError : public exception
55+
class ValueError
56+
: public exception
5557
{
5658
string msg;
5759
public:
@@ -159,12 +161,12 @@ namespace pfasst
159161
virtual void save(bool initial_only=false)
160162
{
161163
UNUSED(initial_only);
162-
NotImplementedYet("mlsdc/pfasst");
164+
throw NotImplementedYet("mlsdc/pfasst");
163165
}
164166

165167
virtual void spread()
166168
{
167-
NotImplementedYet("pfasst");
169+
throw NotImplementedYet("pfasst");
168170
}
169171
//! @}
170172

@@ -177,19 +179,19 @@ namespace pfasst
177179
virtual void send(ICommunicator* comm, int tag, bool blocking)
178180
{
179181
UNUSED(comm); UNUSED(tag); UNUSED(blocking);
180-
NotImplementedYet("pfasst");
182+
throw NotImplementedYet("pfasst");
181183
}
182184

183185
virtual void recv(ICommunicator* comm, int tag, bool blocking)
184186
{
185187
UNUSED(comm); UNUSED(tag); UNUSED(blocking);
186-
NotImplementedYet("pfasst");
188+
throw NotImplementedYet("pfasst");
187189
}
188190

189191
virtual void broadcast(ICommunicator* comm)
190192
{
191193
UNUSED(comm);
192-
NotImplementedYet("pfasst");
194+
throw NotImplementedYet("pfasst");
193195
}
194196
//! @}
195197

@@ -217,7 +219,7 @@ namespace pfasst
217219
shared_ptr<const ISweeper<time>> src)
218220
{
219221
UNUSED(dst); UNUSED(src);
220-
NotImplementedYet("pfasst");
222+
throw NotImplementedYet("pfasst");
221223
}
222224

223225
/**
@@ -238,7 +240,7 @@ namespace pfasst
238240
shared_ptr<const ISweeper<time>> src)
239241
{
240242
UNUSED(dst); UNUSED(src);
241-
NotImplementedYet("pfasst");
243+
throw NotImplementedYet("pfasst");
242244
}
243245

244246

0 commit comments

Comments
 (0)