Skip to content

Commit 1fa228e

Browse files
committed
Split out VectorEncapsulation stuff.
1 parent e36d335 commit 1fa228e

File tree

4 files changed

+153
-131
lines changed

4 files changed

+153
-131
lines changed

examples/ad/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <pfasst.hpp>
1010
#include <pfasst-imex.hpp>
11+
#include <pfasst-vector.hpp>
1112

1213
#include <fftw3.h>
1314

src/pfasst-encapsulated.hpp

Lines changed: 94 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -5,138 +5,106 @@
55
#ifndef _PFASST_ENCAPSULATED_HPP_
66
#define _PFASST_ENCAPSULATED_HPP_
77

8-
#include <algorithm>
98
#include <vector>
109

1110
#include "pfasst-interfaces.hpp"
1211
#include "pfasst-quadrature.hpp"
1312

1413
using namespace std;
1514

16-
namespace pfasst { namespace encap {
17-
18-
typedef enum EncapType { solution, function } EncapType;
19-
20-
//
21-
// encapsulation
22-
//
23-
24-
class Encapsulation {
25-
public:
26-
virtual ~Encapsulation() { }
27-
28-
// required for interp/restrict helpers
29-
virtual void interpolate(const Encapsulation *) {
30-
throw NotImplementedYet("mlsdc/pfasst");
31-
}
32-
virtual void restrict(const Encapsulation *) {
33-
throw NotImplementedYet("mlsdc/pfasst");
34-
}
35-
36-
// required for time-parallel communications
37-
virtual unsigned int nbytes() {
38-
throw NotImplementedYet("pfasst");
39-
}
40-
virtual void send() {
41-
throw NotImplementedYet("pfasst");
42-
}
43-
virtual void recv() {
44-
throw NotImplementedYet("pfasst");
45-
}
46-
47-
// required for host based encap helpers
48-
virtual void setval(double) {
49-
throw NotImplementedYet("encap");
50-
}
51-
virtual void copy(const Encapsulation *) {
52-
throw NotImplementedYet("encap");
53-
}
54-
virtual void saxpy(double a, const Encapsulation *) {
55-
throw NotImplementedYet("encap");
56-
}
57-
// virtual void mat_apply(Encapsulation dst[], double a, matrix m,
58-
// const Encapsulation src[]) {
59-
// throw NotImplementedYet("encap");
60-
// }
61-
};
62-
63-
class EncapsulationFactory {
64-
public:
65-
virtual Encapsulation* create(const EncapType) = 0;
66-
};
67-
68-
template<typename T>
69-
class EncapsulatedSweeperMixin : public ISweeper {
70-
vector<T> nodes;
71-
shared_ptr<EncapsulationFactory> factory;
72-
73-
public:
74-
void set_nodes(vector<T> nodes) {
75-
this->nodes = nodes;
76-
}
77-
const vector<T> get_nodes() const {
78-
return nodes;
79-
}
80-
81-
void set_factory(EncapsulationFactory* factory) {
82-
this->factory = shared_ptr<EncapsulationFactory>(factory);
83-
}
84-
EncapsulationFactory* get_factory() const {
85-
return factory.get();
86-
}
87-
88-
virtual void set_q0(const Encapsulation* q0) {
89-
throw NotImplementedYet("sweeper");
90-
}
91-
virtual Encapsulation* get_qend() {
92-
throw NotImplementedYet("sweeper");
93-
return NULL;
94-
}
95-
96-
virtual void advance() {
97-
throw NotImplementedYet("sweeper");
98-
}
99-
};
100-
101-
template<class T>
102-
class PolyInterpMixin : public T {
103-
virtual void interpolate(const ISweeper*) { }
104-
virtual void restrict(const ISweeper*) { }
105-
};
106-
107-
template<typename scalar>
108-
struct VectorEncapsulation : public vector<scalar>, public Encapsulation {
109-
VectorEncapsulation(int size) : vector<scalar>(size) { }
110-
virtual unsigned int nbytes() const {
111-
return sizeof(scalar) * this->size();
112-
}
113-
void setval(scalar v) {
114-
for (int i=0; i<this->size(); i++)
115-
(*this)[i] = v;
116-
}
117-
void copy(const Encapsulation* X) {
118-
const auto* x = dynamic_cast<const VectorEncapsulation*>(X);
119-
for (int i=0; i<this->size(); i++)
120-
(*this)[i] = (*x)[i];
121-
}
122-
void saxpy(double a, const Encapsulation *X) {
123-
const auto* x = dynamic_cast<const VectorEncapsulation*>(X);
124-
for (int i=0; i<this->size(); i++)
125-
(*this)[i] += a * (*x)[i];
126-
}
127-
};
128-
129-
template<typename T>
130-
class VectorFactory : public EncapsulationFactory {
131-
int size;
132-
public:
133-
int dofs() { return size; }
134-
VectorFactory(const int size) : size(size) { }
135-
Encapsulation* create(const EncapType) {
136-
return new VectorEncapsulation<T>(size);
137-
}
138-
};
139-
140-
} }
15+
namespace pfasst {
16+
namespace encap {
17+
18+
typedef enum EncapType { solution, function } EncapType;
19+
20+
//
21+
// encapsulation
22+
//
23+
24+
class Encapsulation {
25+
public:
26+
virtual ~Encapsulation() { }
27+
28+
// required for interp/restrict helpers
29+
virtual void interpolate(const Encapsulation *) {
30+
throw NotImplementedYet("mlsdc/pfasst");
31+
}
32+
virtual void restrict(const Encapsulation *) {
33+
throw NotImplementedYet("mlsdc/pfasst");
34+
}
35+
36+
// required for time-parallel communications
37+
virtual unsigned int nbytes() {
38+
throw NotImplementedYet("pfasst");
39+
}
40+
virtual void send() {
41+
throw NotImplementedYet("pfasst");
42+
}
43+
virtual void recv() {
44+
throw NotImplementedYet("pfasst");
45+
}
46+
47+
// required for host based encap helpers
48+
virtual void setval(double) {
49+
throw NotImplementedYet("encap");
50+
}
51+
virtual void copy(const Encapsulation *) {
52+
throw NotImplementedYet("encap");
53+
}
54+
virtual void saxpy(double a, const Encapsulation *) {
55+
throw NotImplementedYet("encap");
56+
}
57+
// virtual void mat_apply(Encapsulation dst[], double a, matrix m,
58+
// const Encapsulation src[]) {
59+
// throw NotImplementedYet("encap");
60+
// }
61+
};
62+
63+
class EncapsulationFactory {
64+
public:
65+
virtual Encapsulation* create(const EncapType) = 0;
66+
};
67+
68+
template<typename T>
69+
class EncapsulatedSweeperMixin : public ISweeper {
70+
vector<T> nodes;
71+
shared_ptr<EncapsulationFactory> factory;
72+
73+
public:
74+
void set_nodes(vector<T> nodes) {
75+
this->nodes = nodes;
76+
}
77+
const vector<T> get_nodes() const {
78+
return nodes;
79+
}
80+
81+
void set_factory(EncapsulationFactory* factory) {
82+
this->factory = shared_ptr<EncapsulationFactory>(factory);
83+
}
84+
EncapsulationFactory* get_factory() const {
85+
return factory.get();
86+
}
87+
88+
virtual void set_q0(const Encapsulation* q0) {
89+
throw NotImplementedYet("sweeper");
90+
}
91+
virtual Encapsulation* get_qend() {
92+
throw NotImplementedYet("sweeper");
93+
return NULL;
94+
}
95+
96+
virtual void advance() {
97+
throw NotImplementedYet("sweeper");
98+
}
99+
};
100+
101+
template<class T>
102+
class PolyInterpMixin : public T {
103+
virtual void interpolate(const ISweeper*) { }
104+
virtual void restrict(const ISweeper*) { }
105+
};
106+
107+
}
108+
}
141109

142110
#endif

src/pfasst-sdc.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,16 @@
99

1010
namespace pfasst {
1111

12-
template<typename timeT>
13-
class SDC : public Controller<timeT> {
12+
template<typename time>
13+
class SDC : public Controller<time> {
1414
public:
1515
void run() {
1616
ISweeper& swp = *this->get_level(0);
1717
for (int nstep=0; nstep<this->nsteps; nstep++) {
18-
timeT t = nstep * this->dt;
18+
time t = nstep * this->dt;
1919
swp.predict(t, this->dt);
20-
for (int niter=0; niter<this->niters; niter++) {
20+
for (int niter=0; niter<this->niters; niter++)
2121
swp.sweep(t, this->dt);
22-
}
2322
swp.advance();
2423
}
2524
}

src/pfasst-vector.hpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Host based encapsulated base sweeper.
3+
*/
4+
5+
#ifndef _PFASST_VECTOR_HPP_
6+
#define _PFASST_VECTOR_HPP_
7+
8+
#include <algorithm>
9+
#include <vector>
10+
11+
#include "pfasst-encapsulated.hpp"
12+
13+
using namespace std;
14+
15+
namespace pfasst {
16+
namespace encap {
17+
18+
template<typename scalar>
19+
struct VectorEncapsulation : public vector<scalar>, public Encapsulation {
20+
VectorEncapsulation(int size) : vector<scalar>(size) { }
21+
virtual unsigned int nbytes() const {
22+
return sizeof(scalar) * this->size();
23+
}
24+
void setval(scalar v) {
25+
for (int i=0; i<this->size(); i++)
26+
(*this)[i] = v;
27+
}
28+
void copy(const Encapsulation* X) {
29+
const auto* x = dynamic_cast<const VectorEncapsulation*>(X);
30+
for (int i=0; i<this->size(); i++)
31+
(*this)[i] = (*x)[i];
32+
}
33+
void saxpy(double a, const Encapsulation *X) {
34+
const auto* x = dynamic_cast<const VectorEncapsulation*>(X);
35+
for (int i=0; i<this->size(); i++)
36+
(*this)[i] += a * (*x)[i];
37+
}
38+
};
39+
40+
template<typename T>
41+
class VectorFactory : public EncapsulationFactory {
42+
int size;
43+
public:
44+
int dofs() { return size; }
45+
VectorFactory(const int size) : size(size) { }
46+
Encapsulation* create(const EncapType) {
47+
return new VectorEncapsulation<T>(size);
48+
}
49+
};
50+
51+
}
52+
}
53+
54+
#endif

0 commit comments

Comments
 (0)