Skip to content

Commit 8e8bdc2

Browse files
committed
Break apart pfasst.hpp into pfasst-interfaces.hpp and pfasst-controller.hpp.
Also add pfasst-imex.hpp, pfasst-encapsulated.hpp, pfasst-sdc.hpp and pfasst-pfasst.hpp.
1 parent 0bb33e6 commit 8e8bdc2

File tree

8 files changed

+475
-228
lines changed

8 files changed

+475
-228
lines changed

src/pfasst-controller.hpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Base controller (see also SDC, MLSDC, and PFASST controllers).
3+
*/
4+
5+
#ifndef _PFASST_CONTROLLER_HPP_
6+
#define _PFASST_CONTROLLER_HPP_
7+
8+
#include "pfasst-interfaces.hpp"
9+
10+
namespace pfasst {
11+
12+
template<typename timeT>
13+
class Controller {
14+
protected:
15+
deque<shared_ptr<ISweeper>> levels;
16+
17+
int nsteps, niters;
18+
timeT dt;
19+
20+
public:
21+
22+
void setup() {
23+
for (auto l=coarsest(); l<=finest(); ++l) {
24+
l.current()->setup();
25+
}
26+
}
27+
28+
void set_duration(timeT dt, int nsteps, int niters) {
29+
this->dt = dt; this->nsteps = nsteps; this->niters = niters;
30+
}
31+
32+
void add_level(ISweeper *sweeper, bool coarse=true) {
33+
if (coarse)
34+
levels.push_front(shared_ptr<ISweeper>(sweeper));
35+
else
36+
levels.push_back(shared_ptr<ISweeper>(sweeper));
37+
}
38+
39+
template<typename R=ISweeper> R* get_level(int level) {
40+
return dynamic_cast<R*>(levels[level].get());
41+
}
42+
43+
int nlevels() {
44+
return levels.size();
45+
}
46+
47+
struct leveliter {
48+
int level;
49+
Controller *ts;
50+
51+
leveliter(int level, Controller *ts) : level(level), ts(ts) {}
52+
53+
template<typename R=ISweeper> R* current() {
54+
return ts->get_level<R>(level);
55+
}
56+
template<typename R=ISweeper> R* fine() {
57+
return ts->get_level<R>(level+1);
58+
}
59+
template<typename R=ISweeper> R* coarse() {
60+
return ts->get_level<R>(level-1);
61+
}
62+
63+
ISweeper *operator*() { return current(); }
64+
bool operator==(leveliter i) { return level == i.level; }
65+
bool operator!=(leveliter i) { return level != i.level; }
66+
bool operator<=(leveliter i) { return level <= i.level; }
67+
bool operator>=(leveliter i) { return level >= i.level; }
68+
bool operator< (leveliter i) { return level < i.level; }
69+
bool operator> (leveliter i) { return level > i.level; }
70+
leveliter operator- (int i) { return leveliter(level-1, ts); }
71+
leveliter operator+ (int i) { return leveliter(level+1, ts); }
72+
void operator++() { level++; }
73+
void operator--() { level--; }
74+
};
75+
76+
leveliter finest() { return leveliter(nlevels()-1, this); }
77+
leveliter coarsest() { return leveliter(0, this); }
78+
79+
};
80+
81+
}
82+
83+
#endif

src/pfasst-encapsulated.hpp

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,91 +7,69 @@
77

88
#include <vector>
99

10-
#include "pfasst.hpp"
10+
#include "pfasst-interfaces.hpp"
11+
#include "pfasst-quadrature.hpp"
1112

1213
using namespace std;
1314

1415
namespace pfasst {
1516

16-
template<typename T>
17-
vector<T> compute_nodes(unsigned int nnodes, string qtype) {
18-
vector<T> nodes(nnodes);
19-
20-
// ...
21-
22-
return nodes;
23-
}
24-
25-
typedef enum encaptype { solution, function } encaptype;
26-
27-
template<typename T>
28-
class matrix : public vector<T> {
29-
30-
public:
31-
unsigned int n, m;
32-
matrix() { }
33-
matrix(unsigned int n, unsigned int m) {
34-
zeros(n, m);
35-
}
36-
void zeros(unsigned int n, unsigned int m) {
37-
this->n = n; this->m = m;
38-
this->resize(n*m);
39-
// ...
40-
}
41-
T& operator()(unsigned int i, unsigned int j) {
42-
return (*this)[i*m+j];
43-
}
44-
};
17+
typedef enum EncapType { solution, function } EncapType;
4518

4619
//
4720
// encapsulation
4821
//
4922

50-
struct encapsulation {
51-
virtual ~encapsulation() { }
23+
class Encapsulation {
24+
public:
25+
virtual ~Encapsulation() { }
26+
27+
// required for interp/restrict helpers
28+
virtual void interpolate(const Encapsulation *) { }
29+
virtual void restrict(const Encapsulation *) { }
5230

5331
// required for time-parallel communications
54-
virtual unsigned int nbytes() { }
32+
virtual unsigned int nbytes() { return -1; }
5533
virtual void pack(char *buf) { }
5634
virtual void unpack(char *buf) { }
5735

58-
// required for interp/restrict helpers
59-
virtual void interpolate(const encapsulation *) { }
60-
virtual void restrict(const encapsulation *) { }
61-
6236
// required for host based encap helpers
6337
virtual void setval(double) { }
64-
virtual void copy(const encapsulation *) { }
65-
virtual void mat_apply(encapsulation dst[], double a, matrix m, const encapsulation src[]) { }
38+
virtual void copy(const Encapsulation *) { }
39+
virtual void saxpy(double a, const Encapsulation *) { }
40+
// virtual void mat_apply(encapsulation dst[], double a, matrix m, const encapsulation src[]) { }
6641
};
6742

68-
struct encapsulation_factory {
69-
virtual encapsulation* create(const encaptype) = 0;
43+
class EncapsulationFactory {
44+
public:
45+
virtual Encapsulation* create(const EncapType) = 0;
7046
};
7147

72-
7348
template<typename T>
74-
class encapsulated_sweeper_mixin : public isweeper {
75-
shared_ptr<vector<T>> nodes;
76-
shared_ptr<encapsulation_factory> encap;
49+
class EncapsulatedSweeperMixin : public ISweeper {
50+
vector<T> nodes;
51+
shared_ptr<EncapsulationFactory> factory;
7752

7853
public:
79-
vector<encapsulation*> q;
80-
vector<T>* get_nodes() { return nodes.get(); }
54+
void set_nodes(vector<T> nodes) { this->nodes = nodes; }
55+
const vector<T> get_nodes() const { return nodes; }
56+
57+
void set_factory(EncapsulationFactory* factory) { this->factory = shared_ptr<EncapsulationFactory>(factory); }
58+
EncapsulationFactory* get_factory() const { return factory.get(); }
8159

82-
virtual void set_q0(const encapsulation* q0) { }
83-
virtual encapsulation* get_qend() { }
60+
virtual void set_q0(const Encapsulation* q0) { throw NotImplementedYet("sweeper"); }
61+
virtual Encapsulation* get_qend() { throw NotImplementedYet("sweeper"); return NULL; }
8462
};
8563

8664
template<class T>
87-
class poly_interp_mixin : public T {
88-
virtual void interpolate(const isweeper*) { }
89-
virtual void restrict(const isweeper*) { }
65+
class PolyInterpMixin : public T {
66+
virtual void interpolate(const ISweeper*) { }
67+
virtual void restrict(const ISweeper*) { }
9068
};
9169

9270
template<typename T>
93-
struct vector_encapsulation : public vector<T>, public encapsulation {
94-
vector_encapsulation(int size) : vector<T>(size) { }
71+
struct VectorEncapsulation : public vector<T>, public Encapsulation {
72+
VectorEncapsulation(int size) : vector<T>(size) { }
9573
virtual unsigned int nbytes() const {
9674
return sizeof(T) * this->size();
9775
}
@@ -103,12 +81,13 @@ namespace pfasst {
10381
};
10482

10583
template<typename T>
106-
class vector_factory : public pfasst::encapsulation_factory {
84+
class VectorFactory : public EncapsulationFactory {
10785
int size;
10886
public:
109-
vector_factory(const int size) : size(size) { }
110-
encapsulation* create(const pfasst::encap_type) {
111-
return new vector_encapsulation<T>(size);
87+
int dofs() { return size; }
88+
VectorFactory(const int size) : size(size) { }
89+
Encapsulation* create(const EncapType) {
90+
return new VectorEncapsulation<T>(size);
11291
}
11392
};
11493

src/pfasst-imex.hpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
2+
#ifndef _PFASST_IMEX_HPP_
3+
#define _PFASST_IMEX_HPP_
4+
5+
#include <iostream>
6+
7+
#include "pfasst-encapsulated.hpp"
8+
#include "pfasst-quadrature.hpp"
9+
10+
using namespace std;
11+
12+
namespace pfasst {
13+
14+
template<typename timeT>
15+
class IMEX : public EncapsulatedSweeperMixin<timeT> {
16+
vector<Encapsulation*> Q, S, Fe, Fi;
17+
matrix<timeT> Smat, SEmat, SImat;
18+
19+
public:
20+
21+
~IMEX() {
22+
for (int m=0; m<Q.size(); m++) delete Q[m];
23+
for (int m=0; m<S.size(); m++) delete S[m];
24+
for (int m=0; m<Fe.size(); m++) delete Fe[m];
25+
for (int m=0; m<Fi.size(); m++) delete Fi[m];
26+
}
27+
28+
void set_q0(Encapsulation *q0)
29+
{
30+
Q[0]->copy(q0);
31+
}
32+
33+
void setup() {
34+
auto nodes = this->get_nodes();
35+
36+
Smat = compute_quadrature(nodes, nodes, 's');
37+
38+
SEmat = Smat; SImat = Smat;
39+
for (int m=0; m<nodes.size()-1; m++) {
40+
timeT dt = nodes[m+1] - nodes[m];
41+
SEmat(m, m) -= dt;
42+
SImat(m, m+1) -= dt;
43+
}
44+
45+
for (int m=0; m<nodes.size(); m++) {
46+
this->Q.push_back(this->get_factory()->create(solution));
47+
this->Fe.push_back(this->get_factory()->create(function));
48+
this->Fi.push_back(this->get_factory()->create(function));
49+
}
50+
51+
for (int m=0; m<nodes.size()-1; m++) {
52+
S.push_back(this->get_factory()->create(solution));
53+
}
54+
}
55+
56+
virtual void integrate(timeT t0, timeT dt0)
57+
{
58+
throw NotImplementedYet("imex integrate");
59+
}
60+
virtual void residual(timeT t0, timeT dt0)
61+
{
62+
throw NotImplementedYet("imex residual");
63+
}
64+
65+
virtual void sweep(timeT t0, timeT dt0)
66+
{
67+
const auto nodes = this->get_nodes();
68+
const int nnodes = nodes.size();
69+
70+
// integrate
71+
for (int n=0; n<nnodes-1; n++) {
72+
this->S[n]->setval(0.0);
73+
for (int m=0; m<nnodes; m++) {
74+
this->S[n]->saxpy(dt0 * this->SEmat(n,m), this->Fe[m]);
75+
this->S[n]->saxpy(dt0 * this->SImat(n,m), this->Fi[m]);
76+
}
77+
}
78+
79+
// sweep
80+
Encapsulation *rhs = this->get_factory()->create(solution);
81+
82+
timeT t = t0;
83+
for (int m=0; m<nnodes-1; m++) {
84+
timeT dt = dt0 * ( nodes[m+1] - nodes[m] );
85+
86+
rhs->copy(this->Q[m]);
87+
rhs->saxpy(dt, this->Fe[m]);
88+
rhs->saxpy(1.0, this->S[m]);
89+
this->f2comp(this->Fi[m+1], this->Q[m+1], t, dt, rhs);
90+
this->f1eval(this->Fe[m+1], this->Q[m+1], t+dt);
91+
92+
t += dt;
93+
}
94+
95+
delete rhs;
96+
}
97+
98+
virtual void predict(timeT t0, timeT dt0) {
99+
const auto nodes = this->get_nodes();
100+
const int nnodes = nodes.size();
101+
102+
this->f1eval(this->Fe[0], this->Q[0], t0);
103+
this->f2eval(this->Fi[0], this->Q[0], t0);
104+
105+
Encapsulation *rhs = this->get_factory()->create(solution);
106+
107+
timeT t = t0;
108+
for (int m=0; m<nnodes-1; m++) {
109+
timeT dt = dt0 * ( nodes[m+1] - nodes[m] );
110+
rhs->copy(this->Q[m]);
111+
rhs->saxpy(dt, this->Fe[m]);
112+
this->f2comp(this->Fi[m+1], this->Q[m+1], t, dt, rhs);
113+
this->f1eval(this->Fe[m+1], this->Q[m+1], t+dt);
114+
115+
t += dt;
116+
}
117+
118+
delete rhs;
119+
}
120+
121+
virtual void f1eval(Encapsulation *F, Encapsulation *Q, timeT t)
122+
{
123+
throw NotImplementedYet("imex (f1eval)");
124+
}
125+
126+
virtual void f2eval(Encapsulation *F, Encapsulation *Q, timeT t)
127+
{
128+
throw NotImplementedYet("imex (f2eval)");
129+
}
130+
131+
virtual void f2comp(Encapsulation *F, Encapsulation *Q, timeT t, timeT dt, Encapsulation* rhs)
132+
{
133+
throw NotImplementedYet("imex (f2comp)");
134+
}
135+
136+
};
137+
138+
}
139+
140+
#endif

0 commit comments

Comments
 (0)