1-
21#ifndef _PFASST_ENCAP_IMPLICIT_SWEEPER_HPP_
32#define _PFASST_ENCAP_IMPLICIT_SWEEPER_HPP_
43
5- #include < cstdlib>
6- #include < cassert>
7- #include < vector>
84#include < memory>
5+ #include < vector>
96
10- #include " ../globals.hpp"
11- #include " ../quadrature.hpp"
12- #include " encapsulation.hpp"
13- #include " encap_sweeper.hpp"
14- #include " vector.hpp"
7+ #include " pfasst/encap/encapsulation.hpp"
8+ #include " pfasst/encap/encap_sweeper.hpp"
159
1610using namespace std ;
1711
1812namespace pfasst
1913{
20-
21- template <typename scalar>
22- using lu_pair = pair< Matrix<scalar>, Matrix<scalar> >;
23-
24- template <typename scalar>
25- static lu_pair<scalar> lu_decomposition (const Matrix<scalar>& A)
26- {
27- assert (A.rows () == A.cols ());
28-
29- auto n = A.rows ();
30-
31- Matrix<scalar> L = Matrix<scalar>::Zero (n, n);
32- Matrix<scalar> U = Matrix<scalar>::Zero (n, n);
33-
34- if (A.rows () == 1 ) {
35-
36- L (0 , 0 ) = 1.0 ;
37- U (0 , 0 ) = A (0 ,0 );
38-
39- } else {
40-
41- // first row of U is first row of A
42- auto U12 = A.block (0 , 1 , 1 , n-1 );
43-
44- // first column of L is first column of A / a11
45- auto L21 = A.block (1 , 0 , n-1 , 1 ) / A (0 , 0 );
46-
47- // remove first row and column and recurse
48- auto A22 = A.block (1 , 1 , n-1 , n-1 );
49- Matrix<scalar> tmp = A22 - L21 * U12;
50- auto LU22 = lu_decomposition (tmp);
51-
52- L (0 , 0 ) = 1.0 ;
53- U (0 , 0 ) = A (0 , 0 );
54- L.block (1 , 0 , n-1 , 1 ) = L21;
55- U.block (0 , 1 , 1 , n-1 ) = U12;
56- L.block (1 , 1 , n-1 , n-1 ) = get<0 >(LU22);
57- U.block (1 , 1 , n-1 , n-1 ) = get<1 >(LU22);
58-
59- }
60-
61- return lu_pair<scalar>(L, U);
62- }
63-
6414 namespace encap
6515 {
6616 using pfasst::encap::Encapsulation;
@@ -90,32 +40,7 @@ namespace pfasst
9040
9141 Matrix<time> q_tilde;
9242
93- /* *
94- * Set end state to \\( U_0 + \\int F_{expl} + F_{expl} \\).
95- */
96- void set_end_state ()
97- {
98- if (this ->quadrature ->right_is_node ()) {
99- this ->end_state ->copy (this ->state .back ());
100- } else {
101- vector<shared_ptr<Encapsulation<time>>> dst = { this ->end_state };
102- dst[0 ]->copy (this ->start_state );
103- dst[0 ]->mat_apply (dst, this ->get_controller ()->get_time_step (), this ->quadrature ->get_b_mat (), this ->fs_impl , false );
104- }
105- }
106-
107- /* *
108- * Augment nodes: nodes <- [t0] + dt * nodes
109- */
110- vector<time> augment (time t0, time dt, vector<time> const & nodes)
111- {
112- vector<time> t (1 + nodes.size ());
113- t[0 ] = t0;
114- for (size_t m = 0 ; m < nodes.size (); m++) {
115- t[m+1 ] = t0 + dt * nodes[m];
116- }
117- return t;
118- }
43+ void set_end_state ();
11944
12045 public:
12146 // ! @{
@@ -126,37 +51,8 @@ namespace pfasst
12651 // ! @{
12752 /* *
12853 * @copydoc ISweeper::setup(bool)
129- *
13054 */
131- virtual void setup (bool coarse) override
132- {
133- pfasst::encap::EncapSweeper<time>::setup (coarse);
134-
135- auto const nodes = this ->quadrature ->get_nodes ();
136- auto const num_nodes = this ->quadrature ->get_num_nodes ();
137-
138- if (this ->quadrature ->left_is_node ()) {
139- CLOG (INFO, " Sweeper" ) << " implicit sweeper shouldn't include left endpoint" ;
140- throw ValueError (" implicit sweeper shouldn't include left endpoint" );
141- }
142-
143- for (size_t m = 0 ; m < num_nodes; m++) {
144- this ->s_integrals .push_back (this ->get_factory ()->create (pfasst::encap::solution));
145- this ->fs_impl .push_back (this ->get_factory ()->create (pfasst::encap::function));
146- }
147-
148- Matrix<time> QT = this ->quadrature ->get_q_mat ().transpose ();
149- auto lu = lu_decomposition (QT);
150- auto L = get<0 >(lu);
151- auto U = get<1 >(lu);
152- this ->q_tilde = U.transpose ();
153-
154- CLOG (DEBUG, " Sweeper" ) << " Q':" << endl << QT;
155- CLOG (DEBUG, " Sweeper" ) << " L:" << endl << L;
156- CLOG (DEBUG, " Sweeper" ) << " U:" << endl << U;
157- CLOG (DEBUG, " Sweeper" ) << " LU:" << endl << L * U;
158- CLOG (DEBUG, " Sweeper" ) << " q_tilde:" << endl << this ->q_tilde ;
159- }
55+ virtual void setup (bool coarse) override ;
16056
16157 /* *
16258 * Compute low-order provisional solution.
@@ -166,103 +62,33 @@ namespace pfasst
16662 * @param[in] initial if `true` the explicit and implicit part of the right hand side of the
16763 * ODE get evaluated with the initial value
16864 */
169- virtual void predict (bool initial) override
170- {
171- UNUSED (initial);
172-
173- auto const dt = this ->get_controller ()->get_time_step ();
174- auto const t = this ->get_controller ()->get_time ();
175-
176- CLOG (INFO, " Sweeper" ) << " predicting step " << this ->get_controller ()->get_step () + 1
177- << " (t=" << t << " , dt=" << dt << " )" ;
178-
179- auto const anodes = augment (t, dt, this ->quadrature ->get_nodes ());
180- for (size_t m = 0 ; m < anodes.size () - 1 ; ++m) {
181- this ->impl_solve (this ->fs_impl [m], this ->state [m], anodes[m], anodes[m+1 ] - anodes[m],
182- m == 0 ? this ->get_start_state () : this ->state [m-1 ]);
183- }
184-
185- this ->set_end_state ();
186- }
65+ virtual void predict (bool initial) override ;
18766
18867 /* *
18968 * Perform one SDC sweep/iteration.
19069 *
19170 * This computes a high-order solution from the previous iteration's function values and
19271 * corrects it using forward/backward Euler steps across the nodes.
19372 */
194- virtual void sweep () override
195- {
196- auto const dt = this ->get_controller ()->get_time_step ();
197- auto const t = this ->get_controller ()->get_time ();
198-
199- CLOG (INFO, " Sweeper" ) << " sweeping on step " << this ->get_controller ()->get_step () + 1
200- << " in iteration " << this ->get_controller ()->get_iteration ()
201- << " (dt=" << dt << " )" ;
202-
203- this ->s_integrals [0 ]->mat_apply (this ->s_integrals , dt, this ->quadrature ->get_s_mat (), this ->fs_impl , true );
204- if (this ->fas_corrections .size () > 0 ) {
205- for (size_t m = 0 ; m < this ->s_integrals .size (); m++) {
206- this ->s_integrals [m]->saxpy (1.0 , this ->fas_corrections [m]);
207- }
208- }
209-
210- for (size_t m = 0 ; m < this ->s_integrals .size (); m++) {
211- for (size_t n = 0 ; n < m; n++) {
212- this ->s_integrals [m]->saxpy (-dt*this ->q_tilde (m, n), this ->fs_impl [n]);
213- }
214- }
215-
216- shared_ptr<Encapsulation<time>> rhs = this ->get_factory ()->create (pfasst::encap::solution);
217-
218- auto const anodes = augment (t, dt, this ->quadrature ->get_nodes ());
219- for (size_t m = 0 ; m < anodes.size () - 1 ; ++m) {
220- auto const ds = anodes[m+1 ] - anodes[m];
221- rhs->copy (m == 0 ? this ->get_start_state () : this ->state [m-1 ]);
222- rhs->saxpy (1.0 , this ->s_integrals [m]);
223- rhs->saxpy (-ds, this ->fs_impl [m]);
224- for (size_t n = 0 ; n < m; n++) {
225- rhs->saxpy (dt*this ->q_tilde (m, n), this ->fs_impl [n]);
226- }
227- this ->impl_solve (this ->fs_impl [m], this ->state [m], anodes[m], ds, rhs);
228- }
229- this ->set_end_state ();
230- }
73+ virtual void sweep () override ;
23174
23275 /* *
23376 * Advance the end solution to start solution.
23477 */
235- virtual void advance () override
236- {
237- this ->start_state ->copy (this ->end_state );
238- }
78+ virtual void advance () override ;
23979
24080 /* *
24181 * @copybrief EncapSweeper::evaluate()
24282 */
243- virtual void reevaluate (bool initial_only) override
244- {
245- if (initial_only) {
246- return ;
247- }
248- auto const dt = this ->get_controller ()->get_time_step ();
249- auto const t0 = this ->get_controller ()->get_time ();
250- auto const nodes = this ->quadrature ->get_nodes ();
251- for (size_t m = 0 ; m < nodes.size (); m++) {
252- this ->f_impl_eval (this ->fs_impl [m], this ->state [m], t0 + dt * nodes[m]);
253- }
254- }
83+ virtual void reevaluate (bool initial_only) override ;
25584
25685 /* *
25786 * @copybrief EncapSweeper::integrate()
25887 *
25988 * @param[in] dt width of time interval to integrate over
26089 * @param[in,out] dst integrated values; will get zeroed out beforehand
26190 */
262- virtual void integrate (time dt, vector<shared_ptr<Encapsulation<time>>> dst) const override
263- {
264- dst[0 ]->mat_apply (dst, dt, this ->quadrature ->get_q_mat (), this ->fs_impl , true );
265- }
91+ virtual void integrate (time dt, vector<shared_ptr<Encapsulation<time>>> dst) const override ;
26692 // ! @}
26793
26894 // ! @{
@@ -316,4 +142,6 @@ namespace pfasst
316142 } // ::pfasst::encap
317143} // ::pfasst
318144
145+ #include " pfasst/encap/implicit_sweeper_impl.hpp"
146+
319147#endif
0 commit comments