Skip to content

Commit c89d4f6

Browse files
committed
Merge pull request #28 from torbjoernk/feature/overloads
introducing overloaded functions
2 parents ef384e3 + 839a38f commit c89d4f6

File tree

5 files changed

+161
-73
lines changed

5 files changed

+161
-73
lines changed

examples/advection_diffusion/advection_diffusion_sweeper.hpp

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <complex>
99
#include <vector>
10+
#include <cassert>
1011

1112
#include <pfasst/encap/imex_sweeper.hpp>
1213
#include "fft.hpp"
@@ -50,7 +51,9 @@ class AdvectionDiffusionSweeper : public pfasst::encap::IMEXSweeper<time>
5051

5152
void exact(Encapsulation* q, time t)
5253
{
53-
exact(*dynamic_cast<DVectorT*>(q), t);
54+
DVectorT* q_cast = dynamic_cast<DVectorT*>(q);
55+
assert(q_cast != nullptr);
56+
this->exact(*q_cast, t);
5457
}
5558

5659
void exact(DVectorT& q, time t)
@@ -72,17 +75,18 @@ class AdvectionDiffusionSweeper : public pfasst::encap::IMEXSweeper<time>
7275

7376
void echo_error(time t, bool predict = false)
7477
{
75-
auto& qend = *dynamic_cast<DVectorT*>(this->get_state(this->get_nodes().size() - 1));
76-
auto qex = DVectorT(qend.size());
78+
DVectorT* qend = dynamic_cast<DVectorT*>(this->get_state(this->get_nodes().size() - 1));
79+
assert(qend != nullptr);
80+
DVectorT qex = DVectorT(qend->size());
7781

7882
exact(qex, t);
7983

8084
double max = 0.0;
81-
for (size_t i = 0; i < qend.size(); i++) {
82-
double d = abs(qend[i] - qex[i]);
85+
for (size_t i = 0; i < qend->size(); i++) {
86+
double d = abs(qend->at(i) - qex[i]);
8387
if (d > max) { max = d; }
8488
}
85-
cout << "err: " << scientific << max << " (" << qend.size() << ", " << predict << ")" << endl;
89+
cout << "err: " << scientific << max << " (" << qend->size() << ", " << predict << ")" << endl;
8690
}
8791

8892
void predict(time t, time dt, bool initial)
@@ -97,50 +101,72 @@ class AdvectionDiffusionSweeper : public pfasst::encap::IMEXSweeper<time>
97101
echo_error(t + dt);
98102
}
99103

100-
void f1eval(Encapsulation* F, Encapsulation* Q, time t)
104+
void f1eval(Encapsulation* f, Encapsulation* q, time t)
101105
{
102-
auto& f = *dynamic_cast<DVectorT*>(F);
103-
auto& q = *dynamic_cast<DVectorT*>(Q);
106+
DVectorT* f_cast = dynamic_cast<DVectorT*>(f);
107+
assert(f_cast != nullptr);
108+
DVectorT* q_cast = dynamic_cast<DVectorT*>(q);
109+
assert(q_cast != nullptr);
104110

105-
double c = -v / double(q.size());
111+
this->f1eval(f_cast, q_cast, t);
112+
}
113+
114+
void f1eval(DVectorT* f, DVectorT* q, time t)
115+
{
116+
double c = -v / double(q->size());
106117

107-
auto* z = fft.forward(q);
108-
for (size_t i = 0; i < q.size(); i++) {
118+
auto* z = fft.forward(*q);
119+
for (size_t i = 0; i < q->size(); i++) {
109120
z[i] *= c * ddx[i];
110121
}
111-
fft.backward(f);
122+
fft.backward(*f);
112123

113124
nf1evals++;
114125
}
115126

116-
void f2eval(Encapsulation* F, Encapsulation* Q, time t)
127+
void f2eval(Encapsulation* f, Encapsulation* q, time t)
117128
{
118-
auto& f = *dynamic_cast<DVectorT*>(F);
119-
auto& q = *dynamic_cast<DVectorT*>(Q);
129+
DVectorT* f_cast = dynamic_cast<DVectorT*>(f);
130+
assert(f_cast != nullptr);
131+
DVectorT* q_cast = dynamic_cast<DVectorT*>(q);
132+
assert(q_cast != nullptr);
133+
134+
this->f2eval(f_cast, q_cast, t);
135+
}
120136

121-
double c = nu / double(q.size());
137+
void f2eval(DVectorT* f, DVectorT* q, time t)
138+
{
139+
double c = nu / double(q->size());
122140

123-
auto* z = fft.forward(q);
124-
for (size_t i = 0; i < q.size(); i++) {
141+
auto* z = fft.forward(*q);
142+
for (size_t i = 0; i < q->size(); i++) {
125143
z[i] *= c * lap[i];
126144
}
127-
fft.backward(f);
145+
fft.backward(*f);
128146
}
129147

130-
void f2comp(Encapsulation* F, Encapsulation* Q, time t, time dt, Encapsulation* RHS)
148+
void f2comp(Encapsulation* f, Encapsulation* q, time t, time dt, Encapsulation* rhs)
131149
{
132-
auto& f = *dynamic_cast<DVectorT*>(F);
133-
auto& q = *dynamic_cast<DVectorT*>(Q);
134-
auto& rhs = *dynamic_cast<DVectorT*>(RHS);
150+
DVectorT* f_cast = dynamic_cast<DVectorT*>(f);
151+
assert(f_cast != nullptr);
152+
DVectorT* q_cast = dynamic_cast<DVectorT*>(q);
153+
assert(q_cast != nullptr);
154+
DVectorT* rhs_cast = dynamic_cast<DVectorT*>(rhs);
155+
assert(rhs_cast != nullptr);
156+
157+
this->f2comp(f_cast, q_cast, t, dt, rhs_cast);
158+
}
135159

136-
auto* z = fft.forward(rhs);
137-
for (size_t i = 0; i < q.size(); i++) {
138-
z[i] /= (1.0 - nu * double(dt) * lap[i]) * double(q.size());
160+
void f2comp(DVectorT* f, DVectorT* q, time t, time dt, DVectorT* rhs)
161+
{
162+
auto* z = fft.forward(*rhs);
163+
for (size_t i = 0; i < q->size(); i++) {
164+
z[i] /= (1.0 - nu * double(dt) * lap[i]) * double(q->size());
139165
}
140-
fft.backward(q);
166+
fft.backward(*q);
141167

142-
for (size_t i = 0; i < q.size(); i++) {
143-
f[i] = (q[i] - rhs[i]) / double(dt);
168+
for (size_t i = 0; i < q->size(); i++) {
169+
f->at(i) = (q->at(i) - rhs->at(i)) / double(dt);
144170
}
145171
}
146172

examples/advection_diffusion/spectral_transfer_1d.hpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#ifndef _SPECTRAL_TRANSFER_1D_HPP_
66
#define _SPECTRAL_TRANSFER_1D_HPP_
77

8+
#include <cassert>
9+
810
#include <pfasst/encap/vector.hpp>
911
#include <pfasst/encap/poly_interp.hpp>
1012

@@ -21,38 +23,52 @@ class SpectralTransfer1D : public pfasst::encap::PolyInterpMixin<time>
2123
public:
2224
void interpolate(Encapsulation* dst, const Encapsulation* src)
2325
{
24-
auto& crse = *dynamic_cast<const DVectorT*>(src);
25-
auto& fine = *dynamic_cast<DVectorT*>(dst);
26+
DVectorT* fine = dynamic_cast<DVectorT*>(dst);
27+
assert(fine != nullptr);
28+
const DVectorT* crse = dynamic_cast<const DVectorT*>(src);
29+
assert(crse != nullptr);
30+
31+
this->interpolate(fine, crse);
32+
}
2633

27-
auto* crse_z = fft.forward(crse);
28-
auto* fine_z = fft.get_workspace(fine.size())->z;
34+
void interpolate(DVectorT* fine, const DVectorT* crse)
35+
{
36+
auto* crse_z = fft.forward(*crse);
37+
auto* fine_z = fft.get_workspace(fine->size())->z;
2938

30-
for (size_t i = 0; i < fine.size(); i++) {
39+
for (size_t i = 0; i < fine->size(); i++) {
3140
fine_z[i] = 0.0;
3241
}
3342

34-
double c = 1.0 / crse.size();
43+
double c = 1.0 / crse->size();
3544

36-
for (size_t i = 0; i < crse.size() / 2; i++) {
45+
for (size_t i = 0; i < crse->size() / 2; i++) {
3746
fine_z[i] = c * crse_z[i];
3847
}
3948

40-
for (size_t i = 1; i < crse.size() / 2; i++) {
41-
fine_z[fine.size() - crse.size() / 2 + i] = c * crse_z[crse.size() / 2 + i];
49+
for (size_t i = 1; i < crse->size() / 2; i++) {
50+
fine_z[fine->size() - crse->size() / 2 + i] = c * crse_z[crse->size() / 2 + i];
4251
}
4352

44-
fft.backward(fine);
53+
fft.backward(*fine);
4554
}
4655

4756
void restrict(Encapsulation* dst, const Encapsulation* src)
4857
{
49-
auto& crse = *dynamic_cast<DVectorT*>(dst);
50-
auto& fine = *dynamic_cast<const DVectorT*>(src);
58+
DVectorT* crse = dynamic_cast<DVectorT*>(dst);
59+
assert(crse != nullptr);
60+
const DVectorT* fine = dynamic_cast<const DVectorT*>(src);
61+
assert(fine != nullptr);
5162

52-
size_t xrat = fine.size() / crse.size();
63+
this->restrict(crse, fine);
64+
}
65+
66+
void restrict(DVectorT* crse, const DVectorT* fine)
67+
{
68+
size_t xrat = fine->size() / crse->size();
5369

54-
for (size_t i = 0; i < crse.size(); i++) {
55-
crse[i] = fine[xrat * i];
70+
for (size_t i = 0; i < crse->size(); i++) {
71+
crse->at(i) = fine->at(xrat * i);
5672
}
5773
}
5874

include/pfasst/encap/automagic.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace pfasst
1919
pfasst::ITransfer<time>*,
2020
pfasst::encap::EncapFactory<time>*>;
2121

22-
template<typename time = time_precision, typename ControllerT, typename BuildT>
22+
template<typename ControllerT, typename BuildT, typename time = time_precision>
2323
void auto_build(ControllerT& c, vector<pair<size_t, string>> nodes, BuildT build)
2424
{
2525
for (size_t l = 0; l < nodes.size(); l++) {
@@ -34,7 +34,7 @@ namespace pfasst
3434
}
3535
}
3636

37-
template<typename time = time_precision, typename ControllerT, typename initialT>
37+
template<typename ControllerT, typename initialT, typename time = time_precision>
3838
void auto_setup(ControllerT& c, initialT initial)
3939
{
4040
c.setup();

include/pfasst/encap/poly_interp.hpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <cassert>
1010

1111
#include "../interfaces.hpp"
12+
#include "encap_sweeper.hpp"
1213

1314
namespace pfasst
1415
{
@@ -27,9 +28,18 @@ namespace pfasst
2728
bool interp_delta_from_initial,
2829
bool interp_initial)
2930
{
30-
auto* fine = dynamic_cast<EncapSweeper<time>*>(dst);
31-
auto* crse = dynamic_cast<const EncapSweeper<time>*>(src);
31+
EncapSweeper<time>* fine = dynamic_cast<EncapSweeper<time>*>(dst);
32+
assert(fine != nullptr);
33+
const EncapSweeper<time>* crse = dynamic_cast<const EncapSweeper<time>*>(src);
34+
assert(crse != nullptr);
3235

36+
this->interpolate(fine, crse, interp_delta_from_initial, interp_initial);
37+
}
38+
39+
virtual void interpolate(EncapSweeper<time>* fine, const EncapSweeper<time>* crse,
40+
bool interp_delta_from_initial,
41+
bool interp_initial)
42+
{
3343
if (tmat.size1() == 0) {
3444
tmat = pfasst::compute_interp<time>(fine->get_nodes(), crse->get_nodes());
3545
}
@@ -76,9 +86,16 @@ namespace pfasst
7686

7787
virtual void restrict(ISweeper<time>* dst, const ISweeper<time>* src, bool restrict_initial)
7888
{
79-
auto* crse = dynamic_cast<EncapSweeper<time>*>(dst);
80-
auto* fine = dynamic_cast<const EncapSweeper<time>*>(src);
89+
EncapSweeper<time>* crse = dynamic_cast<EncapSweeper<time>*>(dst);
90+
assert(crse != nullptr);
91+
const EncapSweeper<time>* fine = dynamic_cast<const EncapSweeper<time>*>(src);
92+
assert(fine != nullptr);
8193

94+
this->restrict(crse, fine, restrict_initial);
95+
}
96+
97+
virtual void restrict(EncapSweeper<time>* crse, const EncapSweeper<time>* fine, bool restrict_initial)
98+
{
8299
auto dnodes = crse->get_nodes();
83100
auto snodes = fine->get_nodes();
84101

@@ -101,9 +118,16 @@ namespace pfasst
101118

102119
virtual void fas(time dt, ISweeper<time>* dst, const ISweeper<time>* src)
103120
{
104-
auto* crse = dynamic_cast<EncapSweeper<time>*>(dst);
105-
auto* fine = dynamic_cast<const EncapSweeper<time>*>(src);
121+
EncapSweeper<time>* crse = dynamic_cast<EncapSweeper<time>*>(dst);
122+
assert(crse != nullptr);
123+
const EncapSweeper<time>* fine = dynamic_cast<const EncapSweeper<time>*>(src);
124+
assert(fine != nullptr);
106125

126+
this->fas(dt, crse, fine);
127+
}
128+
129+
virtual void fas(time dt, EncapSweeper<time>* crse, const EncapSweeper<time>* fine)
130+
{
107131
size_t ncrse = crse->get_nodes().size();
108132
assert(ncrse > 1);
109133
size_t nfine = fine->get_nodes().size();

include/pfasst/encap/vector.hpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <algorithm>
99
#include <vector>
10+
#include <cassert>
1011

1112
#include <cstdio>
1213
#include <cstdlib>
@@ -46,45 +47,66 @@ namespace pfasst
4647
std::fill(this->begin(), this->end(), 0.0);
4748
}
4849

49-
void copy(const Encapsulation<time>* X)
50+
void copy(const Encapsulation<time>* x)
51+
{
52+
const VectorEncapsulation<scalar, time>* x_cast = dynamic_cast<const VectorEncapsulation<scalar, time>*>(x);
53+
assert(x_cast != nullptr);
54+
this->copy(x_cast);
55+
}
56+
57+
void copy(const VectorEncapsulation<scalar, time>* x)
5058
{
51-
const auto* x = dynamic_cast<const VectorEncapsulation*>(X);
5259
std::copy(x->begin(), x->end(), this->begin());
5360
}
5461
//! @}
5562

5663
//! @{
57-
void saxpy(time a, const Encapsulation<time>* X)
64+
void saxpy(time a, const Encapsulation<time>* x)
5865
{
59-
const auto& x = *dynamic_cast<const VectorEncapsulation*>(X);
60-
auto& y = *this;
66+
const VectorEncapsulation<scalar, time>* x_cast = dynamic_cast<const VectorEncapsulation<scalar, time>*>(x);
67+
assert(x_cast != nullptr);
6168

62-
for (int i = 0; i < y.size(); i++)
63-
{ y[i] += a * x[i]; }
69+
this->saxpy(a, x_cast);
6470
}
6571

66-
void mat_apply(vector<Encapsulation<time>*> DST, time a, matrix<time> mat,
67-
vector<Encapsulation<time>*> SRC, bool zero = true)
72+
void saxpy(time a, const VectorEncapsulation<scalar, time>* x)
6873
{
74+
for (size_t i = 0; i < this->size(); i++)
75+
{ this->at(i) += a * x->at(i); }
76+
}
6977

70-
int ndst = DST.size();
71-
int nsrc = SRC.size();
78+
void mat_apply(vector<Encapsulation<time>*> dst, time a, matrix<time> mat,
79+
vector<Encapsulation<time>*> src, bool zero = true)
80+
{
81+
size_t ndst = dst.size();
82+
size_t nsrc = src.size();
7283

73-
vector<VectorEncapsulation<scalar>*> dst(ndst), src(nsrc);
84+
vector<VectorEncapsulation<scalar, time>*> dst_cast(ndst), src_cast(nsrc);
7485
for (int n = 0; n < ndst; n++) {
75-
dst[n] = dynamic_cast<VectorEncapsulation<scalar>*>(DST[n]);
86+
dst_cast[n] = dynamic_cast<VectorEncapsulation<scalar, time>*>(dst[n]);
87+
assert(dst_cast[n] != nullptr);
7688
}
7789
for (int m = 0; m < nsrc; m++) {
78-
src[m] = dynamic_cast<VectorEncapsulation<scalar>*>(SRC[m]);
90+
src_cast[m] = dynamic_cast<VectorEncapsulation<scalar, time>*>(src[m]);
91+
assert(src_cast[m] != nullptr);
7992
}
8093

81-
if (zero) { for (int n = 0; n < ndst; n++) { dst[n]->zero(); } }
94+
this->mat_apply(dst_cast, a, mat, src_cast, zero);
95+
}
96+
97+
void mat_apply(vector<VectorEncapsulation<scalar, time>*> dst, time a, matrix<time> mat,
98+
vector<VectorEncapsulation<scalar, time>*> src, bool zero = true)
99+
{
100+
size_t ndst = dst.size();
101+
size_t nsrc = src.size();
102+
103+
if (zero) { for (size_t n = 0; n < ndst; n++) { dst[n]->zero(); } }
82104

83-
int ndofs = (*dst[0]).size();
84-
for (int i = 0; i < ndofs; i++) {
85-
for (int n = 0; n < ndst; n++) {
86-
for (int m = 0; m < nsrc; m++) {
87-
dst[n]->data()[i] += a * mat(n, m) * src[m]->data()[i];
105+
size_t ndofs = dst[0]->size();
106+
for (size_t i = 0; i < ndofs; i++) {
107+
for (size_t n = 0; n < ndst; n++) {
108+
for (size_t m = 0; m < nsrc; m++) {
109+
dst[n]->at(i) += a * mat(n, m) * src[m]->at(i);
88110
}
89111
}
90112
}

0 commit comments

Comments
 (0)