Skip to content

Commit e491dab

Browse files
committed
internals: splitting up decls/defs for pfasst controller
towards fixing #84
1 parent 64902bb commit e491dab

File tree

2 files changed

+221
-189
lines changed

2 files changed

+221
-189
lines changed

include/pfasst/pfasst.hpp

Lines changed: 15 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
1-
/*
2-
* PFASST controller.
3-
*/
4-
51
#ifndef _PFASST_PFASST_HPP_
62
#define _PFASST_PFASST_HPP_
73

8-
#include "logging.hpp"
94
#include "mlsdc.hpp"
105

11-
using namespace std;
126

137
namespace pfasst
148
{
15-
169
/**
1710
* implementation of the PFASST algorithm as described in \cite emmett_pfasst_2012
1811
*/
@@ -26,20 +19,7 @@ namespace pfasst
2619

2720
bool predict; //<! whether to use a 'predict' sweep
2821

29-
void perform_sweeps(size_t level)
30-
{
31-
auto sweeper = this->get_level(level);
32-
for (size_t s = 0; s < this->nsweeps[level]; s++) {
33-
if (predict) {
34-
sweeper->predict(predict);
35-
sweeper->post_predict();
36-
predict = false;
37-
} else {
38-
sweeper->sweep();
39-
sweeper->post_sweep();
40-
}
41-
}
42-
}
22+
virtual void perform_sweeps(size_t level);
4323

4424
public:
4525
/**
@@ -50,75 +30,15 @@ namespace pfasst
5030
*
5131
* Currently uses "block mode" PFASST with the standard predictor.
5232
*/
53-
void run()
54-
{
55-
if (this->comm->size() == 1) {
56-
pfasst::MLSDC<time>::run();
57-
return;
58-
}
59-
60-
int nblocks = int(this->get_end_time() / this->get_time_step()) / comm->size();
61-
62-
if (nblocks == 0) {
63-
throw ValueError("invalid duration: there are more time processors than time steps");
64-
}
65-
66-
for (int nblock = 0; nblock < nblocks; nblock++) {
67-
this->set_step(nblock * comm->size() + comm->rank());
68-
69-
if (this->comm->size() == 1) {
70-
predict = true;
71-
} else {
72-
predictor();
73-
}
74-
75-
for (this->set_iteration(0);
76-
this->get_iteration() < this->get_max_iterations() && this->comm->status->keep_iterating();
77-
this->advance_iteration()) {
78-
79-
if (this->comm->status->previous_is_iterating()) {
80-
post();
81-
}
82-
cycle_v(this->finest());
83-
}
84-
85-
for (auto l = this->finest(); l >= this->coarsest(); --l) {
86-
l.current()->post_step();
87-
}
88-
89-
if (nblock < nblocks - 1) {
90-
broadcast();
91-
}
92-
93-
this->comm->status->clear();
94-
}
95-
}
33+
virtual void run();
34+
35+
virtual void set_comm(ICommunicator* comm);
9636

9737
private:
9838
/**
9939
* Cycle down: sweep on current (fine), restrict to coarse.
10040
*/
101-
LevelIter cycle_down(LevelIter l)
102-
{
103-
auto fine = l.current();
104-
auto crse = l.coarse();
105-
auto trns = l.transfer();
106-
107-
perform_sweeps(l.level);
108-
109-
if (l == this->finest() && fine->converged()) {
110-
this->comm->status->set_converged(true);
111-
}
112-
113-
fine->send(comm, tag(l), false);
114-
115-
trns->restrict(crse, fine, true);
116-
117-
trns->fas(this->get_time_step(), crse, fine);
118-
crse->save();
119-
120-
return l - 1;
121-
}
41+
virtual LevelIter cycle_down(LevelIter l);
12242

12343
/**
12444
* Cycle up: interpolate coarse correction to fine, sweep on
@@ -128,125 +48,31 @@ namespace pfasst
12848
* level, we don't perform a sweep. In this case the only
12949
* operation that is performed here is interpolation.
13050
*/
131-
LevelIter cycle_up(LevelIter l)
132-
{
133-
auto fine = l.current();
134-
auto crse = l.coarse();
135-
auto trns = l.transfer();
136-
137-
trns->interpolate(fine, crse, true);
138-
139-
if (this->comm->status->previous_is_iterating()) {
140-
fine->recv(comm, tag(l), false);
141-
trns->interpolate_initial(fine, crse);
142-
}
143-
144-
if (l < this->finest()) {
145-
perform_sweeps(l.level);
146-
}
147-
148-
return l + 1;
149-
}
51+
virtual LevelIter cycle_up(LevelIter l);
15052

15153
/**
15254
* Cycle bottom: sweep on the current (coarsest) level.
15355
*/
154-
LevelIter cycle_bottom(LevelIter l)
155-
{
156-
auto crse = l.current();
157-
158-
if (this->comm->status->previous_is_iterating()) {
159-
crse->recv(comm, tag(l), true);
160-
}
161-
this->comm->status->recv();
162-
this->perform_sweeps(l.level);
163-
crse->send(comm, tag(l), true);
164-
this->comm->status->send();
165-
return l + 1;
166-
}
56+
virtual LevelIter cycle_bottom(LevelIter l);
16757

16858
/**
16959
* Perform an MLSDC V-cycle.
17060
*/
171-
LevelIter cycle_v(LevelIter l)
172-
{
173-
if (l.level == 0) {
174-
l = cycle_bottom(l);
175-
} else {
176-
l = cycle_down(l);
177-
l = cycle_v(l);
178-
l = cycle_up(l);
179-
}
180-
return l;
181-
}
61+
virtual LevelIter cycle_v(LevelIter l);
18262

18363
/**
18464
* Predictor: restrict initial down, preform coarse sweeps, return to finest.
18565
*/
186-
void predictor()
187-
{
188-
this->get_finest()->spread();
189-
190-
// restrict fine initial condition
191-
for (auto l = this->finest() - 1; l >= this->coarsest(); --l) {
192-
auto crse = l.current();
193-
auto fine = l.fine();
194-
auto trns = l.transfer();
195-
trns->restrict_initial(crse, fine);
196-
crse->spread();
197-
crse->save();
198-
}
199-
200-
// perform sweeps on the coarse level based on rank
201-
predict = true;
202-
auto crse = this->coarsest().current();
203-
for (int nstep = 0; nstep < comm->rank() + 1; nstep++) {
204-
// XXX: set iteration and step?
205-
perform_sweeps(0);
206-
if (nstep < comm->rank()) {
207-
crse->advance();
208-
}
209-
}
210-
211-
// return to finest level, sweeping as we go
212-
for (auto l = this->coarsest() + 1; l <= this->finest(); ++l) {
213-
auto crse = l.coarse();
214-
auto fine = l.current();
215-
auto trns = l.transfer();
216-
217-
trns->interpolate(fine, crse, true);
218-
if (l < this->finest()) {
219-
perform_sweeps(l.level);
220-
}
221-
}
222-
}
223-
224-
void broadcast()
225-
{
226-
this->get_finest()->broadcast(comm);
227-
}
228-
229-
int tag(LevelIter l)
230-
{
231-
return l.level * 10000 + this->get_iteration() + 10;
232-
}
233-
234-
void post()
235-
{
236-
this->comm->status->post();
237-
for (auto l = this->coarsest() + 1; l <= this->finest(); ++l) {
238-
l.current()->post(comm, tag(l));
239-
}
240-
}
66+
virtual void predictor();
24167

242-
public:
243-
void set_comm(ICommunicator* comm)
244-
{
245-
this->comm = comm;
246-
}
68+
virtual void broadcast();
24769

248-
};
70+
virtual int tag(LevelIter l);
24971

72+
virtual void post();
73+
};
25074
} // ::pfasst
25175

76+
#include "pfasst_impl.hpp"
77+
25278
#endif

0 commit comments

Comments
 (0)