Skip to content

Commit 3b1ed9a

Browse files
committed
encap: Use mat_apply to compute tau instead of saxpys.
Signed-off-by: Matthew Emmett <[email protected]>
1 parent 8fe2b6b commit 3b1ed9a

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

include/pfasst/encap/poly_interp.hpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace pfasst
2121
class PolyInterpMixin
2222
: public pfasst::ITransfer<time>
2323
{
24+
using EncapVecT = vector<shared_ptr<Encapsulation<time>>>;
2425
matrix<time> tmat, fmat;
2526

2627
public:
@@ -54,7 +55,7 @@ namespace pfasst
5455
auto crse_factory = crse->get_factory();
5556
auto fine_factory = fine->get_factory();
5657

57-
vector<shared_ptr<Encapsulation<time>>> fine_state(nfine), fine_delta(ncrse);
58+
EncapVecT fine_state(nfine), fine_delta(ncrse);
5859

5960
for (size_t m = 0; m < nfine; m++) { fine_state[m] = fine->get_state(m); }
6061
for (size_t m = 0; m < ncrse; m++) { fine_delta[m] = fine_factory->create(solution); }
@@ -123,7 +124,7 @@ namespace pfasst
123124
}
124125

125126
virtual void fas(time dt, shared_ptr<ISweeper<time>> dst,
126-
shared_ptr<const ISweeper<time>> src)
127+
shared_ptr<const ISweeper<time>> src)
127128
{
128129
shared_ptr<EncapSweeper<time>> crse = dynamic_pointer_cast<EncapSweeper<time>>(dst);
129130
assert(crse);
@@ -136,17 +137,14 @@ namespace pfasst
136137
virtual void fas(time dt, shared_ptr<EncapSweeper<time>> crse,
137138
shared_ptr<const EncapSweeper<time>> fine)
138139
{
139-
size_t ncrse = crse->get_nodes().size();
140-
assert(ncrse > 1);
141-
size_t nfine = fine->get_nodes().size();
142-
assert(nfine >= 1);
140+
size_t ncrse = crse->get_nodes().size(); assert(ncrse >= 1);
141+
size_t nfine = fine->get_nodes().size(); assert(nfine >= 1);
143142

144143
auto crse_factory = crse->get_factory();
145144
auto fine_factory = fine->get_factory();
146145

147-
vector<shared_ptr<Encapsulation<time>>> crse_z2n(ncrse - 1)
148-
, fine_z2n(nfine - 1)
149-
, rstr_z2n(ncrse - 1);
146+
EncapVecT crse_z2n(ncrse - 1), fine_z2n(nfine - 1), rstr_z2n(ncrse - 1);
147+
150148
for (size_t m = 0; m < ncrse - 1; m++) { crse_z2n[m] = crse_factory->create(solution); }
151149
for (size_t m = 0; m < ncrse - 1; m++) { rstr_z2n[m] = crse_factory->create(solution); }
152150
for (size_t m = 0; m < nfine - 1; m++) { fine_z2n[m] = fine_factory->create(solution); }
@@ -170,21 +168,27 @@ namespace pfasst
170168
}
171169

172170
// compute 'node to node' tau correction
173-
vector<shared_ptr<Encapsulation<time>>> tau(ncrse - 1);
174-
for (size_t m = 0; m < ncrse - 1; m++) {
175-
tau[m] = crse->get_tau(m);
171+
EncapVecT tau(ncrse - 1), rstr_and_crse(2 * (ncrse - 1));
172+
for (size_t m = 0; m < ncrse - 1; m++) { tau[m] = crse->get_tau(m); }
173+
for (size_t m = 0; m < ncrse - 1; m++) { rstr_and_crse[m] = rstr_z2n[m]; }
174+
for (size_t m = 0; m < ncrse - 1; m++) { rstr_and_crse[ncrse - 1 + m] = crse_z2n[m]; }
175+
176+
if (fmat.size1() == 0) {
177+
fmat.resize(ncrse - 1, 2 * (ncrse - 1));
178+
fmat.clear();
179+
180+
for (size_t m = 0; m < ncrse - 1; m++) {
181+
fmat(m, m) = 1.0;
182+
fmat(m, ncrse - 1 + m) = -1.0;
183+
184+
for (size_t n = 0; n < m; n++) {
185+
fmat(m, n) = -1.0;
186+
fmat(m, ncrse - 1 + n) = 1.0;
187+
}
188+
}
176189
}
177190

178-
tau[0]->copy(rstr_z2n[0]);
179-
tau[0]->saxpy(-1.0, crse_z2n[0]);
180-
181-
for (size_t m = 1; m < ncrse - 1; m++) {
182-
tau[m]->copy(rstr_z2n[m]);
183-
tau[m]->saxpy(-1.0, rstr_z2n[m - 1]);
184-
185-
tau[m]->saxpy(-1.0, crse_z2n[m]);
186-
tau[m]->saxpy(1.0, crse_z2n[m - 1]);
187-
}
191+
tau[0]->mat_apply(tau, 1.0, fmat, rstr_and_crse, true);
188192
}
189193

190194
// required for interp/restrict helpers

0 commit comments

Comments
 (0)