Skip to content

Commit 2de0367

Browse files
committed
Merge pull request #118 from memmett/feature/movestate-merge
Use `shared_ptr` for quadrature, move `state` etc to `encap_sweeper` (2nd try).
2 parents 707ffdd + 937cbfb commit 2de0367

File tree

4 files changed

+137
-150
lines changed

4 files changed

+137
-150
lines changed

include/pfasst/encap/encap_sweeper.hpp

Lines changed: 105 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,137 @@ namespace pfasst
1919
{
2020
namespace encap
2121
{
22+
using namespace pfasst::quadrature;
23+
2224
template<typename time = time_precision>
2325
class EncapSweeper
2426
: public ISweeper<time>
2527
{
2628
protected:
2729
//! @{
28-
quadrature::IQuadrature<time>* quad;
30+
shared_ptr<IQuadrature<time>> quadrature;
2931
shared_ptr<EncapFactory<time>> factory;
3032
shared_ptr<Encapsulation<time>> start_state;
3133
shared_ptr<Encapsulation<time>> end_state;
3234
vector<shared_ptr<Encapsulation<time>>> residuals;
3335
//! @}
3436

37+
//! @{
38+
/**
39+
* Solution values \\( U \\) at all time nodes of the current iteration.
40+
*/
41+
vector<shared_ptr<Encapsulation<time>>> state;
42+
43+
/**
44+
* Solution values \\( U \\) at all time nodes of the previous iteration.
45+
*/
46+
vector<shared_ptr<Encapsulation<time>>> saved_state;
47+
48+
/**
49+
* FAS corrections \\( \\tau \\) at all time nodes of the current iteration.
50+
*/
51+
vector<shared_ptr<Encapsulation<time>>> fas_corrections;
52+
//! @}
53+
3554
int residual_norm_order;
3655
time abs_residual_tol, rel_residual_tol;
3756

3857
public:
39-
//! @{
58+
4059
EncapSweeper()
41-
: quad(nullptr), abs_residual_tol(0.0), rel_residual_tol(0.0)
60+
: quadrature(nullptr), abs_residual_tol(0.0), rel_residual_tol(0.0)
4261
{}
4362

44-
virtual ~EncapSweeper()
63+
//! @{
64+
/**
65+
* Retrieve solution values of current iteration at time node index `m`.
66+
*
67+
* @param[in] m 0-based index of time node
68+
*/
69+
virtual shared_ptr<Encapsulation<time>> get_state(size_t m) const
4570
{
46-
if (this->quad) delete this->quad;
71+
return this->state[m];
72+
}
73+
74+
/**
75+
* Retrieve FAS correction of current iteration at time node index `m`.
76+
*
77+
* @param[in] m 0-based index of time node
78+
*/
79+
virtual shared_ptr<Encapsulation<time>> get_tau(size_t m) const
80+
{
81+
return this->fas_corrections[m];
82+
}
83+
84+
/**
85+
* Retrieve solution values of previous iteration at time node index `m`.
86+
*
87+
* @param[in] m 0-based index of time node
88+
*/
89+
virtual shared_ptr<Encapsulation<time>> get_saved_state(size_t m) const
90+
{
91+
return this->saved_state[m];
4792
}
4893
//! @}
4994

95+
virtual void setup(bool coarse) override
96+
{
97+
auto const nodes = this->quadrature->get_nodes();
98+
auto const num_nodes = this->quadrature->get_num_nodes();
99+
100+
this->start_state = this->get_factory()->create(pfasst::encap::solution);
101+
this->end_state = this->get_factory()->create(pfasst::encap::solution);
102+
103+
for (size_t m = 0; m < num_nodes; m++) {
104+
this->state.push_back(this->get_factory()->create(pfasst::encap::solution));
105+
if (coarse) {
106+
this->saved_state.push_back(this->get_factory()->create(pfasst::encap::solution));
107+
}
108+
}
109+
110+
if (coarse) {
111+
size_t num_fas = this->quadrature->left_is_node() ? num_nodes -1 : num_nodes;
112+
for (size_t m = 0; m < num_fas; m++) {
113+
this->fas_corrections.push_back(this->get_factory()->create(pfasst::encap::solution));
114+
}
115+
}
116+
117+
}
118+
50119
//! @{
51120
virtual void spread() override
52121
{
53-
for (size_t m = 1; m < this->quad->get_num_nodes(); m++) {
122+
for (size_t m = 1; m < this->quadrature->get_num_nodes(); m++) {
54123
// this->get_state(m)->copy(this->start_state);
55-
this->get_state(m)->copy(this->get_state(0));
124+
this->state[m]->copy(this->state[0]);
56125
}
57126
}
58127
//! @}
59128

129+
/**
130+
* Save current solution states.
131+
*/
132+
virtual void save(bool initial_only) override
133+
{
134+
// XXX: if !left_is_node, this is a problem...
135+
if (initial_only) {
136+
this->saved_state[0]->copy(state[0]);
137+
} else {
138+
for (size_t m = 0; m < this->saved_state.size(); m++) {
139+
this->saved_state[m]->copy(state[m]);
140+
}
141+
}
142+
}
143+
60144
//! @{
61-
void set_quadrature(quadrature::IQuadrature<time>* quad)
145+
void set_quadrature(shared_ptr<IQuadrature<time>> quadrature)
62146
{
63-
this->quad = quad;
147+
this->quadrature = quadrature;
64148
}
65149

66-
const quadrature::IQuadrature<time>* get_quadrature() const
150+
shared_ptr<const IQuadrature<time>> get_quadrature() const
67151
{
68-
return this->quad;
152+
return this->quadrature;
69153
}
70154

71155
shared_ptr<Encapsulation<time>> get_start_state() const
@@ -75,7 +159,7 @@ namespace pfasst
75159

76160
const vector<time> get_nodes() const
77161
{
78-
return this->quad->get_nodes();
162+
return this->quadrature->get_nodes();
79163
}
80164

81165
void set_factory(shared_ptr<EncapFactory<time>> factory)
@@ -88,48 +172,6 @@ namespace pfasst
88172
return factory;
89173
}
90174

91-
/**
92-
* retrieve solution values of current iteration at time node index `m`
93-
*
94-
* @param[in] m 0-based index of time node
95-
*
96-
* @note This method must be implemented in derived sweepers.
97-
*/
98-
virtual shared_ptr<Encapsulation<time>> get_state(size_t m) const
99-
{
100-
UNUSED(m);
101-
throw NotImplementedYet("sweeper");
102-
return NULL;
103-
}
104-
105-
/**
106-
* retrieves FAS correction of current iteration at time node index `m`
107-
*
108-
* @param[in] m 0-based index of time node
109-
*
110-
* @note This method must be implemented in derived sweepers.
111-
*/
112-
virtual shared_ptr<Encapsulation<time>> get_tau(size_t m) const
113-
{
114-
UNUSED(m);
115-
throw NotImplementedYet("sweeper");
116-
return NULL;
117-
}
118-
119-
/**
120-
* retrieves solution values of previous iteration at time node index `m`
121-
*
122-
* @param[in] m 0-based index of time node
123-
*
124-
* @note This method must be implemented in derived sweepers.
125-
*/
126-
virtual shared_ptr<Encapsulation<time>> get_saved_state(size_t m) const
127-
{
128-
UNUSED(m);
129-
throw NotImplementedYet("sweeper");
130-
return NULL;
131-
}
132-
133175
virtual shared_ptr<Encapsulation<time>> get_end_state()
134176
{
135177
return this->end_state;
@@ -229,25 +271,29 @@ namespace pfasst
229271
//! @{
230272
virtual void post(ICommunicator* comm, int tag) override
231273
{
232-
this->get_state(0)->post(comm, tag);
274+
this->start_state->post(comm, tag);
233275
}
234276

235277
virtual void send(ICommunicator* comm, int tag, bool blocking) override
236278
{
237-
this->get_state(this->get_nodes().size() - 1)->send(comm, tag, blocking);
279+
this->end_state->send(comm, tag, blocking);
238280
}
239281

240282
virtual void recv(ICommunicator* comm, int tag, bool blocking) override
241283
{
242-
this->get_state(0)->recv(comm, tag, blocking);
284+
this->start_state->recv(comm, tag, blocking);
285+
// XXX
286+
this->state.front()->copy(this->start_state);
243287
}
244288

245289
virtual void broadcast(ICommunicator* comm) override
246290
{
247291
if (comm->rank() == comm->size() - 1) {
248-
this->get_state(0)->copy(this->get_state(this->get_nodes().size() - 1));
292+
this->start_state->copy(this->end_state);
249293
}
250-
this->get_state(0)->broadcast(comm);
294+
this->start_state->broadcast(comm);
295+
// XXX
296+
this->state.front()->copy(this->start_state);
251297
}
252298
//! @}
253299
};

0 commit comments

Comments
 (0)