Skip to content

Commit e46136b

Browse files
committed
implicit: Split decl/impl.
1 parent fc47125 commit e46136b

File tree

2 files changed

+219
-184
lines changed

2 files changed

+219
-184
lines changed
Lines changed: 12 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,16 @@
1-
21
#ifndef _PFASST_ENCAP_IMPLICIT_SWEEPER_HPP_
32
#define _PFASST_ENCAP_IMPLICIT_SWEEPER_HPP_
43

5-
#include <cstdlib>
6-
#include <cassert>
7-
#include <vector>
84
#include <memory>
5+
#include <vector>
96

10-
#include "../globals.hpp"
11-
#include "../quadrature.hpp"
12-
#include "encapsulation.hpp"
13-
#include "encap_sweeper.hpp"
14-
#include "vector.hpp"
7+
#include "pfasst/encap/encapsulation.hpp"
8+
#include "pfasst/encap/encap_sweeper.hpp"
159

1610
using namespace std;
1711

1812
namespace pfasst
1913
{
20-
21-
template<typename scalar>
22-
using lu_pair = pair< Matrix<scalar>, Matrix<scalar> >;
23-
24-
template<typename scalar>
25-
static lu_pair<scalar> lu_decomposition(const Matrix<scalar>& A)
26-
{
27-
assert(A.rows() == A.cols());
28-
29-
auto n = A.rows();
30-
31-
Matrix<scalar> L = Matrix<scalar>::Zero(n, n);
32-
Matrix<scalar> U = Matrix<scalar>::Zero(n, n);
33-
34-
if (A.rows() == 1) {
35-
36-
L(0, 0) = 1.0;
37-
U(0, 0) = A(0,0);
38-
39-
} else {
40-
41-
// first row of U is first row of A
42-
auto U12 = A.block(0, 1, 1, n-1);
43-
44-
// first column of L is first column of A / a11
45-
auto L21 = A.block(1, 0, n-1, 1) / A(0, 0);
46-
47-
// remove first row and column and recurse
48-
auto A22 = A.block(1, 1, n-1, n-1);
49-
Matrix<scalar> tmp = A22 - L21 * U12;
50-
auto LU22 = lu_decomposition(tmp);
51-
52-
L(0, 0) = 1.0;
53-
U(0, 0) = A(0, 0);
54-
L.block(1, 0, n-1, 1) = L21;
55-
U.block(0, 1, 1, n-1) = U12;
56-
L.block(1, 1, n-1, n-1) = get<0>(LU22);
57-
U.block(1, 1, n-1, n-1) = get<1>(LU22);
58-
59-
}
60-
61-
return lu_pair<scalar>(L, U);
62-
}
63-
6414
namespace encap
6515
{
6616
using pfasst::encap::Encapsulation;
@@ -90,32 +40,7 @@ namespace pfasst
9040

9141
Matrix<time> q_tilde;
9242

93-
/**
94-
* Set end state to \\( U_0 + \\int F_{expl} + F_{expl} \\).
95-
*/
96-
void set_end_state()
97-
{
98-
if (this->quadrature->right_is_node()) {
99-
this->end_state->copy(this->state.back());
100-
} else {
101-
vector<shared_ptr<Encapsulation<time>>> dst = { this->end_state };
102-
dst[0]->copy(this->start_state);
103-
dst[0]->mat_apply(dst, this->get_controller()->get_time_step(), this->quadrature->get_b_mat(), this->fs_impl, false);
104-
}
105-
}
106-
107-
/**
108-
* Augment nodes: nodes <- [t0] + dt * nodes
109-
*/
110-
vector<time> augment(time t0, time dt, vector<time> const & nodes)
111-
{
112-
vector<time> t(1 + nodes.size());
113-
t[0] = t0;
114-
for (size_t m = 0; m < nodes.size(); m++) {
115-
t[m+1] = t0 + dt * nodes[m];
116-
}
117-
return t;
118-
}
43+
void set_end_state();
11944

12045
public:
12146
//! @{
@@ -126,37 +51,8 @@ namespace pfasst
12651
//! @{
12752
/**
12853
* @copydoc ISweeper::setup(bool)
129-
*
13054
*/
131-
virtual void setup(bool coarse) override
132-
{
133-
pfasst::encap::EncapSweeper<time>::setup(coarse);
134-
135-
auto const nodes = this->quadrature->get_nodes();
136-
auto const num_nodes = this->quadrature->get_num_nodes();
137-
138-
if (this->quadrature->left_is_node()) {
139-
CLOG(INFO, "Sweeper") << "implicit sweeper shouldn't include left endpoint";
140-
throw ValueError("implicit sweeper shouldn't include left endpoint");
141-
}
142-
143-
for (size_t m = 0; m < num_nodes; m++) {
144-
this->s_integrals.push_back(this->get_factory()->create(pfasst::encap::solution));
145-
this->fs_impl.push_back(this->get_factory()->create(pfasst::encap::function));
146-
}
147-
148-
Matrix<time> QT = this->quadrature->get_q_mat().transpose();
149-
auto lu = lu_decomposition(QT);
150-
auto L = get<0>(lu);
151-
auto U = get<1>(lu);
152-
this->q_tilde = U.transpose();
153-
154-
CLOG(DEBUG, "Sweeper") << "Q':" << endl << QT;
155-
CLOG(DEBUG, "Sweeper") << "L:" << endl << L;
156-
CLOG(DEBUG, "Sweeper") << "U:" << endl << U;
157-
CLOG(DEBUG, "Sweeper") << "LU:" << endl << L * U;
158-
CLOG(DEBUG, "Sweeper") << "q_tilde:" << endl << this->q_tilde;
159-
}
55+
virtual void setup(bool coarse) override;
16056

16157
/**
16258
* Compute low-order provisional solution.
@@ -166,103 +62,33 @@ namespace pfasst
16662
* @param[in] initial if `true` the explicit and implicit part of the right hand side of the
16763
* ODE get evaluated with the initial value
16864
*/
169-
virtual void predict(bool initial) override
170-
{
171-
UNUSED(initial);
172-
173-
auto const dt = this->get_controller()->get_time_step();
174-
auto const t = this->get_controller()->get_time();
175-
176-
CLOG(INFO, "Sweeper") << "predicting step " << this->get_controller()->get_step() + 1
177-
<< " (t=" << t << ", dt=" << dt << ")";
178-
179-
auto const anodes = augment(t, dt, this->quadrature->get_nodes());
180-
for (size_t m = 0; m < anodes.size() - 1; ++m) {
181-
this->impl_solve(this->fs_impl[m], this->state[m], anodes[m], anodes[m+1] - anodes[m],
182-
m == 0 ? this->get_start_state() : this->state[m-1]);
183-
}
184-
185-
this->set_end_state();
186-
}
65+
virtual void predict(bool initial) override;
18766

18867
/**
18968
* Perform one SDC sweep/iteration.
19069
*
19170
* This computes a high-order solution from the previous iteration's function values and
19271
* corrects it using forward/backward Euler steps across the nodes.
19372
*/
194-
virtual void sweep() override
195-
{
196-
auto const dt = this->get_controller()->get_time_step();
197-
auto const t = this->get_controller()->get_time();
198-
199-
CLOG(INFO, "Sweeper") << "sweeping on step " << this->get_controller()->get_step() + 1
200-
<< " in iteration " << this->get_controller()->get_iteration()
201-
<< " (dt=" << dt << ")";
202-
203-
this->s_integrals[0]->mat_apply(this->s_integrals, dt, this->quadrature->get_s_mat(), this->fs_impl, true);
204-
if (this->fas_corrections.size() > 0) {
205-
for (size_t m = 0; m < this->s_integrals.size(); m++) {
206-
this->s_integrals[m]->saxpy(1.0, this->fas_corrections[m]);
207-
}
208-
}
209-
210-
for (size_t m = 0; m < this->s_integrals.size(); m++) {
211-
for (size_t n = 0; n < m; n++) {
212-
this->s_integrals[m]->saxpy(-dt*this->q_tilde(m, n), this->fs_impl[n]);
213-
}
214-
}
215-
216-
shared_ptr<Encapsulation<time>> rhs = this->get_factory()->create(pfasst::encap::solution);
217-
218-
auto const anodes = augment(t, dt, this->quadrature->get_nodes());
219-
for (size_t m = 0; m < anodes.size() - 1; ++m) {
220-
auto const ds = anodes[m+1] - anodes[m];
221-
rhs->copy(m == 0 ? this->get_start_state() : this->state[m-1]);
222-
rhs->saxpy(1.0, this->s_integrals[m]);
223-
rhs->saxpy(-ds, this->fs_impl[m]);
224-
for (size_t n = 0; n < m; n++) {
225-
rhs->saxpy(dt*this->q_tilde(m, n), this->fs_impl[n]);
226-
}
227-
this->impl_solve(this->fs_impl[m], this->state[m], anodes[m], ds, rhs);
228-
}
229-
this->set_end_state();
230-
}
73+
virtual void sweep() override;
23174

23275
/**
23376
* Advance the end solution to start solution.
23477
*/
235-
virtual void advance() override
236-
{
237-
this->start_state->copy(this->end_state);
238-
}
78+
virtual void advance() override;
23979

24080
/**
24181
* @copybrief EncapSweeper::evaluate()
24282
*/
243-
virtual void reevaluate(bool initial_only) override
244-
{
245-
if (initial_only) {
246-
return;
247-
}
248-
auto const dt = this->get_controller()->get_time_step();
249-
auto const t0 = this->get_controller()->get_time();
250-
auto const nodes = this->quadrature->get_nodes();
251-
for (size_t m = 0; m < nodes.size(); m++) {
252-
this->f_impl_eval(this->fs_impl[m], this->state[m], t0 + dt * nodes[m]);
253-
}
254-
}
83+
virtual void reevaluate(bool initial_only) override;
25584

25685
/**
25786
* @copybrief EncapSweeper::integrate()
25887
*
25988
* @param[in] dt width of time interval to integrate over
26089
* @param[in,out] dst integrated values; will get zeroed out beforehand
26190
*/
262-
virtual void integrate(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const override
263-
{
264-
dst[0]->mat_apply(dst, dt, this->quadrature->get_q_mat(), this->fs_impl, true);
265-
}
91+
virtual void integrate(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const override;
26692
//! @}
26793

26894
//! @{
@@ -316,4 +142,6 @@ namespace pfasst
316142
} // ::pfasst::encap
317143
} // ::pfasst
318144

145+
#include "pfasst/encap/implicit_sweeper_impl.hpp"
146+
319147
#endif

0 commit comments

Comments
 (0)