Skip to content

Commit 4fa799c

Browse files
committed
First stab at a class structure.
1 parent c3820f5 commit 4fa799c

File tree

3 files changed

+229
-1
lines changed

3 files changed

+229
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
PFASST
22
======
33

4-
C++ PFASST
4+
Proposed class structure for a C++ PFASST implementation.

src/pfasst-encapsulated.hpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Host based encapsulated base sweeper.
3+
*/
4+
5+
#ifndef _PFASST_ENCAPSULATED_HPP_
6+
#define _PFASST_ENCAPSULATED_HPP_
7+
8+
#include <vector>
9+
10+
#include "pfasst.hpp"
11+
12+
using namespace std;
13+
14+
namespace pfasst {
15+
16+
template<typename T>
17+
vector<T> compute_nodes(unsigned int nnodes, string qtype) {
18+
vector<T> nodes(nnodes);
19+
20+
// ...
21+
22+
return nodes;
23+
}
24+
25+
typedef enum encaptype { solution, function } encaptype;
26+
27+
template<typename T>
28+
class matrix : public vector<T> {
29+
30+
public:
31+
unsigned int n, m;
32+
matrix() { }
33+
matrix(unsigned int n, unsigned int m) {
34+
zeros(n, m);
35+
}
36+
void zeros(unsigned int n, unsigned int m) {
37+
this->n = n; this->m = m;
38+
this->resize(n*m);
39+
// ...
40+
}
41+
T& operator()(unsigned int i, unsigned int j) {
42+
return (*this)[i*m+j];
43+
}
44+
};
45+
46+
//
47+
// encapsulation
48+
//
49+
50+
struct encapsulation {
51+
virtual ~encapsulation() { }
52+
53+
// required for time-parallel communications
54+
virtual unsigned int nbytes() { }
55+
virtual void pack(char *buf) { }
56+
virtual void unpack(char *buf) { }
57+
58+
// required for interp/restrict helpers
59+
virtual void interpolate(const encapsulation *) { }
60+
virtual void restrict(const encapsulation *) { }
61+
62+
// required for host based encap helpers
63+
virtual void setval(double) { }
64+
virtual void copy(const encapsulation *) { }
65+
virtual void mat_apply(encapsulation dst[], double a, matrix m, const encapsulation src[]) { }
66+
};
67+
68+
struct encapsulation_factory {
69+
virtual encapsulation* create(const encaptype) = 0;
70+
};
71+
72+
73+
template<typename T>
74+
class encapsulated_sweeper_mixin : public isweeper {
75+
shared_ptr<vector<T>> nodes;
76+
shared_ptr<encapsulation_factory> encap;
77+
78+
public:
79+
vector<encapsulation*> q;
80+
vector<T>* get_nodes() { return nodes.get(); }
81+
82+
virtual void set_q0(const encapsulation* q0) { }
83+
virtual encapsulation* get_qend() { }
84+
};
85+
86+
template<class T>
87+
class poly_interp_mixin : public T {
88+
virtual void interpolate(const isweeper*) { }
89+
virtual void restrict(const isweeper*) { }
90+
};
91+
92+
template<typename T>
93+
struct vector_encapsulation : public vector<T>, public encapsulation {
94+
vector_encapsulation(int size) : vector<T>(size) { }
95+
virtual unsigned int nbytes() const {
96+
return sizeof(T) * this->size();
97+
}
98+
void setval(double v) {
99+
for (int i=0; i<this->size(); i++)
100+
this->data()[i] = v;
101+
}
102+
// ...
103+
};
104+
105+
template<typename T>
106+
class vector_factory : public pfasst::encapsulation_factory {
107+
int size;
108+
public:
109+
vector_factory(const int size) : size(size) { }
110+
encapsulation* create(const pfasst::encap_type) {
111+
return new vector_encapsulation<T>(size);
112+
}
113+
};
114+
115+
}
116+
117+
#endif

src/pfasst.hpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Interfaces for SDC/MLSDC/PFASST algorithms.
3+
*/
4+
5+
#ifndef _PFASST_HPP_
6+
#define _PFASST_HPP_
7+
8+
#include <deque>
9+
#include <exception>
10+
#include <iostream>
11+
#include <iterator>
12+
#include <memory>
13+
#include <string>
14+
15+
using namespace std;
16+
17+
namespace pfasst {
18+
19+
//
20+
// sdc sweeper interface
21+
//
22+
23+
struct isweeper {
24+
virtual void setup() { }
25+
virtual ~isweeper() { }
26+
27+
// required for all sdc schemes
28+
virtual void sweep(double t, double dt) = 0;
29+
virtual void predict(double t, double dt) = 0;
30+
31+
// required for multi-level sdc and pfasst schemes
32+
virtual void interpolate(const isweeper*) { }
33+
virtual void restrict(const isweeper*) { }
34+
35+
// required for pfasst schemes
36+
virtual void post() { }
37+
virtual void send() { }
38+
virtual void recv() { }
39+
};
40+
41+
//
42+
// pfasst controller
43+
//
44+
45+
class pfasst {
46+
deque<shared_ptr<isweeper>> levels;
47+
48+
public:
49+
int nstep, niter;
50+
double dt, t;
51+
52+
void add_level(isweeper *sweeper, bool coarse) {
53+
if (coarse)
54+
levels.push_front(shared_ptr<isweeper>(sweeper));
55+
else
56+
levels.push_back(shared_ptr<isweeper>(sweeper));
57+
}
58+
59+
template<typename R=isweeper> R* get_level(int level) {
60+
return dynamic_cast<R*>(levels[level].get());
61+
}
62+
63+
int nlevels() { return levels.size(); }
64+
void setup();
65+
void run();
66+
67+
void predictor();
68+
void iteration();
69+
70+
struct leveliter {
71+
int level;
72+
pfasst *pf;
73+
74+
leveliter(int level, pfasst *pf) : level(level), pf(pf) {}
75+
76+
template<typename R=isweeper> R* current() {
77+
return pf->get_level<R>(level);
78+
}
79+
template<typename R=isweeper> R* fine() {
80+
return pf->get_level<R>(level+1);
81+
}
82+
template<typename R=isweeper> R* coarse() {
83+
return pf->get_level<R>(level-1);
84+
}
85+
86+
isweeper *operator*() { return current(); }
87+
bool operator==(leveliter i) { return level == i.level; }
88+
bool operator!=(leveliter i) { return level != i.level; }
89+
bool operator<=(leveliter i) { return level <= i.level; }
90+
bool operator>=(leveliter i) { return level >= i.level; }
91+
bool operator< (leveliter i) { return level < i.level; }
92+
bool operator> (leveliter i) { return level > i.level; }
93+
leveliter operator- (int i) { return leveliter(level-1, pf); }
94+
leveliter operator+ (int i) { return leveliter(level+1, pf); }
95+
void operator++() { level++; }
96+
void operator--() { level--; }
97+
};
98+
99+
leveliter finest() { return leveliter(nlevels()-1, this); }
100+
leveliter coarsest() { return leveliter(0, this); }
101+
102+
leveliter cycle_down(leveliter levels, double t, double dt);
103+
leveliter cycle_up(leveliter levels, double t, double dt);
104+
leveliter cycle_bottom(leveliter levels, double t, double dt);
105+
leveliter cycle_top(leveliter levels, double t, double dt);
106+
leveliter cycle_v(leveliter levels, double t, double dt);
107+
};
108+
109+
}
110+
111+
#endif

0 commit comments

Comments
 (0)