1414
1515#include " fft.hpp"
1616
17- #define PI 3.1415926535897932385
18- #define TWO_PI 6.2831853071795864769
17+ #ifndef PI
18+ #define PI 3.1415926535897932385
19+ #endif
1920
2021using namespace std ;
22+ using pfasst::encap::Encapsulation;
23+ using pfasst::encap::as_vector;
24+
2125
2226template <typename time = pfasst::time_precision>
23- class AdvectionDiffusionSweeper
27+ class AdvectionDiffusionSweeper
2428 : public pfasst::encap::IMEXSweeper<time>
2529{
26- typedef pfasst::encap::Encapsulation<time> Encapsulation;
27- typedef pfasst::encap::VectorEncapsulation<double > DVectorT;
28-
2930 FFT fft;
3031
3132 vector<complex <double >> ddx, lap;
@@ -41,7 +42,7 @@ class AdvectionDiffusionSweeper
4142 ddx.resize (nvars);
4243 lap.resize (nvars);
4344 for (size_t i = 0 ; i < nvars; i++) {
44- double kx = TWO_PI * ((i <= nvars / 2 ) ? int (i) : int (i) - int (nvars));
45+ double kx = 2 * PI * ((i <= nvars / 2 ) ? int (i) : int (i) - int (nvars));
4546 ddx[i] = complex <double >(0.0 , 1.0 ) * kx;
4647 lap[i] = (kx * kx < 1e-13 ) ? 0.0 : -kx * kx;
4748 }
@@ -52,45 +53,42 @@ class AdvectionDiffusionSweeper
5253 cout << " number of f1 evals: " << nf1evals << endl;
5354 }
5455
55- void exact (shared_ptr<Encapsulation> q, time t)
56+ void exact (shared_ptr<Encapsulation<time> > q, time t)
5657 {
57- shared_ptr<DVectorT> q_cast = dynamic_pointer_cast<DVectorT>(q);
58- assert (q_cast);
59- this ->exact (q_cast, t);
58+ this ->exact (as_vector<double ,time>(q), t);
6059 }
6160
62- void exact (shared_ptr< DVectorT> q, time t)
61+ void exact (DVectorT& q, time t)
6362 {
64- size_t n = q-> size ();
63+ size_t n = q. size ();
6564 double a = 1.0 / sqrt (4 * PI * nu * (t + t0));
6665
6766 for (size_t i = 0 ; i < n; i++) {
68- q-> data () [i] = 0.0 ;
67+ q[i] = 0.0 ;
6968 }
7069
7170 for (int ii = -2 ; ii < 3 ; ii++) {
7271 for (size_t i = 0 ; i < n; i++) {
7372 double x = double (i) / n - 0.5 + ii - t * v;
74- q-> data () [i] += a * exp (-x * x / (4 * nu * (t + t0)));
73+ q[i] += a * exp (-x * x / (4 * nu * (t + t0)));
7574 }
7675 }
7776 }
7877
7978 void echo_error (time t, bool predict = false )
8079 {
81- shared_ptr<DVectorT> qend = dynamic_pointer_cast<DVectorT>(this ->get_state (this ->get_nodes ().size () - 1 ));
82- assert (qend);
83- shared_ptr<DVectorT> qex = make_shared<DVectorT>(qend->size ());
80+ auto & qend = as_vector<double ,time>(this ->get_end_state ());
81+ DVectorT qex (qend.size ());
8482
8583 exact (qex, t);
8684
8785 double max = 0.0 ;
88- for (size_t i = 0 ; i < qend-> size (); i++) {
89- double d = abs (qend-> data () [i] - qex-> data () [i]);
86+ for (size_t i = 0 ; i < qend. size (); i++) {
87+ double d = abs (qend[i] - qex[i]);
9088 if (d > max) { max = d; }
9189 }
92- cout << " err: " << scientific << max
93- << " (" << qend-> size () << " , " << predict << " )"
90+ cout << " err: " << scientific << max
91+ << " (" << qend. size () << " , " << predict << " )"
9492 << endl;
9593 }
9694
@@ -106,76 +104,55 @@ class AdvectionDiffusionSweeper
106104 echo_error (t + dt);
107105 }
108106
109- void f1eval (shared_ptr<Encapsulation> f , shared_ptr<Encapsulation> q , time t )
107+ void f1eval (shared_ptr<Encapsulation<time>> _f , shared_ptr<Encapsulation<time>> _q , time /* t */ )
110108 {
111- shared_ptr<DVectorT> f_cast = dynamic_pointer_cast<DVectorT>(f);
112- assert (f_cast);
113- shared_ptr<DVectorT> q_cast = dynamic_pointer_cast<DVectorT>(q);
114- assert (q_cast);
109+ auto & q = as_vector<double ,time>(_q);
110+ auto & f = as_vector<double ,time>(_f);
115111
116- this ->f1eval (f_cast, q_cast, t);
117- }
118-
119- void f1eval (shared_ptr<DVectorT> f, shared_ptr<DVectorT> q, time t)
120- {
121- double c = -v / double (q->size ());
112+ double c = -v / double (q.size ());
122113
123114 auto * z = fft.forward (q);
124- for (size_t i = 0 ; i < q-> size (); i++) {
115+ for (size_t i = 0 ; i < q. size (); i++) {
125116 z[i] *= c * ddx[i];
126117 }
127118 fft.backward (f);
128119
129120 nf1evals++;
130121 }
131122
132- void f2eval (shared_ptr<Encapsulation> f , shared_ptr<Encapsulation> q , time t )
123+ void f2eval (shared_ptr<Encapsulation<time>> _f , shared_ptr<Encapsulation<time>> _q , time /* t */ )
133124 {
134- shared_ptr<DVectorT> f_cast = dynamic_pointer_cast<DVectorT>(f);
135- assert (f_cast);
136- shared_ptr<DVectorT> q_cast = dynamic_pointer_cast<DVectorT>(q);
137- assert (q_cast);
125+ auto & q = as_vector<double ,time>(_q);
126+ auto & f = as_vector<double ,time>(_f);
138127
139- this ->f2eval (f_cast, q_cast, t);
140- }
141-
142- void f2eval (shared_ptr<DVectorT> f, shared_ptr<DVectorT> q, time t)
143- {
144- double c = nu / double (q->size ());
128+ double c = nu / double (q.size ());
145129
146130 auto * z = fft.forward (q);
147- for (size_t i = 0 ; i < q-> size (); i++) {
131+ for (size_t i = 0 ; i < q. size (); i++) {
148132 z[i] *= c * lap[i];
149133 }
150134 fft.backward (f);
151135 }
152136
153- void f2comp (shared_ptr<Encapsulation> f , shared_ptr<Encapsulation> q , time t , time dt,
154- shared_ptr<Encapsulation> rhs )
137+ void f2comp (shared_ptr<Encapsulation<time>> _f , shared_ptr<Encapsulation<time>> _q , time /* t */ , time dt,
138+ shared_ptr<Encapsulation<time>> _rhs )
155139 {
156- shared_ptr<DVectorT> f_cast = dynamic_pointer_cast<DVectorT>(f);
157- assert (f_cast);
158- shared_ptr<DVectorT> q_cast = dynamic_pointer_cast<DVectorT>(q);
159- assert (q_cast);
160- shared_ptr<DVectorT> rhs_cast = dynamic_pointer_cast<DVectorT>(rhs);
161- assert (rhs_cast);
162-
163- this ->f2comp (f_cast, q_cast, t, dt, rhs_cast);
164- }
140+ auto & q = as_vector<double ,time>(_q);
141+ auto & f = as_vector<double ,time>(_f);
142+ auto & rhs = as_vector<double ,time>(_rhs);
165143
166- void f2comp (shared_ptr<DVectorT> f, shared_ptr<DVectorT> q, time t, time dt,
167- shared_ptr<DVectorT> rhs)
168- {
169144 auto * z = fft.forward (rhs);
170- for (size_t i = 0 ; i < q-> size (); i++) {
171- z[i] /= (1.0 - nu * double (dt) * lap[i]) * double (q-> size ());
145+ for (size_t i = 0 ; i < q. size (); i++) {
146+ z[i] /= (1.0 - nu * double (dt) * lap[i]) * double (q. size ());
172147 }
173148 fft.backward (q);
174149
175- for (size_t i = 0 ; i < q-> size (); i++) {
176- f-> data () [i] = (q-> data () [i] - rhs-> data () [i]) / double (dt);
150+ for (size_t i = 0 ; i < q. size (); i++) {
151+ f[i] = (q[i] - rhs[i]) / double (dt);
177152 }
153+
178154 }
155+
179156};
180157
181158#endif
0 commit comments