@@ -19,53 +19,137 @@ namespace pfasst
1919{
2020 namespace encap
2121 {
22+ using namespace pfasst ::quadrature;
23+
2224 template <typename time = time_precision>
2325 class EncapSweeper
2426 : public ISweeper<time>
2527 {
2628 protected:
2729 // ! @{
28- quadrature:: IQuadrature<time>* quad ;
30+ shared_ptr< IQuadrature<time>> quadrature ;
2931 shared_ptr<EncapFactory<time>> factory;
3032 shared_ptr<Encapsulation<time>> start_state;
3133 shared_ptr<Encapsulation<time>> end_state;
3234 vector<shared_ptr<Encapsulation<time>>> residuals;
3335 // ! @}
3436
37+ // ! @{
38+ /* *
39+ * Solution values \\( U \\) at all time nodes of the current iteration.
40+ */
41+ vector<shared_ptr<Encapsulation<time>>> state;
42+
43+ /* *
44+ * Solution values \\( U \\) at all time nodes of the previous iteration.
45+ */
46+ vector<shared_ptr<Encapsulation<time>>> saved_state;
47+
48+ /* *
49+ * FAS corrections \\( \\tau \\) at all time nodes of the current iteration.
50+ */
51+ vector<shared_ptr<Encapsulation<time>>> fas_corrections;
52+ // ! @}
53+
3554 int residual_norm_order;
3655 time abs_residual_tol, rel_residual_tol;
3756
3857 public:
39- // ! @{
58+
4059 EncapSweeper ()
41- : quad (nullptr ), abs_residual_tol(0.0 ), rel_residual_tol(0.0 )
60+ : quadrature (nullptr ), abs_residual_tol(0.0 ), rel_residual_tol(0.0 )
4261 {}
4362
44- virtual ~EncapSweeper ()
63+ // ! @{
64+ /* *
65+ * Retrieve solution values of current iteration at time node index `m`.
66+ *
67+ * @param[in] m 0-based index of time node
68+ */
69+ virtual shared_ptr<Encapsulation<time>> get_state (size_t m) const
4570 {
46- if (this ->quad ) delete this ->quad ;
71+ return this ->state [m];
72+ }
73+
74+ /* *
75+ * Retrieve FAS correction of current iteration at time node index `m`.
76+ *
77+ * @param[in] m 0-based index of time node
78+ */
79+ virtual shared_ptr<Encapsulation<time>> get_tau (size_t m) const
80+ {
81+ return this ->fas_corrections [m];
82+ }
83+
84+ /* *
85+ * Retrieve solution values of previous iteration at time node index `m`.
86+ *
87+ * @param[in] m 0-based index of time node
88+ */
89+ virtual shared_ptr<Encapsulation<time>> get_saved_state (size_t m) const
90+ {
91+ return this ->saved_state [m];
4792 }
4893 // ! @}
4994
95+ virtual void setup (bool coarse) override
96+ {
97+ auto const nodes = this ->quadrature ->get_nodes ();
98+ auto const num_nodes = this ->quadrature ->get_num_nodes ();
99+
100+ this ->start_state = this ->get_factory ()->create (pfasst::encap::solution);
101+ this ->end_state = this ->get_factory ()->create (pfasst::encap::solution);
102+
103+ for (size_t m = 0 ; m < num_nodes; m++) {
104+ this ->state .push_back (this ->get_factory ()->create (pfasst::encap::solution));
105+ if (coarse) {
106+ this ->saved_state .push_back (this ->get_factory ()->create (pfasst::encap::solution));
107+ }
108+ }
109+
110+ if (coarse) {
111+ size_t num_fas = this ->quadrature ->left_is_node () ? num_nodes -1 : num_nodes;
112+ for (size_t m = 0 ; m < num_fas; m++) {
113+ this ->fas_corrections .push_back (this ->get_factory ()->create (pfasst::encap::solution));
114+ }
115+ }
116+
117+ }
118+
50119 // ! @{
51120 virtual void spread () override
52121 {
53- for (size_t m = 1 ; m < this ->quad ->get_num_nodes (); m++) {
122+ for (size_t m = 1 ; m < this ->quadrature ->get_num_nodes (); m++) {
54123 // this->get_state(m)->copy(this->start_state);
55- this ->get_state (m) ->copy (this ->get_state ( 0 ) );
124+ this ->state [m] ->copy (this ->state [ 0 ] );
56125 }
57126 }
58127 // ! @}
59128
129+ /* *
130+ * Save current solution states.
131+ */
132+ virtual void save (bool initial_only) override
133+ {
134+ // XXX: if !left_is_node, this is a problem...
135+ if (initial_only) {
136+ this ->saved_state [0 ]->copy (state[0 ]);
137+ } else {
138+ for (size_t m = 0 ; m < this ->saved_state .size (); m++) {
139+ this ->saved_state [m]->copy (state[m]);
140+ }
141+ }
142+ }
143+
60144 // ! @{
61- void set_quadrature (quadrature:: IQuadrature<time>* quad )
145+ void set_quadrature (shared_ptr< IQuadrature<time>> quadrature )
62146 {
63- this ->quad = quad ;
147+ this ->quadrature = quadrature ;
64148 }
65149
66- const quadrature:: IQuadrature<time>* get_quadrature () const
150+ shared_ptr< const IQuadrature<time>> get_quadrature () const
67151 {
68- return this ->quad ;
152+ return this ->quadrature ;
69153 }
70154
71155 shared_ptr<Encapsulation<time>> get_start_state () const
@@ -75,7 +159,7 @@ namespace pfasst
75159
76160 const vector<time> get_nodes () const
77161 {
78- return this ->quad ->get_nodes ();
162+ return this ->quadrature ->get_nodes ();
79163 }
80164
81165 void set_factory (shared_ptr<EncapFactory<time>> factory)
@@ -88,48 +172,6 @@ namespace pfasst
88172 return factory;
89173 }
90174
91- /* *
92- * retrieve solution values of current iteration at time node index `m`
93- *
94- * @param[in] m 0-based index of time node
95- *
96- * @note This method must be implemented in derived sweepers.
97- */
98- virtual shared_ptr<Encapsulation<time>> get_state (size_t m) const
99- {
100- UNUSED (m);
101- throw NotImplementedYet (" sweeper" );
102- return NULL ;
103- }
104-
105- /* *
106- * retrieves FAS correction of current iteration at time node index `m`
107- *
108- * @param[in] m 0-based index of time node
109- *
110- * @note This method must be implemented in derived sweepers.
111- */
112- virtual shared_ptr<Encapsulation<time>> get_tau (size_t m) const
113- {
114- UNUSED (m);
115- throw NotImplementedYet (" sweeper" );
116- return NULL ;
117- }
118-
119- /* *
120- * retrieves solution values of previous iteration at time node index `m`
121- *
122- * @param[in] m 0-based index of time node
123- *
124- * @note This method must be implemented in derived sweepers.
125- */
126- virtual shared_ptr<Encapsulation<time>> get_saved_state (size_t m) const
127- {
128- UNUSED (m);
129- throw NotImplementedYet (" sweeper" );
130- return NULL ;
131- }
132-
133175 virtual shared_ptr<Encapsulation<time>> get_end_state ()
134176 {
135177 return this ->end_state ;
@@ -229,25 +271,29 @@ namespace pfasst
229271 // ! @{
230272 virtual void post (ICommunicator* comm, int tag) override
231273 {
232- this ->get_state ( 0 ) ->post (comm, tag);
274+ this ->start_state ->post (comm, tag);
233275 }
234276
235277 virtual void send (ICommunicator* comm, int tag, bool blocking) override
236278 {
237- this ->get_state ( this -> get_nodes (). size () - 1 ) ->send (comm, tag, blocking);
279+ this ->end_state ->send (comm, tag, blocking);
238280 }
239281
240282 virtual void recv (ICommunicator* comm, int tag, bool blocking) override
241283 {
242- this ->get_state (0 )->recv (comm, tag, blocking);
284+ this ->start_state ->recv (comm, tag, blocking);
285+ // XXX
286+ this ->state .front ()->copy (this ->start_state );
243287 }
244288
245289 virtual void broadcast (ICommunicator* comm) override
246290 {
247291 if (comm->rank () == comm->size () - 1 ) {
248- this ->get_state ( 0 ) ->copy (this ->get_state ( this -> get_nodes (). size () - 1 ) );
292+ this ->start_state ->copy (this ->end_state );
249293 }
250- this ->get_state (0 )->broadcast (comm);
294+ this ->start_state ->broadcast (comm);
295+ // XXX
296+ this ->state .front ()->copy (this ->start_state );
251297 }
252298 // ! @}
253299 };
0 commit comments