55#ifndef _PFASST_ENCAP_ENCAP_SWEEPER_HPP_
66#define _PFASST_ENCAP_ENCAP_SWEEPER_HPP_
77
8- #include < algorithm>
9- #include < cstdlib>
108#include < memory>
119#include < vector>
1210using namespace std ;
1311
14- #include " ../globals.hpp"
15- #include " ../config.hpp"
1612#include " ../interfaces.hpp"
1713#include " ../quadrature.hpp"
1814#include " encapsulation.hpp"
1915
16+
2017namespace pfasst
2118{
2219 namespace encap
@@ -57,132 +54,53 @@ namespace pfasst
5754 time abs_residual_tol, rel_residual_tol;
5855
5956 public:
60-
61- EncapSweeper ()
62- : quadrature(nullptr ), abs_residual_tol(0.0 ), rel_residual_tol(0.0 )
63- {}
57+ EncapSweeper ();
6458
6559 // ! @{
6660 /* *
6761 * Retrieve solution values of current iteration at time node index `m`.
6862 *
6963 * @param[in] m 0-based index of time node
7064 */
71- virtual shared_ptr<Encapsulation<time>> get_state (size_t m) const
72- {
73- return this ->state [m];
74- }
65+ virtual shared_ptr<Encapsulation<time>> get_state (size_t m) const ;
7566
7667 /* *
7768 * Retrieve FAS correction of current iteration at time node index `m`.
7869 *
7970 * @param[in] m 0-based index of time node
8071 */
81- virtual shared_ptr<Encapsulation<time>> get_tau (size_t m) const
82- {
83- return this ->fas_corrections [m];
84- }
72+ virtual shared_ptr<Encapsulation<time>> get_tau (size_t m) const ;
8573
8674 /* *
8775 * Retrieve solution values of previous iteration at time node index `m`.
8876 *
8977 * @param[in] m 0-based index of time node
9078 */
91- virtual shared_ptr<Encapsulation<time>> get_saved_state (size_t m) const
92- {
93- return this ->saved_state [m];
94- }
79+ virtual shared_ptr<Encapsulation<time>> get_saved_state (size_t m) const ;
9580 // ! @}
9681
9782 // ! @{
98- virtual void set_options () override
99- {
100- this ->abs_residual_tol = time (config::get_value<double >(" abs_res_tol" , this ->abs_residual_tol ));
101- this ->rel_residual_tol = time (config::get_value<double >(" rel_res_tol" , this ->rel_residual_tol ));
102- }
103-
104- virtual void setup (bool coarse) override
105- {
106- auto const nodes = this ->quadrature ->get_nodes ();
107- auto const num_nodes = this ->quadrature ->get_num_nodes ();
108-
109- this ->start_state = this ->get_factory ()->create (pfasst::encap::solution);
110- this ->end_state = this ->get_factory ()->create (pfasst::encap::solution);
111-
112- for (size_t m = 0 ; m < num_nodes; m++) {
113- this ->state .push_back (this ->get_factory ()->create (pfasst::encap::solution));
114- if (coarse) {
115- this ->saved_state .push_back (this ->get_factory ()->create (pfasst::encap::solution));
116- }
117- }
118-
119- if (coarse) {
120- for (size_t m = 0 ; m < num_nodes; m++) {
121- this ->fas_corrections .push_back (this ->get_factory ()->create (pfasst::encap::solution));
122- }
123- }
124- }
83+ virtual void set_options () override ;
84+ virtual void setup (bool coarse) override ;
12585 // ! @}
12686
12787 // ! @{
128- virtual void spread () override
129- {
130- for (size_t m = 1 ; m < this ->quadrature ->get_num_nodes (); m++) {
131- this ->state [m]->copy (this ->state [0 ]);
132- }
133- }
88+ virtual void spread () override ;
13489
13590 /* *
13691 * Save current solution states.
13792 */
138- virtual void save (bool initial_only) override
139- {
140- // XXX: if !left_is_node, this is a problem...
141- if (initial_only) {
142- this ->saved_state [0 ]->copy (state[0 ]);
143- } else {
144- for (size_t m = 0 ; m < this ->saved_state .size (); m++) {
145- this ->saved_state [m]->copy (state[m]);
146- }
147- }
148- }
93+ virtual void save (bool initial_only) override ;
14994 // ! @}
15095
15196 // ! @{
152- void set_quadrature (shared_ptr<IQuadrature<time>> quadrature)
153- {
154- this ->quadrature = quadrature;
155- }
156-
157- shared_ptr<const IQuadrature<time>> get_quadrature () const
158- {
159- return this ->quadrature ;
160- }
161-
162- shared_ptr<Encapsulation<time>> get_start_state () const
163- {
164- return this ->start_state ;
165- }
166-
167- const vector<time> get_nodes () const
168- {
169- return this ->quadrature ->get_nodes ();
170- }
171-
172- void set_factory (shared_ptr<EncapFactory<time>> factory)
173- {
174- this ->factory = factory;
175- }
176-
177- virtual shared_ptr<EncapFactory<time>> get_factory () const
178- {
179- return factory;
180- }
181-
182- virtual shared_ptr<Encapsulation<time>> get_end_state ()
183- {
184- return this ->end_state ;
185- }
97+ virtual void set_quadrature (shared_ptr<IQuadrature<time>> quadrature);
98+ virtual shared_ptr<const IQuadrature<time>> get_quadrature () const ;
99+ virtual shared_ptr<Encapsulation<time>> get_start_state () const ;
100+ virtual const vector<time> get_nodes () const ;
101+ virtual void set_factory (shared_ptr<EncapFactory<time>> factory);
102+ virtual shared_ptr<EncapFactory<time>> get_factory () const ;
103+ virtual shared_ptr<Encapsulation<time>> get_end_state ();
186104 // ! @}
187105
188106 // ! @{
@@ -191,21 +109,14 @@ namespace pfasst
191109 *
192110 * @note This method must be implemented in derived sweepers.
193111 */
194- virtual void advance () override
195- {
196- throw NotImplementedYet (" sweeper" );
197- }
112+ virtual void advance () override ;
198113
199114 /* *
200115 * Re-evaluate function values.
201116 *
202117 * @note This method must be implemented in derived sweepers.
203118 */
204- virtual void reevaluate (bool initial_only=false )
205- {
206- UNUSED (initial_only);
207- throw NotImplementedYet (" sweeper" );
208- }
119+ virtual void reevaluate (bool initial_only = false );
209120
210121 /* *
211122 * integrates values of right hand side at all time nodes \\( t \\in [0,M-1] \\)
@@ -216,110 +127,45 @@ namespace pfasst
216127 *
217128 * @note This method must be implemented in derived sweepers.
218129 */
219- virtual void integrate (time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
220- {
221- UNUSED (dt); UNUSED (dst);
222- throw NotImplementedYet (" sweeper" );
223- }
130+ virtual void integrate (time dt, vector<shared_ptr<Encapsulation<time>>> dst) const ;
224131 // ! @}
225132
226133 // ! @{
227134 /* *
228135 * Set residual tolerances for convergence checking.
229136 */
230- void set_residual_tolerances (time abs_residual_tol, time rel_residual_tol, int order=0 )
231- {
232- this ->abs_residual_tol = abs_residual_tol;
233- this ->rel_residual_tol = rel_residual_tol;
234- this ->residual_norm_order = order;
235- }
137+ void set_residual_tolerances (time abs_residual_tol, time rel_residual_tol, int order = 0 );
236138
237139 /* *
238140 * Compute residual at each SDC node (including FAS corrections).
239141 */
240- virtual void residual (time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
241- {
242- UNUSED (dt); UNUSED (dst);
243- throw NotImplementedYet (" residual" );
244- }
142+ virtual void residual (time dt, vector<shared_ptr<Encapsulation<time>>> dst) const ;
245143
246144 /* *
247145 * Return convergence status.
248146 *
249147 * This is used by controllers to shortcircuit iterations.
250148 */
251- virtual bool converged () override
252- {
253- if (this ->abs_residual_tol > 0.0 || this ->rel_residual_tol > 0.0 ) {
254- if (this ->residuals .size () == 0 ) {
255- for (size_t m = 0 ; m < this ->get_nodes ().size (); m++) {
256- this ->residuals .push_back (this ->get_factory ()->create (pfasst::encap::solution));
257- }
258- }
259- this ->residual (this ->get_controller ()->get_time_step (), this ->residuals );
260- vector<time> anorms, rnorms;
261- for (size_t m = 0 ; m < this ->get_nodes ().size (); m++) {
262- anorms.push_back (this ->residuals [m]->norm0 ());
263- rnorms.push_back (anorms.back () / this ->get_state (m)->norm0 ());
264- }
265- auto amax = *std::max_element (anorms.begin (), anorms.end ());
266- auto rmax = *std::max_element (rnorms.begin (), rnorms.end ());
267- if (amax < this ->abs_residual_tol || rmax < this ->rel_residual_tol ) {
268- return true ;
269- }
270- }
271- return false ;
272- }
149+ virtual bool converged () override ;
273150 // ! @}
274151
275152 // ! @{
276- virtual void post (ICommunicator* comm, int tag) override
277- {
278- this ->start_state ->post (comm, tag);
279- }
280-
281- virtual void send (ICommunicator* comm, int tag, bool blocking) override
282- {
283- this ->end_state ->send (comm, tag, blocking);
284- }
285-
286- virtual void recv (ICommunicator* comm, int tag, bool blocking) override
287- {
288- this ->start_state ->recv (comm, tag, blocking);
289- if (this ->quadrature ->left_is_node ()) {
290- this ->state [0 ]->copy (this ->start_state );
291- }
292- }
293-
294- virtual void broadcast (ICommunicator* comm) override
295- {
296- if (comm->rank () == comm->size () - 1 ) {
297- this ->start_state ->copy (this ->end_state );
298- }
299- this ->start_state ->broadcast (comm);
300- }
153+ virtual void post (ICommunicator* comm, int tag) override ;
154+ virtual void send (ICommunicator* comm, int tag, bool blocking) override ;
155+ virtual void recv (ICommunicator* comm, int tag, bool blocking) override ;
156+ virtual void broadcast (ICommunicator* comm) override ;
301157 // ! @}
302158 };
303159
304160
305161 template <typename time>
306- EncapSweeper<time>& as_encap_sweeper (shared_ptr<ISweeper<time>> x)
307- {
308- shared_ptr<EncapSweeper<time>> y = dynamic_pointer_cast<EncapSweeper<time>>(x);
309- assert (y);
310- return *y.get ();
311- }
312-
162+ EncapSweeper<time>& as_encap_sweeper (shared_ptr<ISweeper<time>> x);
313163
314164 template <typename time>
315- const EncapSweeper<time>& as_encap_sweeper (shared_ptr<const ISweeper<time>> x)
316- {
317- shared_ptr<const EncapSweeper<time>> y = dynamic_pointer_cast<const EncapSweeper<time>>(x);
318- assert (y);
319- return *y.get ();
320- }
321-
165+ const EncapSweeper<time>& as_encap_sweeper (shared_ptr<const ISweeper<time>> x);
322166 } // ::pfasst::encap
323167} // ::pfasst
324168
169+ #include " encap_sweeper_impl.hpp"
170+
325171#endif
0 commit comments