Skip to content

Commit 5f4e9d4

Browse files
committed
encap: Put various classes into their own include files.
Signed-off-by: Matthew Emmett <[email protected]>
1 parent 6603957 commit 5f4e9d4

File tree

8 files changed

+264
-221
lines changed

8 files changed

+264
-221
lines changed

examples/advection_diffusion/advection_diffusion_sweeper.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <complex>
99
#include <vector>
1010

11-
#include <pfasst/encap/imex.hpp>
11+
#include <pfasst/encap/imex_sweeper.hpp>
1212
#include "fft.hpp"
1313

1414
#define pi 3.1415926535897932385
@@ -20,7 +20,7 @@ using pfasst::encap::Encapsulation;
2020

2121

2222
template<typename scalar, typename time>
23-
class AdvectionDiffusionSweeper : public pfasst::imex::IMEX<scalar,time> {
23+
class AdvectionDiffusionSweeper : public pfasst::encap::IMEXSweeper<scalar,time> {
2424

2525
using dvector = pfasst::encap::VectorEncapsulation<scalar,time>;
2626
FFT<scalar,time> fft;
@@ -89,13 +89,13 @@ class AdvectionDiffusionSweeper : public pfasst::imex::IMEX<scalar,time> {
8989

9090
void predict(time t, time dt, bool initial)
9191
{
92-
pfasst::imex::IMEX<scalar,time>::predict(t, dt, initial);
92+
pfasst::encap::IMEXSweeper<scalar,time>::predict(t, dt, initial);
9393
echo_error(t+dt, true);
9494
}
9595

9696
void sweep(time t, time dt)
9797
{
98-
pfasst::imex::IMEX<scalar,time>::sweep(t, dt);
98+
pfasst::encap::IMEXSweeper<scalar,time>::sweep(t, dt);
9999
echo_error(t+dt);
100100
}
101101

examples/advection_diffusion/ex3.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <tuple>
1111

1212
#include <pfasst.hpp>
13+
#include <pfasst/mlsdc.hpp>
1314
#include <pfasst/encap/automagic.hpp>
1415
#include <pfasst/encap/vector.hpp>
1516

examples/advection_diffusion/spectral_transfer_1d.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#define _SPECTRAL_TRANSFER_1D_HPP_
77

88
#include <pfasst/encap/vector.hpp>
9+
#include <pfasst/encap/poly_interp.hpp>
10+
911
#include "fft.hpp"
1012

1113
template<typename scalar, typename time>

include/pfasst/encap/automagic.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <tuple>
66

77
#include "../quadrature.hpp"
8-
#include "encapsulation.hpp"
8+
#include "encap_sweeper.hpp"
99

1010
using namespace std;
1111

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Host based encapsulated base sweeper.
3+
*/
4+
5+
#ifndef _PFASST_ENCAP_ENCAP_SWEEPER_HPP_
6+
#define _PFASST_ENCAP_ENCAP_SWEEPER_HPP_
7+
8+
#include <vector>
9+
10+
#include "../interfaces.hpp"
11+
#include "../quadrature.hpp"
12+
#include "encapsulation.hpp"
13+
14+
namespace pfasst {
15+
namespace encap {
16+
17+
template<typename scalar, typename time>
18+
class EncapSweeper : public ISweeper {
19+
vector<scalar> nodes;
20+
shared_ptr<EncapFactory<scalar,time>> factory;
21+
22+
public:
23+
24+
void set_nodes(vector<time> nodes)
25+
{
26+
this->nodes = nodes;
27+
}
28+
29+
const vector<time> get_nodes() const
30+
{
31+
return nodes;
32+
}
33+
34+
void set_factory(EncapFactory<scalar,time>* factory)
35+
{
36+
this->factory = shared_ptr<EncapFactory<scalar,time>>(factory);
37+
}
38+
39+
EncapFactory<scalar,time>* get_factory() const
40+
{
41+
return factory.get();
42+
}
43+
44+
virtual void set_state(const Encapsulation<scalar,time>* q0, unsigned int m)
45+
{
46+
throw NotImplementedYet("sweeper");
47+
}
48+
49+
virtual Encapsulation<scalar,time>* get_state(unsigned int m) const
50+
{
51+
throw NotImplementedYet("sweeper");
52+
return NULL;
53+
}
54+
55+
virtual Encapsulation<scalar,time>* get_tau(unsigned int m) const
56+
{
57+
throw NotImplementedYet("sweeper");
58+
return NULL;
59+
}
60+
61+
virtual Encapsulation<scalar,time>* get_saved_state(unsigned int m) const
62+
{
63+
throw NotImplementedYet("sweeper");
64+
return NULL;
65+
}
66+
67+
virtual Encapsulation<scalar,time>* get_end_state()
68+
{
69+
return this->get_state(this->get_nodes().size()-1);
70+
}
71+
72+
virtual void evaluate(int m)
73+
{
74+
throw NotImplementedYet("sweeper");
75+
}
76+
77+
virtual void advance()
78+
{
79+
throw NotImplementedYet("sweeper");
80+
}
81+
82+
virtual void integrate(time dt, vector<Encapsulation<scalar,time>*> dst) const
83+
{
84+
throw NotImplementedYet("sweeper");
85+
}
86+
};
87+
88+
}
89+
90+
}
91+
92+
#endif

include/pfasst/encap/encapsulation.hpp

Lines changed: 0 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -55,216 +55,6 @@ namespace pfasst {
5555
virtual Encapsulation<scalar,time>* create(const EncapType) = 0;
5656
};
5757

58-
template<typename scalar, typename time>
59-
class EncapSweeper : public ISweeper {
60-
vector<scalar> nodes;
61-
shared_ptr<EncapFactory<scalar,time>> factory;
62-
63-
public:
64-
65-
void set_nodes(vector<time> nodes)
66-
{
67-
this->nodes = nodes;
68-
}
69-
70-
const vector<time> get_nodes() const
71-
{
72-
return nodes;
73-
}
74-
75-
void set_factory(EncapFactory<scalar,time>* factory)
76-
{
77-
this->factory = shared_ptr<EncapFactory<scalar,time>>(factory);
78-
}
79-
80-
EncapFactory<scalar,time>* get_factory() const
81-
{
82-
return factory.get();
83-
}
84-
85-
virtual void set_state(const Encapsulation<scalar,time>* q0, unsigned int m)
86-
{
87-
throw NotImplementedYet("sweeper");
88-
}
89-
90-
virtual Encapsulation<scalar,time>* get_state(unsigned int m) const
91-
{
92-
throw NotImplementedYet("sweeper");
93-
return NULL;
94-
}
95-
96-
virtual Encapsulation<scalar,time>* get_tau(unsigned int m) const
97-
{
98-
throw NotImplementedYet("sweeper");
99-
return NULL;
100-
}
101-
102-
virtual Encapsulation<scalar,time>* get_saved_state(unsigned int m) const
103-
{
104-
throw NotImplementedYet("sweeper");
105-
return NULL;
106-
}
107-
108-
virtual Encapsulation<scalar,time>* get_end_state()
109-
{
110-
return this->get_state(this->get_nodes().size()-1);
111-
}
112-
113-
virtual void evaluate(int m)
114-
{
115-
throw NotImplementedYet("sweeper");
116-
}
117-
118-
virtual void advance()
119-
{
120-
throw NotImplementedYet("sweeper");
121-
}
122-
123-
virtual void integrate(time dt, vector<Encapsulation<scalar,time>*> dst) const
124-
{
125-
throw NotImplementedYet("sweeper");
126-
}
127-
};
128-
129-
template<typename scalar, typename time>
130-
class PolyInterpMixin : public pfasst::ITransfer {
131-
matrix<time> tmat, fmat;
132-
133-
public:
134-
135-
virtual ~PolyInterpMixin() { }
136-
137-
virtual void interpolate(ISweeper *dst, const ISweeper *src,
138-
bool interp_delta_from_initial,
139-
bool interp_initial)
140-
{
141-
auto* fine = dynamic_cast<EncapSweeper<scalar,time>*>(dst);
142-
auto* crse = dynamic_cast<const EncapSweeper<scalar,time>*>(src);
143-
144-
if (tmat.size1() == 0)
145-
tmat = pfasst::compute_interp<time>(fine->get_nodes(), crse->get_nodes());
146-
147-
int nfine = fine->get_nodes().size();
148-
int ncrse = crse->get_nodes().size();
149-
150-
auto* crse_factory = crse->get_factory();
151-
auto* fine_factory = fine->get_factory();
152-
153-
vector<Encapsulation<scalar,time>*> fine_state(nfine), fine_delta(ncrse);
154-
155-
for (int m=0; m<nfine; m++) fine_state[m] = fine->get_state(m);
156-
for (int m=0; m<ncrse; m++) fine_delta[m] = fine_factory->create(solution);
157-
158-
if (interp_delta_from_initial)
159-
for (int m=1; m<nfine; m++)
160-
fine_state[m]->copy(fine_state[0]);
161-
162-
auto* crse_delta = crse_factory->create(solution);
163-
int m0 = interp_initial ? 0 : 1;
164-
for (int m=m0; m<ncrse; m++) {
165-
crse_delta->copy(crse->get_state(m));
166-
if (interp_delta_from_initial)
167-
crse_delta->saxpy(-1.0, crse->get_state(0));
168-
else
169-
crse_delta->saxpy(-1.0, crse->get_saved_state(m));
170-
interpolate(fine_delta[m], crse_delta);
171-
}
172-
delete crse_delta;
173-
174-
if (! interp_initial)
175-
fine_delta[0]->setval(0.0);
176-
177-
fine->get_state(0)->mat_apply(fine_state, 1.0, tmat, fine_delta, false);
178-
179-
for (int m=0; m<ncrse; m++) delete fine_delta[m];
180-
for (int m=m0; m<nfine; m++) fine->evaluate(m);
181-
}
182-
183-
virtual void restrict(ISweeper *dst, const ISweeper *src, bool restrict_initial)
184-
{
185-
auto* crse = dynamic_cast<EncapSweeper<scalar,time>*>(dst);
186-
auto* fine = dynamic_cast<const EncapSweeper<scalar,time>*>(src);
187-
188-
auto dnodes = crse->get_nodes();
189-
auto snodes = fine->get_nodes();
190-
191-
int ncrse = crse->get_nodes().size();
192-
int nfine = fine->get_nodes().size();
193-
194-
int trat = (nfine - 1) / (ncrse - 1);
195-
196-
int m0 = restrict_initial ? 0 : 1;
197-
for (int m=m0; m<ncrse; m++) {
198-
if (dnodes[m] != snodes[m*trat])
199-
throw NotImplementedYet("coarse nodes must be nested");
200-
this->restrict(crse->get_state(m), fine->get_state(m*trat));
201-
}
202-
203-
for (int m=m0; m<ncrse; m++) crse->evaluate(m);
204-
}
205-
206-
virtual void fas(time dt, ISweeper *dst, const ISweeper *src)
207-
{
208-
auto* crse = dynamic_cast<EncapSweeper<scalar,time>*>(dst);
209-
auto* fine = dynamic_cast<const EncapSweeper<scalar,time>*>(src);
210-
211-
int ncrse = crse->get_nodes().size();
212-
int nfine = fine->get_nodes().size();
213-
214-
auto* crse_factory = crse->get_factory();
215-
auto* fine_factory = fine->get_factory();
216-
217-
vector<Encapsulation<scalar,time>*> crse_z2n(ncrse-1), fine_z2n(nfine-1), rstr_z2n(ncrse-1);
218-
for (int m=0; m<ncrse-1; m++) crse_z2n[m] = crse_factory->create(solution);
219-
for (int m=0; m<ncrse-1; m++) rstr_z2n[m] = crse_factory->create(solution);
220-
for (int m=0; m<nfine-1; m++) fine_z2n[m] = fine_factory->create(solution);
221-
222-
// compute '0 to node' integral on the coarse level
223-
crse->integrate(dt, crse_z2n);
224-
for (int m=1; m<ncrse-1; m++)
225-
crse_z2n[m]->saxpy(1.0, crse_z2n[m-1]);
226-
227-
// compute '0 to node' integral on the fine level
228-
fine->integrate(dt, fine_z2n);
229-
for (int m=1; m<nfine-1; m++)
230-
fine_z2n[m]->saxpy(1.0, fine_z2n[m-1]);
231-
232-
// restrict '0 to node' fine integral
233-
int trat = (nfine - 1) / (ncrse - 1);
234-
for (int m=1; m<ncrse; m++)
235-
this->restrict(rstr_z2n[m-1], fine_z2n[m*trat-1]);
236-
237-
// compute 'node to node' tau correction
238-
vector<Encapsulation<scalar,time>*> tau(ncrse-1);
239-
for (int m=0; m<ncrse-1; m++) tau[m] = crse->get_tau(m);
240-
241-
tau[0]->copy(rstr_z2n[0]);
242-
tau[0]->saxpy(-1.0, crse_z2n[0]);
243-
244-
for (int m=1; m<ncrse-1; m++) {
245-
tau[m]->copy(rstr_z2n[m]);
246-
tau[m]->saxpy(-1.0, rstr_z2n[m-1]);
247-
248-
tau[m]->saxpy(-1.0, crse_z2n[m]);
249-
tau[m]->saxpy(1.0, crse_z2n[m-1]);
250-
}
251-
252-
for (int m=0; m<ncrse-1; m++) delete crse_z2n[m];
253-
for (int m=0; m<ncrse-1; m++) delete rstr_z2n[m];
254-
for (int m=0; m<nfine-1; m++) delete fine_z2n[m];
255-
}
256-
257-
// required for interp/restrict helpers
258-
virtual void interpolate(Encapsulation<scalar,time> *dst, const Encapsulation<scalar,time> *src) {
259-
throw NotImplementedYet("mlsdc/pfasst");
260-
}
261-
262-
virtual void restrict(Encapsulation<scalar,time> *dst, const Encapsulation<scalar,time> *src) {
263-
throw NotImplementedYet("mlsdc/pfasst");
264-
}
265-
266-
};
267-
26858
}
26959
}
27060

0 commit comments

Comments
 (0)