Skip to content

Commit c05d7fa

Browse files
author
Matthew Emmett
committed
Merge branch 'development' into feature/test-mpi
2 parents beec217 + efba4be commit c05d7fa

File tree

6 files changed

+119
-78
lines changed

6 files changed

+119
-78
lines changed

examples/advection_diffusion/mpi_pfasst.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ error_map run_mpi_pfasst()
5959

6060
pf.set_comm(&comm);
6161
pf.set_duration(0.0, nsteps * dt, dt, niters);
62+
pf.set_nsweeps({2, 1});
6263
pf.run();
6364

6465
auto fine = pf.get_finest<AdvectionDiffusionSweeper<>>();

include/pfasst/encap/encap_sweeper.hpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ namespace pfasst
2525

2626
public:
2727

28+
virtual void spread()
29+
{
30+
for (size_t m = 1; m < nodes.size(); m++) {
31+
this->get_state(m)->copy(this->get_state(0));
32+
}
33+
}
34+
2835
virtual void post(ICommunicator* comm, int tag)
2936
{
3037
this->get_state(0)->post(comm, tag);
@@ -42,9 +49,9 @@ namespace pfasst
4249

4350
virtual void broadcast(ICommunicator* comm)
4451
{
45-
if (comm->rank() == comm->size() - 1) {
46-
this->get_state(0)->copy(this->get_state(this->get_nodes().size() - 1));
47-
}
52+
if (comm->rank() == comm->size() - 1) {
53+
this->get_state(0)->copy(this->get_state(this->get_nodes().size() - 1));
54+
}
4855
this->get_state(0)->broadcast(comm);
4956
}
5057

@@ -112,6 +119,23 @@ namespace pfasst
112119
}
113120
};
114121

122+
123+
template<typename time>
124+
EncapSweeper<time>& as_encap_sweeper(shared_ptr<ISweeper<time>> x)
125+
{
126+
shared_ptr<EncapSweeper<time>> y = dynamic_pointer_cast<EncapSweeper<time>>(x);
127+
assert(y);
128+
return *y.get();
129+
}
130+
131+
template<typename time>
132+
const EncapSweeper<time>& as_encap_sweeper(shared_ptr<const ISweeper<time>> x)
133+
{
134+
shared_ptr<const EncapSweeper<time>> y = dynamic_pointer_cast<const EncapSweeper<time>>(x);
135+
assert(y);
136+
return *y.get();
137+
}
138+
115139
}
116140

117141
}

include/pfasst/encap/poly_interp.hpp

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,79 +27,74 @@ namespace pfasst
2727
public:
2828
virtual ~PolyInterpMixin() { }
2929

30-
virtual void interpolate(shared_ptr<ISweeper<time>> dst,
31-
shared_ptr<const ISweeper<time>> src,
32-
bool interp_delta_from_initial,
33-
bool interp_initial,
34-
bool interp_initial_only)
30+
virtual void interpolate_initial(shared_ptr<ISweeper<time>> dst,
31+
shared_ptr<const ISweeper<time>> src)
3532
{
36-
shared_ptr<EncapSweeper<time>> fine = dynamic_pointer_cast<EncapSweeper<time>>(dst);
37-
assert(fine);
38-
shared_ptr<const EncapSweeper<time>> crse = dynamic_pointer_cast<const EncapSweeper<time>>(src);
39-
assert(crse);
33+
auto& fine = as_encap_sweeper(dst);
34+
auto& crse = as_encap_sweeper(src);
35+
36+
auto crse_factory = crse.get_factory();
37+
auto fine_factory = fine.get_factory();
38+
39+
auto crse_delta = crse_factory->create(solution);
40+
this->restrict(crse_delta, fine.get_state(0));
41+
crse_delta->saxpy(-1.0, crse.get_state(0));
42+
43+
auto fine_delta = fine_factory->create(solution);
44+
this->interpolate(fine_delta, crse_delta);
45+
fine.get_state(0)->saxpy(-1.0, fine_delta);
4046

41-
this->interpolate(fine, crse, interp_delta_from_initial, interp_initial, interp_initial_only);
47+
fine.evaluate(0);
4248
}
4349

44-
virtual void interpolate(shared_ptr<EncapSweeper<time>> fine,
45-
shared_ptr<const EncapSweeper<time>> crse,
46-
bool interp_delta_from_initial,
47-
bool interp_initial,
48-
bool interp_initial_only)
50+
virtual void interpolate(shared_ptr<ISweeper<time>> dst,
51+
shared_ptr<const ISweeper<time>> src,
52+
bool interp_initial)
4953
{
54+
auto& fine = as_encap_sweeper(dst);
55+
auto& crse = as_encap_sweeper(src);
56+
5057
if (tmat.size1() == 0) {
51-
tmat = pfasst::compute_interp<time>(fine->get_nodes(), crse->get_nodes());
58+
tmat = pfasst::compute_interp<time>(fine.get_nodes(), crse.get_nodes());
5259
}
5360

54-
size_t nfine = fine->get_nodes().size();
55-
size_t ncrse = crse->get_nodes().size();
56-
57-
auto crse_factory = crse->get_factory();
58-
auto fine_factory = fine->get_factory();
59-
60-
if (interp_initial_only) {
61-
auto crse_delta = crse_factory->create(solution);
62-
this->restrict(crse_delta, fine->get_state(0));
63-
crse_delta->saxpy(-1.0, crse->get_state(0));
61+
size_t nfine = fine.get_nodes().size();
62+
size_t ncrse = crse.get_nodes().size();
6463

65-
auto fine_delta = fine_factory->create(solution);
66-
this->interpolate(fine_delta, crse_delta);
67-
fine->get_state(0)->saxpy(-1.0, fine_delta);
68-
69-
fine->evaluate(0);
70-
return;
71-
}
64+
auto crse_factory = crse.get_factory();
65+
auto fine_factory = fine.get_factory();
7266

7367
EncapVecT fine_state(nfine), fine_delta(ncrse);
7468

75-
for (size_t m = 0; m < nfine; m++) { fine_state[m] = fine->get_state(m); }
69+
for (size_t m = 0; m < nfine; m++) { fine_state[m] = fine.get_state(m); }
7670
for (size_t m = 0; m < ncrse; m++) { fine_delta[m] = fine_factory->create(solution); }
7771

78-
if (interp_delta_from_initial) {
79-
for (size_t m = 1; m < nfine; m++) {
80-
fine_state[m]->copy(fine_state[0]);
81-
}
82-
}
72+
// if (interp_delta_from_initial) {
73+
// for (size_t m = 1; m < nfine; m++) {
74+
// fine_state[m]->copy(fine_state[0]);
75+
// }
76+
// }
8377

8478
auto crse_delta = crse_factory->create(solution);
8579
size_t m0 = interp_initial ? 0 : 1;
8680
for (size_t m = m0; m < ncrse; m++) {
87-
crse_delta->copy(crse->get_state(m));
88-
if (interp_delta_from_initial) {
89-
crse_delta->saxpy(-1.0, crse->get_state(0));
90-
} else {
91-
crse_delta->saxpy(-1.0, crse->get_saved_state(m));
92-
}
81+
crse_delta->copy(crse.get_state(m));
82+
// if (interp_delta_from_initial) {
83+
// crse_delta->saxpy(-1.0, crse->get_saved_state(0));
84+
// // crse_delta->saxpy(-1.0, crse->get_state(0));
85+
// } else {
86+
crse_delta->saxpy(-1.0, crse.get_saved_state(m));
87+
// }
9388
interpolate(fine_delta[m], crse_delta);
9489
}
9590

9691
if (!interp_initial) {
9792
fine_delta[0]->zero();
9893
}
9994

100-
fine->get_state(0)->mat_apply(fine_state, 1.0, tmat, fine_delta, false);
95+
fine.get_state(0)->mat_apply(fine_state, 1.0, tmat, fine_delta, false);
10196

102-
for (size_t m = m0; m < nfine; m++) { fine->evaluate(m); }
97+
for (size_t m = m0; m < nfine; m++) { fine.evaluate(m); }
10398
}
10499

105100
virtual void restrict(shared_ptr<ISweeper<time>> dst,

include/pfasst/interfaces.hpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,29 +112,32 @@ namespace pfasst
112112
virtual void predict(bool initial) = 0;
113113

114114
/**
115-
* perform one SDC sweep/iteration.
116-
* Compute a correction and update solution values.
117-
* Note that this function can assume that valid function values exist from a previous
118-
* pfasst::ISweeper::sweep() or pfasst::ISweeper::predict().
115+
* Perform one SDC sweep/iteration.
116+
*
117+
* Compute a correction and update solution values. Note that this function can assume that
118+
* valid function values exist from a previous pfasst::ISweeper::sweep() or
119+
* pfasst::ISweeper::predict().
119120
*/
120121
virtual void sweep() = 0;
121122

122123
/**
123-
* advance from one time step to the next.
124+
* Advance from one time step to the next.
124125
*
125126
* Essentially this means copying the solution and function values from the last node to the
126127
* first node.
127128
*/
128129
virtual void advance() = 0;
129130

130131
/**
131-
* save solutions (and/or function values) at all nodes.
132+
* Save states (and/or function values) at all nodes.
132133
*
133-
* This is typically done in MLSDC/PFASST immediately after a call to restrict.
134-
* The saved states are used to compute deltas during interpolation.
134+
* This is typically done in MLSDC/PFASST immediately after a call to restrict. The saved
135+
* states are used to compute deltas during interpolation.
135136
*/
136137
virtual void save(bool initial_only = false) { (void) initial_only; NotImplementedYet("mlsdc/pfasst"); }
137138

139+
virtual void spread() { NotImplementedYet("pfasst"); }
140+
138141
virtual void post(ICommunicator* /*comm*/, int /*tag*/) { };
139142
virtual void send(ICommunicator* /*comm*/, int /*tag*/, bool /*blocking*/) { NotImplementedYet("pfasst"); }
140143
virtual void recv(ICommunicator* /*comm*/, int /*tag*/, bool /*blocking*/) { NotImplementedYet("pfasst"); }
@@ -151,24 +154,31 @@ namespace pfasst
151154
class ITransfer
152155
{
153156
public:
154-
// XXX: pass level iterator to these routines as well
155157
virtual ~ITransfer() { }
156158

159+
160+
/**
161+
* Interpolate initial condition (in space) from the coarse sweeper to the fine sweeper.
162+
*/
163+
virtual void interpolate_initial(shared_ptr<ISweeper<time>> dst,
164+
shared_ptr<const ISweeper<time>> src)
165+
{
166+
NotImplementedYet("pfasst");
167+
}
168+
169+
157170
/**
158-
* interpolate, in time and space, from the coarse sweeper to the fine sweeper.
159-
* @param[in] interp_delta_from_initial
160-
* `true` if the delta computed at each node should be relative to the initial condition.
171+
* Interpolate, in time and space, from the coarse sweeper to the fine sweeper.
161172
* @param[in] interp_initial
162173
* `true` if a delta for the initial condtion should also be computed (PFASST).
163174
*/
164175
virtual void interpolate(shared_ptr<ISweeper<time>> dst,
165176
shared_ptr<const ISweeper<time>> src,
166-
bool interp_delta_from_initial = false,
167-
bool interp_initial = false,
168-
bool interp_initial_only = false) = 0;
177+
bool interp_initial = false) = 0;
178+
169179

170180
/**
171-
* restrict, in time and space, from the fine sweeper to the coarse sweeper.
181+
* Restrict, in time and space, from the fine sweeper to the coarse sweeper.
172182
* @param[in] restrict_initial
173183
* `true` if the initial condition should also be restricted.
174184
*/
@@ -178,7 +188,7 @@ namespace pfasst
178188
bool restrict_initial_only = false) = 0;
179189

180190
/**
181-
* compute FAS correction between the coarse and fine sweepers.
191+
* Compute FAS correction between the coarse and fine sweepers.
182192
*/
183193
virtual void fas(time dt, shared_ptr<ISweeper<time>> dst,
184194
shared_ptr<const ISweeper<time>> src) = 0;

include/pfasst/mlsdc.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,16 @@ namespace pfasst
4646
nsweeps.resize(this->nlevels());
4747
fill(nsweeps.begin(), nsweeps.end(), 1);
4848
for (auto leviter = this->coarsest(); leviter <= this->finest(); ++leviter) {
49-
leviter.current()->set_controller(this);
49+
leviter.current()->set_controller(this);
5050
leviter.current()->setup(leviter != this->finest());
5151
}
5252
}
5353

54+
void set_nsweeps(vector<size_t> nsweeps)
55+
{
56+
this->nsweeps = nsweeps;
57+
}
58+
5459
/**
5560
* evolve ODE using MLSDC.
5661
*
@@ -64,7 +69,7 @@ namespace pfasst
6469
initial = this->get_step() == 0; // only evaluate node 0 functions on first step
6570

6671
// iterate by performing v-cycles
67-
for (this->set_iteration(0); this->get_iteration() < this->get_max_iterations(); this->advance_iteration()) {
72+
for (this->set_iteration(0); this->get_iteration() < this->get_max_iterations(); this->advance_iteration()) {
6873
cycle_v(this->finest());
6974
}
7075

include/pfasst/pfasst.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace pfasst
1919

2020
typedef typename pfasst::Controller<time>::LevelIter LevelIter;
2121

22-
bool predict, initial;
22+
bool predict; //<! whether to use a 'predict' sweep
23+
bool initial; //<! whether we're sweeping from a new initial condition
2324

2425
void perform_sweeps(size_t level)
2526
{
@@ -46,13 +47,16 @@ namespace pfasst
4647
*/
4748
void predictor()
4849
{
50+
this->get_finest()->spread();
51+
4952
// restrict fine initial condition
5053
for (auto l = this->finest() - 1; l >= this->coarsest(); --l) {
5154
auto crse = l.current();
5255
auto fine = l.fine();
5356
auto trns = l.transfer();
54-
trns->restrict(crse, fine, false, true);
55-
crse->save(true);
57+
trns->restrict(crse, fine, true, true);
58+
crse->spread();
59+
crse->save();
5660
}
5761

5862
// perform sweeps on the coarse level based on rank
@@ -63,7 +67,9 @@ namespace pfasst
6367
// XXX: set iteration?
6468

6569
perform_sweeps(0);
66-
crse->advance();
70+
if (nstep < comm->rank()) {
71+
crse->advance();
72+
}
6773
}
6874

6975
// return to finest level, sweeping as we go
@@ -72,7 +78,7 @@ namespace pfasst
7278
auto fine = l.current();
7379
auto trns = l.transfer();
7480

75-
trns->interpolate(fine, crse, true, true);
81+
trns->interpolate(fine, crse, true);
7682
if (l < this->finest()) {
7783
perform_sweeps(l.level);
7884
}
@@ -128,9 +134,9 @@ namespace pfasst
128134

129135
cycle_v(this->finest() - 1);
130136

131-
trns->interpolate(fine, crse, false, true, false);
137+
trns->interpolate(fine, crse, true);
132138
fine->recv(comm, tag, false);
133-
trns->interpolate(fine, crse, false, false, true);
139+
trns->interpolate_initial(fine, crse);
134140
// XXX: call interpolate_q0(pf,F, G)
135141
}
136142

@@ -176,12 +182,12 @@ namespace pfasst
176182
auto crse = l.coarse();
177183
auto trns = l.transfer();
178184

179-
trns->interpolate(fine, crse, false, true, false);
185+
trns->interpolate(fine, crse, true);
180186

181187
int tag = l.level * 10000 + this->get_iteration() + 10;
182188
fine->recv(comm, tag, false);
183189
// XXX call interpolate_q0(pf,F, G)
184-
trns->interpolate(fine, crse, false, false, true);
190+
trns->interpolate_initial(fine, crse);
185191

186192
if (l < this->finest()) {
187193
perform_sweeps(l.level);

0 commit comments

Comments
 (0)