Skip to content

Commit f053346

Browse files
committed
internals: splitting up decls/defs for mlsdc controller
towards fixing #84
1 parent 3d2ec2b commit f053346

File tree

2 files changed

+144
-121
lines changed

2 files changed

+144
-121
lines changed

include/pfasst/mlsdc.hpp

Lines changed: 11 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
1-
/*
2-
* Multi-level SDC controller.
3-
*/
4-
51
#ifndef _PFASST_MLSDC_HPP_
62
#define _PFASST_MLSDC_HPP_
73

8-
#include <algorithm>
9-
#include <iostream>
104
#include <vector>
5+
using namespace std;
116

127
#include "controller.hpp"
13-
#include "logging.hpp"
148

15-
using namespace std;
169

1710
namespace pfasst
1811
{
19-
2012
/**
2113
* Multilevel SDC controller.
2214
*/
@@ -33,21 +25,7 @@ namespace pfasst
3325
bool initial; //<! whether we're sweeping from a new initial condition
3426
bool converged; //<! whether we've converged
3527

36-
void perform_sweeps(size_t level)
37-
{
38-
auto sweeper = this->get_level(level);
39-
CLOG(INFO, "Controller") << "on level " << level + 1 << "/" << this->nlevels();
40-
for (size_t s = 0; s < this->nsweeps[level]; s++) {
41-
if (predict) {
42-
sweeper->predict(initial & predict);
43-
sweeper->post_predict();
44-
predict = false;
45-
} else {
46-
sweeper->sweep();
47-
sweeper->post_sweep();
48-
}
49-
}
50-
}
28+
virtual void perform_sweeps(size_t level);
5129

5230
public:
5331
/**
@@ -56,56 +34,15 @@ namespace pfasst
5634
* This assumes that the user has set initial conditions on the finest level.
5735
* Currently uses a fixed number of iterations per step.
5836
*/
59-
void run()
60-
{
61-
for (; this->get_time() < this->get_end_time(); this->advance_time()) {
62-
predict = true;
63-
initial = true;
64-
converged = false;
65-
66-
for (this->set_iteration(0);
67-
this->get_iteration() < this->get_max_iterations() && !converged;
68-
this->advance_iteration()) {
69-
cycle_v(this->finest());
70-
initial = false;
71-
}
72-
73-
perform_sweeps(this->finest().level);
74-
75-
for (auto l = this->finest(); l >= this->coarsest(); --l) {
76-
l.current()->post_step();
77-
}
78-
79-
if (this->get_time() + this->get_time_step() < this->get_end_time()) {
80-
this->get_finest()->advance();
81-
}
82-
}
83-
}
37+
virtual void setup() override;
38+
virtual void set_nsweeps(vector<size_t> nsweeps);
39+
virtual void run();
8440

8541
private:
8642
/**
8743
* Cycle down: sweep on current (fine), restrict to coarse.
8844
*/
89-
LevelIter cycle_down(LevelIter l)
90-
{
91-
auto fine = l.current();
92-
auto crse = l.coarse();
93-
auto trns = l.transfer();
94-
95-
perform_sweeps(l.level);
96-
97-
if (l == this->finest() && fine->converged()) {
98-
converged = true;
99-
return l;
100-
}
101-
102-
CVLOG(1, "Controller") << "Cycle down onto level " << l.level << "/" << this->nlevels();
103-
trns->restrict(crse, fine, initial);
104-
trns->fas(this->get_time_step(), crse, fine);
105-
crse->save();
106-
107-
return l - 1;
108-
}
45+
virtual LevelIter cycle_down(LevelIter l);
10946

11047
/**
11148
* Cycle up: interpolate coarse correction to fine, sweep on current (fine).
@@ -114,67 +51,20 @@ namespace pfasst
11451
* sweep.
11552
* In this case the only operation that is performed here is interpolation.
11653
*/
117-
LevelIter cycle_up(LevelIter l)
118-
{
119-
auto fine = l.current();
120-
auto crse = l.coarse();
121-
auto trns = l.transfer();
122-
123-
CVLOG(1, "Controller") << "Cycle up onto level " << l.level + 1 << "/" << this->nlevels();
124-
trns->interpolate(fine, crse);
125-
126-
if (l < this->finest()) {
127-
perform_sweeps(l.level);
128-
}
129-
130-
return l + 1;
131-
}
54+
virtual LevelIter cycle_up(LevelIter l);
13255

13356
/**
13457
* Cycle bottom: sweep on the current (coarsest) level.
13558
*/
136-
LevelIter cycle_bottom(LevelIter l)
137-
{
138-
perform_sweeps(l.level);
139-
return l + 1;
140-
}
59+
virtual LevelIter cycle_bottom(LevelIter l);
14160

14261
/**
14362
* Perform an MLSDC V-cycle.
14463
*/
145-
LevelIter cycle_v(LevelIter l)
146-
{
147-
if (l.level == 0) {
148-
l = cycle_bottom(l);
149-
} else {
150-
l = cycle_down(l);
151-
if (converged) {
152-
return l;
153-
}
154-
l = cycle_v(l);
155-
l = cycle_up(l);
156-
}
157-
return l;
158-
}
159-
160-
public:
161-
virtual void setup() override
162-
{
163-
nsweeps.resize(this->nlevels());
164-
fill(nsweeps.begin(), nsweeps.end(), 1);
165-
for (auto leviter = this->coarsest(); leviter <= this->finest(); ++leviter) {
166-
leviter.current()->set_controller(this);
167-
leviter.current()->setup(leviter != this->finest());
168-
}
169-
}
170-
171-
void set_nsweeps(vector<size_t> nsweeps)
172-
{
173-
this->nsweeps = nsweeps;
174-
}
175-
64+
virtual LevelIter cycle_v(LevelIter l);
17665
};
177-
17866
} // ::pfasst
17967

68+
#include "mlsdc_impl.hpp"
69+
18070
#endif

include/pfasst/mlsdc_impl.hpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#include "mlsdc.hpp"
2+
3+
#include <algorithm>
4+
using namespace std;
5+
6+
#include "logging.hpp"
7+
8+
9+
namespace pfasst
10+
{
11+
template<typename time>
12+
void MLSDC<time>::perform_sweeps(size_t level)
13+
{
14+
auto sweeper = this->get_level(level);
15+
CLOG(INFO, "Controller") << "on level " << level + 1 << "/" << this->nlevels();
16+
for (size_t s = 0; s < this->nsweeps[level]; s++) {
17+
if (predict) {
18+
sweeper->predict(initial & predict);
19+
sweeper->post_predict();
20+
predict = false;
21+
} else {
22+
sweeper->sweep();
23+
sweeper->post_sweep();
24+
}
25+
}
26+
}
27+
28+
template<typename time>
29+
void MLSDC<time>::setup()
30+
{
31+
nsweeps.resize(this->nlevels());
32+
fill(nsweeps.begin(), nsweeps.end(), 1);
33+
for (auto leviter = this->coarsest(); leviter <= this->finest(); ++leviter) {
34+
leviter.current()->set_controller(this);
35+
leviter.current()->setup(leviter != this->finest());
36+
}
37+
}
38+
39+
template<typename time>
40+
void MLSDC<time>::set_nsweeps(vector<size_t> nsweeps)
41+
{
42+
this->nsweeps = nsweeps;
43+
}
44+
45+
template<typename time>
46+
void MLSDC<time>::run()
47+
{
48+
for (; this->get_time() < this->get_end_time(); this->advance_time()) {
49+
predict = true;
50+
initial = true;
51+
converged = false;
52+
53+
for (this->set_iteration(0);
54+
this->get_iteration() < this->get_max_iterations() && !converged;
55+
this->advance_iteration()) {
56+
cycle_v(this->finest());
57+
initial = false;
58+
}
59+
60+
perform_sweeps(this->finest().level);
61+
62+
for (auto l = this->finest(); l >= this->coarsest(); --l) {
63+
l.current()->post_step();
64+
}
65+
66+
if (this->get_time() + this->get_time_step() < this->get_end_time()) {
67+
this->get_finest()->advance();
68+
}
69+
}
70+
}
71+
72+
template<typename time>
73+
typename MLSDC<time>::LevelIter MLSDC<time>::cycle_down(typename MLSDC<time>::LevelIter l)
74+
{
75+
auto fine = l.current();
76+
auto crse = l.coarse();
77+
auto trns = l.transfer();
78+
79+
perform_sweeps(l.level);
80+
81+
if (l == this->finest() && fine->converged()) {
82+
converged = true;
83+
return l;
84+
}
85+
86+
CVLOG(1, "Controller") << "Cycle down onto level " << l.level << "/" << this->nlevels();
87+
trns->restrict(crse, fine, initial);
88+
trns->fas(this->get_time_step(), crse, fine);
89+
crse->save();
90+
91+
return l - 1;
92+
}
93+
94+
template<typename time>
95+
typename MLSDC<time>::LevelIter MLSDC<time>::cycle_up(typename MLSDC<time>::LevelIter l)
96+
{
97+
auto fine = l.current();
98+
auto crse = l.coarse();
99+
auto trns = l.transfer();
100+
101+
CVLOG(1, "Controller") << "Cycle up onto level " << l.level + 1 << "/" << this->nlevels();
102+
trns->interpolate(fine, crse);
103+
104+
if (l < this->finest()) {
105+
perform_sweeps(l.level);
106+
}
107+
108+
return l + 1;
109+
}
110+
111+
template<typename time>
112+
typename MLSDC<time>::LevelIter MLSDC<time>::cycle_bottom(typename MLSDC<time>::LevelIter l)
113+
{
114+
perform_sweeps(l.level);
115+
return l + 1;
116+
}
117+
118+
template<typename time>
119+
typename MLSDC<time>::LevelIter MLSDC<time>::cycle_v(typename MLSDC<time>::LevelIter l)
120+
{
121+
if (l.level == 0) {
122+
l = cycle_bottom(l);
123+
} else {
124+
l = cycle_down(l);
125+
if (converged) {
126+
return l;
127+
}
128+
l = cycle_v(l);
129+
l = cycle_up(l);
130+
}
131+
return l;
132+
}
133+
} // ::pfasst

0 commit comments

Comments
 (0)