Skip to content

Commit 8b0fd93

Browse files
committed
internals: splitting up decls/defs for encapsulation related stuff
towards fixing #84
1 parent e4ae6fb commit 8b0fd93

15 files changed

+1260
-929
lines changed

examples/boris/boris_sweeper.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ namespace pfasst
178178
virtual void set_state(shared_ptr<const encap_type> u0, size_t m);
179179
virtual void set_start_state(shared_ptr<const encap_type> u0);
180180
virtual shared_ptr<Encapsulation<time>> get_state(size_t m) const override;
181-
virtual shared_ptr<encap_type> get_start_state() const;
181+
virtual shared_ptr<Encapsulation<time>> get_start_state() const;
182182
virtual shared_ptr<acceleration_type> get_tau_q_as_force(size_t m) const;
183183
virtual shared_ptr<acceleration_type> get_tau_qq_as_force(size_t m) const;
184184
virtual shared_ptr<Encapsulation<time>> get_saved_state(size_t m) const override;

examples/boris/boris_sweeper_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ namespace pfasst
269269
}
270270

271271
template<typename scalar, typename time>
272-
shared_ptr<typename BorisSweeper<scalar, time>::encap_type> BorisSweeper<scalar, time>::get_start_state() const
272+
shared_ptr<Encapsulation<time>> BorisSweeper<scalar, time>::get_start_state() const
273273
{
274274
return this->start_particles;
275275
}

examples/boris/injective_transfer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ namespace pfasst
132132
auto fine = dynamic_pointer_cast<const BorisSweeper<scalar, time>>(src);
133133
assert(fine);
134134
CVLOG(5, "BorisTransfer") << "fine: " << fine->get_start_state();
135-
coarse->set_start_state(fine->get_start_state());
135+
coarse->set_start_state(dynamic_pointer_cast<typename BorisSweeper<scalar, time>::encap_type>(fine->get_start_state()));
136136
CVLOG(5, "BorisTransfer") << "restricted: " << coarse->get_start_state();
137137
}
138138

include/pfasst/encap/encap_sweeper.hpp

Lines changed: 30 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,15 @@
55
#ifndef _PFASST_ENCAP_ENCAP_SWEEPER_HPP_
66
#define _PFASST_ENCAP_ENCAP_SWEEPER_HPP_
77

8-
#include <algorithm>
9-
#include <cstdlib>
108
#include <memory>
119
#include <vector>
1210
using namespace std;
1311

14-
#include "../globals.hpp"
15-
#include "../config.hpp"
1612
#include "../interfaces.hpp"
1713
#include "../quadrature.hpp"
1814
#include "encapsulation.hpp"
1915

16+
2017
namespace pfasst
2118
{
2219
namespace encap
@@ -57,132 +54,53 @@ namespace pfasst
5754
time abs_residual_tol, rel_residual_tol;
5855

5956
public:
60-
61-
EncapSweeper()
62-
: quadrature(nullptr), abs_residual_tol(0.0), rel_residual_tol(0.0)
63-
{}
57+
EncapSweeper();
6458

6559
//! @{
6660
/**
6761
* Retrieve solution values of current iteration at time node index `m`.
6862
*
6963
* @param[in] m 0-based index of time node
7064
*/
71-
virtual shared_ptr<Encapsulation<time>> get_state(size_t m) const
72-
{
73-
return this->state[m];
74-
}
65+
virtual shared_ptr<Encapsulation<time>> get_state(size_t m) const;
7566

7667
/**
7768
* Retrieve FAS correction of current iteration at time node index `m`.
7869
*
7970
* @param[in] m 0-based index of time node
8071
*/
81-
virtual shared_ptr<Encapsulation<time>> get_tau(size_t m) const
82-
{
83-
return this->fas_corrections[m];
84-
}
72+
virtual shared_ptr<Encapsulation<time>> get_tau(size_t m) const;
8573

8674
/**
8775
* Retrieve solution values of previous iteration at time node index `m`.
8876
*
8977
* @param[in] m 0-based index of time node
9078
*/
91-
virtual shared_ptr<Encapsulation<time>> get_saved_state(size_t m) const
92-
{
93-
return this->saved_state[m];
94-
}
79+
virtual shared_ptr<Encapsulation<time>> get_saved_state(size_t m) const;
9580
//! @}
9681

9782
//! @{
98-
virtual void set_options() override
99-
{
100-
this->abs_residual_tol = time(config::get_value<double>("abs_res_tol", this->abs_residual_tol));
101-
this->rel_residual_tol = time(config::get_value<double>("rel_res_tol", this->rel_residual_tol));
102-
}
103-
104-
virtual void setup(bool coarse) override
105-
{
106-
auto const nodes = this->quadrature->get_nodes();
107-
auto const num_nodes = this->quadrature->get_num_nodes();
108-
109-
this->start_state = this->get_factory()->create(pfasst::encap::solution);
110-
this->end_state = this->get_factory()->create(pfasst::encap::solution);
111-
112-
for (size_t m = 0; m < num_nodes; m++) {
113-
this->state.push_back(this->get_factory()->create(pfasst::encap::solution));
114-
if (coarse) {
115-
this->saved_state.push_back(this->get_factory()->create(pfasst::encap::solution));
116-
}
117-
}
118-
119-
if (coarse) {
120-
for (size_t m = 0; m < num_nodes; m++) {
121-
this->fas_corrections.push_back(this->get_factory()->create(pfasst::encap::solution));
122-
}
123-
}
124-
}
83+
virtual void set_options() override;
84+
virtual void setup(bool coarse) override;
12585
//! @}
12686

12787
//! @{
128-
virtual void spread() override
129-
{
130-
for (size_t m = 1; m < this->quadrature->get_num_nodes(); m++) {
131-
this->state[m]->copy(this->state[0]);
132-
}
133-
}
88+
virtual void spread() override;
13489

13590
/**
13691
* Save current solution states.
13792
*/
138-
virtual void save(bool initial_only) override
139-
{
140-
// XXX: if !left_is_node, this is a problem...
141-
if (initial_only) {
142-
this->saved_state[0]->copy(state[0]);
143-
} else {
144-
for (size_t m = 0; m < this->saved_state.size(); m++) {
145-
this->saved_state[m]->copy(state[m]);
146-
}
147-
}
148-
}
93+
virtual void save(bool initial_only) override;
14994
//! @}
15095

15196
//! @{
152-
void set_quadrature(shared_ptr<IQuadrature<time>> quadrature)
153-
{
154-
this->quadrature = quadrature;
155-
}
156-
157-
shared_ptr<const IQuadrature<time>> get_quadrature() const
158-
{
159-
return this->quadrature;
160-
}
161-
162-
shared_ptr<Encapsulation<time>> get_start_state() const
163-
{
164-
return this->start_state;
165-
}
166-
167-
const vector<time> get_nodes() const
168-
{
169-
return this->quadrature->get_nodes();
170-
}
171-
172-
void set_factory(shared_ptr<EncapFactory<time>> factory)
173-
{
174-
this->factory = factory;
175-
}
176-
177-
virtual shared_ptr<EncapFactory<time>> get_factory() const
178-
{
179-
return factory;
180-
}
181-
182-
virtual shared_ptr<Encapsulation<time>> get_end_state()
183-
{
184-
return this->end_state;
185-
}
97+
virtual void set_quadrature(shared_ptr<IQuadrature<time>> quadrature);
98+
virtual shared_ptr<const IQuadrature<time>> get_quadrature() const;
99+
virtual shared_ptr<Encapsulation<time>> get_start_state() const;
100+
virtual const vector<time> get_nodes() const;
101+
virtual void set_factory(shared_ptr<EncapFactory<time>> factory);
102+
virtual shared_ptr<EncapFactory<time>> get_factory() const;
103+
virtual shared_ptr<Encapsulation<time>> get_end_state();
186104
//! @}
187105

188106
//! @{
@@ -191,21 +109,14 @@ namespace pfasst
191109
*
192110
* @note This method must be implemented in derived sweepers.
193111
*/
194-
virtual void advance() override
195-
{
196-
throw NotImplementedYet("sweeper");
197-
}
112+
virtual void advance() override;
198113

199114
/**
200115
* Re-evaluate function values.
201116
*
202117
* @note This method must be implemented in derived sweepers.
203118
*/
204-
virtual void reevaluate(bool initial_only=false)
205-
{
206-
UNUSED(initial_only);
207-
throw NotImplementedYet("sweeper");
208-
}
119+
virtual void reevaluate(bool initial_only = false);
209120

210121
/**
211122
* integrates values of right hand side at all time nodes \\( t \\in [0,M-1] \\)
@@ -216,110 +127,45 @@ namespace pfasst
216127
*
217128
* @note This method must be implemented in derived sweepers.
218129
*/
219-
virtual void integrate(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
220-
{
221-
UNUSED(dt); UNUSED(dst);
222-
throw NotImplementedYet("sweeper");
223-
}
130+
virtual void integrate(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const;
224131
//! @}
225132

226133
//! @{
227134
/**
228135
* Set residual tolerances for convergence checking.
229136
*/
230-
void set_residual_tolerances(time abs_residual_tol, time rel_residual_tol, int order=0)
231-
{
232-
this->abs_residual_tol = abs_residual_tol;
233-
this->rel_residual_tol = rel_residual_tol;
234-
this->residual_norm_order = order;
235-
}
137+
void set_residual_tolerances(time abs_residual_tol, time rel_residual_tol, int order = 0);
236138

237139
/**
238140
* Compute residual at each SDC node (including FAS corrections).
239141
*/
240-
virtual void residual(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
241-
{
242-
UNUSED(dt); UNUSED(dst);
243-
throw NotImplementedYet("residual");
244-
}
142+
virtual void residual(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const;
245143

246144
/**
247145
* Return convergence status.
248146
*
249147
* This is used by controllers to shortcircuit iterations.
250148
*/
251-
virtual bool converged() override
252-
{
253-
if (this->abs_residual_tol > 0.0 || this->rel_residual_tol > 0.0) {
254-
if (this->residuals.size() == 0) {
255-
for (size_t m = 0; m < this->get_nodes().size(); m++) {
256-
this->residuals.push_back(this->get_factory()->create(pfasst::encap::solution));
257-
}
258-
}
259-
this->residual(this->get_controller()->get_time_step(), this->residuals);
260-
vector<time> anorms, rnorms;
261-
for (size_t m = 0; m < this->get_nodes().size(); m++) {
262-
anorms.push_back(this->residuals[m]->norm0());
263-
rnorms.push_back(anorms.back() / this->get_state(m)->norm0());
264-
}
265-
auto amax = *std::max_element(anorms.begin(), anorms.end());
266-
auto rmax = *std::max_element(rnorms.begin(), rnorms.end());
267-
if (amax < this->abs_residual_tol || rmax < this->rel_residual_tol) {
268-
return true;
269-
}
270-
}
271-
return false;
272-
}
149+
virtual bool converged() override;
273150
//! @}
274151

275152
//! @{
276-
virtual void post(ICommunicator* comm, int tag) override
277-
{
278-
this->start_state->post(comm, tag);
279-
}
280-
281-
virtual void send(ICommunicator* comm, int tag, bool blocking) override
282-
{
283-
this->end_state->send(comm, tag, blocking);
284-
}
285-
286-
virtual void recv(ICommunicator* comm, int tag, bool blocking) override
287-
{
288-
this->start_state->recv(comm, tag, blocking);
289-
if (this->quadrature->left_is_node()) {
290-
this->state[0]->copy(this->start_state);
291-
}
292-
}
293-
294-
virtual void broadcast(ICommunicator* comm) override
295-
{
296-
if (comm->rank() == comm->size() - 1) {
297-
this->start_state->copy(this->end_state);
298-
}
299-
this->start_state->broadcast(comm);
300-
}
153+
virtual void post(ICommunicator* comm, int tag) override;
154+
virtual void send(ICommunicator* comm, int tag, bool blocking) override;
155+
virtual void recv(ICommunicator* comm, int tag, bool blocking) override;
156+
virtual void broadcast(ICommunicator* comm) override;
301157
//! @}
302158
};
303159

304160

305161
template<typename time>
306-
EncapSweeper<time>& as_encap_sweeper(shared_ptr<ISweeper<time>> x)
307-
{
308-
shared_ptr<EncapSweeper<time>> y = dynamic_pointer_cast<EncapSweeper<time>>(x);
309-
assert(y);
310-
return *y.get();
311-
}
312-
162+
EncapSweeper<time>& as_encap_sweeper(shared_ptr<ISweeper<time>> x);
313163

314164
template<typename time>
315-
const EncapSweeper<time>& as_encap_sweeper(shared_ptr<const ISweeper<time>> x)
316-
{
317-
shared_ptr<const EncapSweeper<time>> y = dynamic_pointer_cast<const EncapSweeper<time>>(x);
318-
assert(y);
319-
return *y.get();
320-
}
321-
165+
const EncapSweeper<time>& as_encap_sweeper(shared_ptr<const ISweeper<time>> x);
322166
} // ::pfasst::encap
323167
} // ::pfasst
324168

169+
#include "encap_sweeper_impl.hpp"
170+
325171
#endif

0 commit comments

Comments
 (0)