Skip to content

Commit 1f34441

Browse files
committed
Add advance() interface, flesh out SDC, add example (it works!).
1 parent e60f63a commit 1f34441

File tree

5 files changed

+244
-11
lines changed

5 files changed

+244
-11
lines changed

examples/ad/main.cpp

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Advection/diffusion example using the encapsulated IMEX sweeper.
3+
*/
4+
5+
#include <algorithm>
6+
#include <cmath>
7+
#include <complex>
8+
9+
#include <pfasst.hpp>
10+
#include <pfasst-imex.hpp>
11+
12+
#include <fftw3.h>
13+
14+
#define pi 3.1415926535897932385
15+
#define two_pi 6.2831853071795864769
16+
17+
using namespace std;
18+
19+
//
20+
// config
21+
//
22+
23+
const int nlevs = 1;
24+
const int ndofs = 512;
25+
const int nnodes = 5;
26+
const int xrat = 2;
27+
const int trat = 2;
28+
29+
const int nsteps = 32;
30+
const double dt = 0.01;
31+
32+
using namespace std;
33+
using pfasst::Encapsulation;
34+
using dvector = pfasst::VectorEncapsulation<double>;
35+
36+
37+
//
38+
// advection/diffusion sweeper
39+
//
40+
41+
template<typename timeT>
42+
class ADIMEX : public pfasst::IMEX<timeT> {
43+
44+
fftw_plan ffft;
45+
fftw_plan ifft;
46+
fftw_complex* wk;
47+
48+
vector<complex<double>> ddx, lap;
49+
50+
double v = 1.0;
51+
double t0 = 1.0;
52+
double nu = 0.02;
53+
54+
public:
55+
56+
ADIMEX(vector<timeT> nodes, pfasst::VectorFactory<timeT> *factory)
57+
{
58+
this->set_nodes(nodes);
59+
this->set_factory(factory);
60+
61+
int nvars = factory->dofs();
62+
63+
wk = fftw_alloc_complex(nvars);
64+
ffft = fftw_plan_dft_1d(nvars, wk, wk, FFTW_FORWARD, FFTW_ESTIMATE);
65+
ifft = fftw_plan_dft_1d(nvars, wk, wk, FFTW_BACKWARD, FFTW_ESTIMATE);
66+
67+
ddx.resize(nvars);
68+
lap.resize(nvars);
69+
for (int i=0; i<nvars; i++) {
70+
double kx = two_pi * ( i <= nvars/2 ? i : i-nvars );
71+
ddx[i] = complex<double>(0.0, 1.0) * kx;
72+
lap[i] = (kx*kx < 1e-13) ? 0.0 : -kx*kx;
73+
}
74+
75+
}
76+
77+
~ADIMEX() {
78+
fftw_destroy_plan(ffft);
79+
fftw_destroy_plan(ifft);
80+
fftw_free(wk);
81+
}
82+
83+
void exact(dvector& q, double t)
84+
{
85+
int n = q.size();
86+
double a = 1.0/sqrt(4*pi*nu*(t+t0));
87+
88+
for (int i=0; i<n; i++)
89+
q[i] = 0.0;
90+
91+
for (int ii=-2; ii<3; ii++) {
92+
for (int i=0; i<n; i++) {
93+
double x = double(i)/n - 0.5 + ii - t * v;
94+
q[i] += a * exp(-x*x/(4*nu*(t+t0)));
95+
}
96+
}
97+
}
98+
99+
void sweep(timeT t, timeT dt)
100+
{
101+
pfasst::IMEX<timeT>::sweep(t, dt);
102+
103+
auto& qend = *dynamic_cast<dvector*>(this->get_qend());
104+
auto qex = dvector(qend.size());
105+
106+
exact(qex, t+dt);
107+
108+
double max = 0.0;
109+
for (int i=0; i<qend.size(); i++) {
110+
double d = abs(qend[i]-qex[i]);
111+
if (d > max)
112+
max = d;
113+
}
114+
cout << "err: " << max << endl;
115+
}
116+
117+
void f1eval(Encapsulation *F, Encapsulation *Q, timeT t)
118+
{
119+
auto& f = *dynamic_cast<dvector*>(F);
120+
auto& q = *dynamic_cast<dvector*>(Q);
121+
122+
complex<double>* z = reinterpret_cast<complex<double>*>(wk);
123+
double c = -v / double(q.size());
124+
125+
copy(q.begin(), q.end(), z);
126+
fftw_execute_dft(ffft, wk, wk);
127+
for (int i=0; i<q.size(); i++)
128+
z[i] *= c * ddx[i];
129+
fftw_execute_dft(ifft, wk, wk);
130+
131+
for (int i=0; i<q.size(); i++)
132+
f[i] = real(z[i]);
133+
}
134+
135+
void f2eval(Encapsulation *F, Encapsulation *Q, timeT t)
136+
{
137+
auto& f = *dynamic_cast<dvector*>(F);
138+
auto& q = *dynamic_cast<dvector*>(Q);
139+
140+
complex<double>* z = reinterpret_cast<complex<double>*>(wk);
141+
double c = nu / double(q.size());
142+
143+
copy(q.begin(), q.end(), z);
144+
fftw_execute_dft(ffft, wk, wk);
145+
for (int i=0; i<q.size(); i++)
146+
z[i] *= c * lap[i];
147+
fftw_execute_dft(ifft, wk, wk);
148+
149+
for (int i=0; i<q.size(); i++)
150+
f[i] = real(z[i]);
151+
}
152+
153+
void f2comp(Encapsulation *F, Encapsulation *Q, timeT t, timeT dt, Encapsulation *RHS)
154+
{
155+
auto& f = *dynamic_cast<dvector*>(F);
156+
auto& q = *dynamic_cast<dvector*>(Q);
157+
auto& rhs = *dynamic_cast<dvector*>(RHS);
158+
159+
complex<double>* z = reinterpret_cast<complex<double>*>(wk);
160+
161+
copy(rhs.begin(), rhs.end(), z);
162+
fftw_execute_dft(ffft, wk, wk);
163+
for (int i=0; i<q.size(); i++)
164+
z[i] /= (1.0 - nu * dt * lap[i]) * double(q.size());
165+
fftw_execute_dft(ifft, wk, wk);
166+
167+
for (int i=0; i<q.size(); i++) {
168+
q[i] = real(z[i]);
169+
f[i] = (q[i] - rhs[i]) / dt;
170+
}
171+
172+
}
173+
174+
};
175+
176+
177+
//
178+
// main
179+
//
180+
181+
int main(int argc, char **argv)
182+
{
183+
int ndofs = 512;
184+
int nnodes = 5;
185+
186+
if (nlevs == 1) {
187+
pfasst::SDC<double> sdc;
188+
189+
auto nodes = pfasst::compute_nodes<double>(nnodes, "gauss-lobatto");
190+
auto* factory = new pfasst::VectorFactory<double>(ndofs);
191+
auto* sweeper = new ADIMEX<double>(nodes, factory);
192+
193+
sdc.add_level(sweeper);
194+
sdc.set_duration(dt, nsteps, 4);
195+
sdc.setup();
196+
197+
dvector q0(ndofs);
198+
sweeper->exact(q0, 0.0);
199+
sweeper->set_q0(&q0);
200+
201+
sdc.run();
202+
} else {
203+
// ...
204+
}
205+
206+
return 0;
207+
}

src/pfasst-encapsulated.hpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#ifndef _PFASST_ENCAPSULATED_HPP_
66
#define _PFASST_ENCAPSULATED_HPP_
77

8+
#include <algorithm>
89
#include <vector>
910

1011
#include "pfasst-interfaces.hpp"
@@ -34,9 +35,9 @@ namespace pfasst {
3435
virtual void unpack(char *buf) { }
3536

3637
// required for host based encap helpers
37-
virtual void setval(double) { }
38-
virtual void copy(const Encapsulation *) { }
39-
virtual void saxpy(double a, const Encapsulation *) { }
38+
virtual void setval(double) { throw NotImplementedYet("encap setval"); }
39+
virtual void copy(const Encapsulation *) { throw NotImplementedYet("encap copy"); }
40+
virtual void saxpy(double a, const Encapsulation *) { throw NotImplementedYet("encap saxpy"); }
4041
// virtual void mat_apply(encapsulation dst[], double a, matrix m, const encapsulation src[]) { }
4142
};
4243

@@ -59,6 +60,9 @@ namespace pfasst {
5960

6061
virtual void set_q0(const Encapsulation* q0) { throw NotImplementedYet("sweeper"); }
6162
virtual Encapsulation* get_qend() { throw NotImplementedYet("sweeper"); return NULL; }
63+
64+
virtual void advance() { }
65+
6266
};
6367

6468
template<class T>
@@ -67,17 +71,26 @@ namespace pfasst {
6771
virtual void restrict(const ISweeper*) { }
6872
};
6973

70-
template<typename T>
71-
struct VectorEncapsulation : public vector<T>, public Encapsulation {
72-
VectorEncapsulation(int size) : vector<T>(size) { }
74+
template<typename scalar>
75+
struct VectorEncapsulation : public vector<scalar>, public Encapsulation {
76+
VectorEncapsulation(int size) : vector<scalar>(size) { }
7377
virtual unsigned int nbytes() const {
74-
return sizeof(T) * this->size();
78+
return sizeof(scalar) * this->size();
79+
}
80+
void setval(scalar v) {
81+
for (int i=0; i<this->size(); i++)
82+
(*this)[i] = v;
83+
}
84+
void copy(const Encapsulation* X) {
85+
const auto* x = dynamic_cast<const VectorEncapsulation*>(X);
86+
for (int i=0; i<this->size(); i++)
87+
(*this)[i] = (*x)[i];
7588
}
76-
void setval(double v) {
89+
void saxpy(double a, const Encapsulation *X) {
90+
const auto* x = dynamic_cast<const VectorEncapsulation*>(X);
7791
for (int i=0; i<this->size(); i++)
78-
this->data()[i] = v;
92+
(*this)[i] += a * (*x)[i];
7993
}
80-
// ...
8194
};
8295

8396
template<typename T>

src/pfasst-imex.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ namespace pfasst {
3030
Q[0]->copy(q0);
3131
}
3232

33+
Encapsulation* get_qend() {
34+
return Q[Q.size()-1];
35+
}
36+
37+
void advance() {
38+
set_q0(get_qend());
39+
}
40+
41+
3342
void setup() {
3443
auto nodes = this->get_nodes();
3544

src/pfasst-interfaces.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace pfasst {
3434
virtual void predict(double t, double dt) = 0;
3535
virtual void integrate(double t, double dt) = 0;
3636
virtual void residual(double t, double dt) = 0;
37+
virtual void advance() = 0;
3738
};
3839

3940
class ITransfer {

src/pfasst-sdc.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@ namespace pfasst {
1515
void run() {
1616
ISweeper& swp = *this->get_level(0);
1717
for (int nstep=0; nstep<this->nsteps; nstep++) {
18+
timeT t = nstep * this->dt;
19+
swp.predict(t, this->dt);
1820
for (int niter=0; niter<this->niters; niter++) {
19-
swp.sweep(0.0, this->dt);
21+
swp.sweep(t, this->dt);
2022
}
23+
swp.advance();
2124
}
2225
}
2326
};

0 commit comments

Comments
 (0)