Skip to content

Commit f3e1e79

Browse files
committed
Tidy up, add more namespaces.
1 parent 501ba25 commit f3e1e79

File tree

3 files changed

+192
-151
lines changed

3 files changed

+192
-151
lines changed

examples/ad/main.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,46 +29,49 @@ const int trat = 2;
2929
const int nsteps = 32;
3030
const double dt = 0.01;
3131

32+
typedef double scalar;
33+
3234
using namespace std;
33-
using pfasst::Encapsulation;
34-
using dvector = pfasst::VectorEncapsulation<double>;
35+
using pfasst::encap::Encapsulation;
36+
using dvector = pfasst::encap::VectorEncapsulation<scalar>;
3537

3638

3739
//
3840
// advection/diffusion sweeper
3941
//
4042

41-
template<typename timeT>
42-
class ADIMEX : public pfasst::IMEX<timeT> {
43+
template<typename time>
44+
class ADIMEX : public pfasst::imex::IMEX<time> {
4345

4446
fftw_plan ffft;
4547
fftw_plan ifft;
4648
fftw_complex* wk;
4749

48-
vector<complex<double>> ddx, lap;
50+
vector<complex<scalar>> ddx, lap;
4951

50-
double v = 1.0;
51-
double t0 = 1.0;
52-
double nu = 0.02;
52+
scalar v = 1.0;
53+
scalar t0 = 1.0;
54+
scalar nu = 0.02;
5355

5456
public:
5557

56-
ADIMEX(vector<timeT> nodes, pfasst::VectorFactory<timeT> *factory)
58+
ADIMEX(vector<time> nodes, pfasst::encap::VectorFactory<time> *factory)
5759
{
5860
this->set_nodes(nodes);
5961
this->set_factory(factory);
6062

6163
int nvars = factory->dofs();
6264

65+
// XXX: this fft stuff almost certainly DOES NOT work when 'scalar' is not 'double'
6366
wk = fftw_alloc_complex(nvars);
6467
ffft = fftw_plan_dft_1d(nvars, wk, wk, FFTW_FORWARD, FFTW_ESTIMATE);
6568
ifft = fftw_plan_dft_1d(nvars, wk, wk, FFTW_BACKWARD, FFTW_ESTIMATE);
6669

6770
ddx.resize(nvars);
6871
lap.resize(nvars);
6972
for (int i=0; i<nvars; i++) {
70-
double kx = two_pi * ( i <= nvars/2 ? i : i-nvars );
71-
ddx[i] = complex<double>(0.0, 1.0) * kx;
73+
scalar kx = two_pi * ( i <= nvars/2 ? i : i-nvars );
74+
ddx[i] = complex<scalar>(0.0, 1.0) * kx;
7275
lap[i] = (kx*kx < 1e-13) ? 0.0 : -kx*kx;
7376
}
7477

@@ -80,47 +83,47 @@ class ADIMEX : public pfasst::IMEX<timeT> {
8083
fftw_free(wk);
8184
}
8285

83-
void exact(dvector& q, double t)
86+
void exact(dvector& q, scalar t)
8487
{
8588
int n = q.size();
86-
double a = 1.0/sqrt(4*pi*nu*(t+t0));
89+
scalar a = 1.0/sqrt(4*pi*nu*(t+t0));
8790

8891
for (int i=0; i<n; i++)
8992
q[i] = 0.0;
9093

9194
for (int ii=-2; ii<3; ii++) {
9295
for (int i=0; i<n; i++) {
93-
double x = double(i)/n - 0.5 + ii - t * v;
96+
scalar x = scalar(i)/n - 0.5 + ii - t * v;
9497
q[i] += a * exp(-x*x/(4*nu*(t+t0)));
9598
}
9699
}
97100
}
98101

99-
void sweep(timeT t, timeT dt)
102+
void sweep(time t, time dt)
100103
{
101-
pfasst::IMEX<timeT>::sweep(t, dt);
104+
pfasst::imex::IMEX<time>::sweep(t, dt);
102105

103106
auto& qend = *dynamic_cast<dvector*>(this->get_qend());
104107
auto qex = dvector(qend.size());
105108

106109
exact(qex, t+dt);
107110

108-
double max = 0.0;
111+
scalar max = 0.0;
109112
for (int i=0; i<qend.size(); i++) {
110-
double d = abs(qend[i]-qex[i]);
113+
scalar d = abs(qend[i]-qex[i]);
111114
if (d > max)
112115
max = d;
113116
}
114117
cout << "err: " << max << endl;
115118
}
116119

117-
void f1eval(Encapsulation *F, Encapsulation *Q, timeT t)
120+
void f1eval(Encapsulation *F, Encapsulation *Q, time t)
118121
{
119122
auto& f = *dynamic_cast<dvector*>(F);
120123
auto& q = *dynamic_cast<dvector*>(Q);
121124

122-
complex<double>* z = reinterpret_cast<complex<double>*>(wk);
123-
double c = -v / double(q.size());
125+
complex<scalar>* z = reinterpret_cast<complex<scalar>*>(wk);
126+
scalar c = -v / scalar(q.size());
124127

125128
copy(q.begin(), q.end(), z);
126129
fftw_execute_dft(ffft, wk, wk);
@@ -132,13 +135,13 @@ class ADIMEX : public pfasst::IMEX<timeT> {
132135
f[i] = real(z[i]);
133136
}
134137

135-
void f2eval(Encapsulation *F, Encapsulation *Q, timeT t)
138+
void f2eval(Encapsulation *F, Encapsulation *Q, time t)
136139
{
137140
auto& f = *dynamic_cast<dvector*>(F);
138141
auto& q = *dynamic_cast<dvector*>(Q);
139142

140-
complex<double>* z = reinterpret_cast<complex<double>*>(wk);
141-
double c = nu / double(q.size());
143+
complex<scalar>* z = reinterpret_cast<complex<scalar>*>(wk);
144+
scalar c = nu / scalar(q.size());
142145

143146
copy(q.begin(), q.end(), z);
144147
fftw_execute_dft(ffft, wk, wk);
@@ -150,18 +153,18 @@ class ADIMEX : public pfasst::IMEX<timeT> {
150153
f[i] = real(z[i]);
151154
}
152155

153-
void f2comp(Encapsulation *F, Encapsulation *Q, timeT t, timeT dt, Encapsulation *RHS)
156+
void f2comp(Encapsulation *F, Encapsulation *Q, time t, time dt, Encapsulation *RHS)
154157
{
155158
auto& f = *dynamic_cast<dvector*>(F);
156159
auto& q = *dynamic_cast<dvector*>(Q);
157160
auto& rhs = *dynamic_cast<dvector*>(RHS);
158161

159-
complex<double>* z = reinterpret_cast<complex<double>*>(wk);
162+
complex<scalar>* z = reinterpret_cast<complex<scalar>*>(wk);
160163

161164
copy(rhs.begin(), rhs.end(), z);
162165
fftw_execute_dft(ffft, wk, wk);
163166
for (int i=0; i<q.size(); i++)
164-
z[i] /= (1.0 - nu * dt * lap[i]) * double(q.size());
167+
z[i] /= (1.0 - nu * dt * lap[i]) * scalar(q.size());
165168
fftw_execute_dft(ifft, wk, wk);
166169

167170
for (int i=0; i<q.size(); i++) {
@@ -184,11 +187,11 @@ int main(int argc, char **argv)
184187
int nnodes = 5;
185188

186189
if (nlevs == 1) {
187-
pfasst::SDC<double> sdc;
190+
pfasst::SDC<scalar> sdc;
188191

189-
auto nodes = pfasst::compute_nodes<double>(nnodes, "gauss-lobatto");
190-
auto* factory = new pfasst::VectorFactory<double>(ndofs);
191-
auto* sweeper = new ADIMEX<double>(nodes, factory);
192+
auto nodes = pfasst::compute_nodes<scalar>(nnodes, "gauss-lobatto");
193+
auto* factory = new pfasst::encap::VectorFactory<scalar>(ndofs);
194+
auto* sweeper = new ADIMEX<scalar>(nodes, factory);
192195

193196
sdc.add_level(sweeper);
194197
sdc.set_duration(dt, nsteps, 4);

src/pfasst-encapsulated.hpp

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
using namespace std;
1515

16-
namespace pfasst {
16+
namespace pfasst { namespace encap {
1717

1818
typedef enum EncapType { solution, function } EncapType;
1919

@@ -26,19 +26,38 @@ namespace pfasst {
2626
virtual ~Encapsulation() { }
2727

2828
// required for interp/restrict helpers
29-
virtual void interpolate(const Encapsulation *) { }
30-
virtual void restrict(const Encapsulation *) { }
29+
virtual void interpolate(const Encapsulation *) {
30+
throw NotImplementedYet("mlsdc/pfasst");
31+
}
32+
virtual void restrict(const Encapsulation *) {
33+
throw NotImplementedYet("mlsdc/pfasst");
34+
}
3135

3236
// required for time-parallel communications
33-
virtual unsigned int nbytes() { return -1; }
34-
virtual void pack(char *buf) { }
35-
virtual void unpack(char *buf) { }
37+
virtual unsigned int nbytes() {
38+
throw NotImplementedYet("pfasst");
39+
}
40+
virtual void send() {
41+
throw NotImplementedYet("pfasst");
42+
}
43+
virtual void recv() {
44+
throw NotImplementedYet("pfasst");
45+
}
3646

3747
// required for host based encap helpers
38-
virtual void setval(double) { throw NotImplementedYet("encap setval"); }
39-
virtual void copy(const Encapsulation *) { throw NotImplementedYet("encap copy"); }
40-
virtual void saxpy(double a, const Encapsulation *) { throw NotImplementedYet("encap saxpy"); }
41-
// virtual void mat_apply(encapsulation dst[], double a, matrix m, const encapsulation src[]) { }
48+
virtual void setval(double) {
49+
throw NotImplementedYet("encap");
50+
}
51+
virtual void copy(const Encapsulation *) {
52+
throw NotImplementedYet("encap");
53+
}
54+
virtual void saxpy(double a, const Encapsulation *) {
55+
throw NotImplementedYet("encap");
56+
}
57+
// virtual void mat_apply(Encapsulation dst[], double a, matrix m,
58+
// const Encapsulation src[]) {
59+
// throw NotImplementedYet("encap");
60+
// }
4261
};
4362

4463
class EncapsulationFactory {
@@ -52,17 +71,31 @@ namespace pfasst {
5271
shared_ptr<EncapsulationFactory> factory;
5372

5473
public:
55-
void set_nodes(vector<T> nodes) { this->nodes = nodes; }
56-
const vector<T> get_nodes() const { return nodes; }
57-
58-
void set_factory(EncapsulationFactory* factory) { this->factory = shared_ptr<EncapsulationFactory>(factory); }
59-
EncapsulationFactory* get_factory() const { return factory.get(); }
74+
void set_nodes(vector<T> nodes) {
75+
this->nodes = nodes;
76+
}
77+
const vector<T> get_nodes() const {
78+
return nodes;
79+
}
6080

61-
virtual void set_q0(const Encapsulation* q0) { throw NotImplementedYet("sweeper"); }
62-
virtual Encapsulation* get_qend() { throw NotImplementedYet("sweeper"); return NULL; }
81+
void set_factory(EncapsulationFactory* factory) {
82+
this->factory = shared_ptr<EncapsulationFactory>(factory);
83+
}
84+
EncapsulationFactory* get_factory() const {
85+
return factory.get();
86+
}
6387

64-
virtual void advance() { }
88+
virtual void set_q0(const Encapsulation* q0) {
89+
throw NotImplementedYet("sweeper");
90+
}
91+
virtual Encapsulation* get_qend() {
92+
throw NotImplementedYet("sweeper");
93+
return NULL;
94+
}
6595

96+
virtual void advance() {
97+
throw NotImplementedYet("sweeper");
98+
}
6699
};
67100

68101
template<class T>
@@ -104,6 +137,6 @@ namespace pfasst {
104137
}
105138
};
106139

107-
}
140+
} }
108141

109142
#endif

0 commit comments

Comments
 (0)