Skip to content

Commit 024b40f

Browse files
committed
Merge pull request #60 from memmett/feature/flense
Tidy and flense various.
2 parents cd7a1b5 + 27e42eb commit 024b40f

File tree

4 files changed

+94
-107
lines changed

4 files changed

+94
-107
lines changed

CMakeLists.txt

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ list(APPEND CMAKE_MODULE_PATH ${pfasst_SOURCE_DIR}/cmake)
55
include(cmake/utility_functions.cmake)
66
include(CheckCXXCompilerFlag)
77
include(ExternalProject)
8+
89
# Set default ExternalProject root directory
910
set_directory_properties(PROPERTIES EP_PREFIX ${CMAKE_BINARY_DIR}/3rdparty)
1011

@@ -13,19 +14,17 @@ option(pfasst_BUILD_EXAMPLES "Build example programs."
1314
option(pfasst_BUILD_TESTS "Build test suite for PFASST." ON )
1415
option(pfasst_WITH_MPI "Build with MPI enabled." OFF)
1516

16-
# output directories
17+
# Set output directories
1718
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${pfasst_SOURCE_DIR}/dist/bin")
1819
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${pfasst_SOURCE_DIR}/dist/lib")
1920
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${pfasst_SOURCE_DIR}/dist/lib")
2021

2122
if(${pfasst_WITH_MPI})
2223
find_package(MPI REQUIRED)
23-
# set(CMAKE_C_COMPILER ${MPI_C_COMPILER})
24-
# set(CMAKE_CXX_COMPILER ${MPI_CXX_COMPILER})
2524
message(STATUS "Using MPI C++ Compiler: ${MPI_CXX_COMPILER}")
2625
endif()
2726

28-
# check for C++11 support
27+
# Check for C++11 support
2928
if(${CMAKE_CXX_COMPILER_ID} MATCHES GNU)
3029
check_cxx_compiler_flag(-std=c++11 HAVE_STD11)
3130
if(HAVE_STD11)
@@ -63,7 +62,7 @@ else()
6362
endif()
6463
message(STATUS "Your compiler has C++11 support. Hurray!")
6564

66-
# enable all compiler warnings
65+
# Enable all compiler warnings
6766
add_to_string_list("${CMAKE_CXX_FLAGS}" CMAKE_CXX_FLAGS "-Wall -Wextra -Wpedantic")
6867

6968
set(3rdparty_INCLUDES)
@@ -75,13 +74,14 @@ if(pfasst_BUILD_TESTS)
7574
enable_testing()
7675
endif(pfasst_BUILD_TESTS)
7776

78-
# adding / including 3rd-party libraries
77+
# Add / include 3rd-party libraries
7978
message(STATUS "********************************************************************************")
8079
message(STATUS "Configuring 3rd party libraries")
8180
# makes available:
82-
# - fftw3_INCLUDES (if pfasst_BUILD_EXAMPLES)
83-
# - fftw3_LIBS (if pfasst_BUILD_EXAMPLES)
84-
# and Boost headers in 3rdparty_INCLUDES
81+
# - Boost headers in 3rdparty_INCLUDES
82+
# - Google test and mock headers in 3rdparty_INCLUDES (if pfasst_BUILD_TESTS)
83+
# - FFTW_INCLUDE_PATH (if pfasst_BUILD_EXAMPLES)
84+
# - FFTW_LIBRARIES (if pfasst_BUILD_EXAMPLES)
8585
add_subdirectory(3rdparty)
8686

8787
message(STATUS "********************************************************************************")

include/pfasst/encap/poly_interp.hpp

Lines changed: 38 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ namespace pfasst
5555
fine.evaluate(0);
5656
}
5757

58+
5859
virtual void interpolate(shared_ptr<ISweeper<time>> dst,
5960
shared_ptr<const ISweeper<time>> src,
6061
bool interp_initial) override
@@ -77,22 +78,11 @@ namespace pfasst
7778
for (size_t m = 0; m < nfine; m++) { fine_state[m] = fine.get_state(m); }
7879
for (size_t m = 0; m < ncrse; m++) { fine_delta[m] = fine_factory->create(solution); }
7980

80-
// if (interp_delta_from_initial) {
81-
// for (size_t m = 1; m < nfine; m++) {
82-
// fine_state[m]->copy(fine_state[0]);
83-
// }
84-
// }
85-
8681
auto crse_delta = crse_factory->create(solution);
8782
size_t m0 = interp_initial ? 0 : 1;
8883
for (size_t m = m0; m < ncrse; m++) {
8984
crse_delta->copy(crse.get_state(m));
90-
// if (interp_delta_from_initial) {
91-
// crse_delta->saxpy(-1.0, crse->get_saved_state(0));
92-
// // crse_delta->saxpy(-1.0, crse->get_state(0));
93-
// } else {
94-
crse_delta->saxpy(-1.0, crse.get_saved_state(m));
95-
// }
85+
crse_delta->saxpy(-1.0, crse.get_saved_state(m));
9686
interpolate(fine_delta[m], crse_delta);
9787
}
9888

@@ -105,43 +95,39 @@ namespace pfasst
10595
for (size_t m = m0; m < nfine; m++) { fine.evaluate(m); }
10696
}
10797

108-
// required for interp/restrict helpers
109-
virtual void interpolate(shared_ptr<Encapsulation<time>> /*dst*/,
110-
shared_ptr<const Encapsulation<time>> /*src*/)
98+
99+
virtual void interpolate(shared_ptr<Encapsulation<time>> dst,
100+
shared_ptr<const Encapsulation<time>> src)
111101
{
102+
UNUSED(dst); UNUSED(src);
112103
throw NotImplementedYet("mlsdc/pfasst");
113104
}
105+
//! @}
114106

115-
virtual void restrict(shared_ptr<ISweeper<time>> dst,
116-
shared_ptr<const ISweeper<time>> src,
117-
bool restrict_initial,
118-
bool restrict_initial_only) override
119-
{
120-
shared_ptr<EncapSweeper<time>> crse = dynamic_pointer_cast<EncapSweeper<time>>(dst);
121-
assert(crse);
122-
shared_ptr<const EncapSweeper<time>> fine = \
123-
dynamic_pointer_cast<const EncapSweeper<time>>(src);
124-
assert(fine);
125107

126-
this->restrict(crse, fine, restrict_initial, restrict_initial_only);
108+
//! @{
109+
virtual void restrict_initial(shared_ptr<ISweeper<time>> dst,
110+
shared_ptr<const ISweeper<time>> src) override
111+
{
112+
auto& crse = as_encap_sweeper(dst);
113+
auto& fine = as_encap_sweeper(src);
114+
this->restrict(crse.get_state(0), fine.get_state(0));
127115
}
128116

129-
virtual void restrict(shared_ptr<EncapSweeper<time>> crse,
130-
shared_ptr<const EncapSweeper<time>> fine,
131-
bool restrict_initial,
132-
bool restrict_initial_only)
117+
118+
virtual void restrict(shared_ptr<ISweeper<time>> dst,
119+
shared_ptr<const ISweeper<time>> src,
120+
bool restrict_initial) override
133121
{
134-
if (restrict_initial_only) {
135-
this->restrict(crse->get_state(0), fine->get_state(0));
136-
return;
137-
}
122+
auto& crse = as_encap_sweeper(dst);
123+
auto& fine = as_encap_sweeper(src);
138124

139-
auto dnodes = crse->get_nodes();
140-
auto snodes = fine->get_nodes();
125+
auto dnodes = crse.get_nodes();
126+
auto snodes = fine.get_nodes();
141127

142-
size_t ncrse = crse->get_nodes().size();
128+
size_t ncrse = crse.get_nodes().size();
143129
assert(ncrse > 1);
144-
size_t nfine = fine->get_nodes().size();
130+
size_t nfine = fine.get_nodes().size();
145131

146132
int trat = (int(nfine) - 1) / (int(ncrse) - 1);
147133

@@ -150,39 +136,32 @@ namespace pfasst
150136
if (dnodes[m] != snodes[m * trat]) {
151137
throw NotImplementedYet("coarse nodes must be nested");
152138
}
153-
this->restrict(crse->get_state(m), fine->get_state(m * trat));
139+
this->restrict(crse.get_state(m), fine.get_state(m * trat));
154140
}
155141

156-
for (size_t m = m0; m < ncrse; m++) { crse->evaluate(m); }
142+
for (size_t m = m0; m < ncrse; m++) { crse.evaluate(m); }
157143
}
158144

145+
159146
virtual void restrict(shared_ptr<Encapsulation<time>> dst,
160147
shared_ptr<const Encapsulation<time>> src)
161148
{
162149
UNUSED(dst); UNUSED(src);
163150
throw NotImplementedYet("mlsdc/pfasst");
164151
}
152+
//! @}
165153

166154
virtual void fas(time dt, shared_ptr<ISweeper<time>> dst,
167155
shared_ptr<const ISweeper<time>> src) override
168156
{
169-
shared_ptr<EncapSweeper<time>> crse = dynamic_pointer_cast<EncapSweeper<time>>(dst);
170-
assert(crse);
171-
shared_ptr<const EncapSweeper<time>> fine = \
172-
dynamic_pointer_cast<const EncapSweeper<time>>(src);
173-
assert(fine);
174-
175-
this->fas(dt, crse, fine);
176-
}
157+
auto& crse = pfasst::encap::as_encap_sweeper(dst);
158+
auto& fine = pfasst::encap::as_encap_sweeper(src);
177159

178-
virtual void fas(time dt, shared_ptr<EncapSweeper<time>> crse,
179-
shared_ptr<const EncapSweeper<time>> fine)
180-
{
181-
size_t ncrse = crse->get_nodes().size(); assert(ncrse >= 1);
182-
size_t nfine = fine->get_nodes().size(); assert(nfine >= 1);
160+
size_t ncrse = crse.get_nodes().size(); assert(ncrse >= 1);
161+
size_t nfine = fine.get_nodes().size(); assert(nfine >= 1);
183162

184-
auto crse_factory = crse->get_factory();
185-
auto fine_factory = fine->get_factory();
163+
auto crse_factory = crse.get_factory();
164+
auto fine_factory = fine.get_factory();
186165

187166
EncapVecT crse_z2n(ncrse - 1), fine_z2n(nfine - 1), rstr_z2n(ncrse - 1);
188167

@@ -191,13 +170,13 @@ namespace pfasst
191170
for (size_t m = 0; m < nfine - 1; m++) { fine_z2n[m] = fine_factory->create(solution); }
192171

193172
// compute '0 to node' integral on the coarse level
194-
crse->integrate(dt, crse_z2n);
173+
crse.integrate(dt, crse_z2n);
195174
for (size_t m = 1; m < ncrse - 1; m++) {
196175
crse_z2n[m]->saxpy(1.0, crse_z2n[m - 1]);
197176
}
198177

199178
// compute '0 to node' integral on the fine level
200-
fine->integrate(dt, fine_z2n);
179+
fine.integrate(dt, fine_z2n);
201180
for (size_t m = 1; m < nfine - 1; m++) {
202181
fine_z2n[m]->saxpy(1.0, fine_z2n[m - 1]);
203182
}
@@ -210,7 +189,7 @@ namespace pfasst
210189

211190
// compute 'node to node' tau correction
212191
EncapVecT tau(ncrse - 1), rstr_and_crse(2 * (ncrse - 1));
213-
for (size_t m = 0; m < ncrse - 1; m++) { tau[m] = crse->get_tau(m); }
192+
for (size_t m = 0; m < ncrse - 1; m++) { tau[m] = crse.get_tau(m); }
214193
for (size_t m = 0; m < ncrse - 1; m++) { rstr_and_crse[m] = rstr_z2n[m]; }
215194
for (size_t m = 0; m < ncrse - 1; m++) { rstr_and_crse[ncrse - 1 + m] = crse_z2n[m]; }
216195

@@ -231,7 +210,7 @@ namespace pfasst
231210

232211
tau[0]->mat_apply(tau, 1.0, fmat, rstr_and_crse, true);
233212
}
234-
//! @}
213+
235214
};
236215

237216
} // ::pfasst::encap

include/pfasst/interfaces.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ namespace pfasst
211211

212212
//! @{
213213
/**
214-
* Interpolate initial condition (in space) from the coarse sweeper to the fine sweeper.
214+
* Interpolate initial condition from the coarse sweeper to the fine sweeper.
215215
*/
216216
virtual void interpolate_initial(shared_ptr<ISweeper<time>> dst,
217217
shared_ptr<const ISweeper<time>> src)
@@ -229,15 +229,28 @@ namespace pfasst
229229
shared_ptr<const ISweeper<time>> src,
230230
bool interp_initial = false) = 0;
231231

232+
/**
233+
* Restrict initial condition from the fine sweeper to the coarse sweeper.
234+
* @param[in] restrict_initial
235+
* `true` if the initial condition should also be restricted.
236+
*/
237+
virtual void restrict_initial(shared_ptr<ISweeper<time>> dst,
238+
shared_ptr<const ISweeper<time>> src)
239+
{
240+
UNUSED(dst); UNUSED(src);
241+
NotImplementedYet("pfasst");
242+
}
243+
244+
232245
/**
233246
* Restrict, in time and space, from the fine sweeper to the coarse sweeper.
234247
* @param[in] restrict_initial
235248
* `true` if the initial condition should also be restricted.
236249
*/
237250
virtual void restrict(shared_ptr<ISweeper<time>> dst,
238251
shared_ptr<const ISweeper<time>> src,
239-
bool restrict_initial = false,
240-
bool restrict_initial_only = false) = 0;
252+
bool restrict_initial = false) = 0;
253+
241254

242255
/**
243256
* Compute FAS correction between the coarse and fine sweepers.

include/pfasst/pfasst.hpp

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ namespace pfasst
5454
auto crse = l.current();
5555
auto fine = l.fine();
5656
auto trns = l.transfer();
57-
trns->restrict(crse, fine, true, true);
57+
trns->restrict_initial(crse, fine);
5858
crse->spread();
5959
crse->save();
6060
}
@@ -63,9 +63,7 @@ namespace pfasst
6363
predict = true;
6464
auto crse = this->coarsest().current();
6565
for (int nstep = 0; nstep < comm->rank() + 1; nstep++) {
66-
// this->set_step(comm->rank());
67-
// XXX: set iteration?
68-
66+
// XXX: set iteration and step?
6967
perform_sweeps(0);
7068
if (nstep < comm->rank()) {
7169
crse->advance();
@@ -91,6 +89,24 @@ namespace pfasst
9189
initial = true;
9290
}
9391

92+
int tag (int level)
93+
{
94+
return level * 10000 + this->get_iteration() + 10;
95+
}
96+
97+
int tag(LevelIter l)
98+
{
99+
return tag(l.level);
100+
}
101+
102+
void post()
103+
{
104+
for (auto l = this->coarsest() + 1; l <= this->finest(); ++l) {
105+
l.current()->post(comm, tag(l));
106+
}
107+
}
108+
109+
94110
/**
95111
* Evolve ODE using PFASST.
96112
*
@@ -115,29 +131,8 @@ namespace pfasst
115131

116132
for (this->set_iteration(0); this->get_iteration() < this->get_max_iterations();
117133
this->advance_iteration()) {
118-
for (auto l = this->coarsest() + 1; l <= this->finest(); ++l) {
119-
int tag = l.level * 10000 + this->get_iteration() + 10;
120-
l.current()->post(comm, tag);
121-
}
122-
123-
perform_sweeps(this->nlevels() - 1);
124-
// XXX check convergence
125-
auto fine = this->get_level(this->nlevels() - 1);
126-
auto crse = this->get_level(this->nlevels() - 2);
127-
auto trns = this->get_transfer(this->nlevels() - 1);
128-
129-
int tag = (this->nlevels() - 1) * 10000 + this->get_iteration() + 10;
130-
fine->send(comm, tag, false);
131-
trns->restrict(crse, fine, true, false);
132-
trns->fas(this->get_time_step(), crse, fine);
133-
crse->save();
134-
135-
cycle_v(this->finest() - 1);
136-
137-
trns->interpolate(fine, crse, true);
138-
fine->recv(comm, tag, false);
139-
trns->interpolate_initial(fine, crse);
140-
// XXX: call interpolate_q0(pf,F, G)
134+
post();
135+
cycle_v(this->finest());
141136
}
142137

143138
if (nblock < nblocks - 1) {
@@ -157,11 +152,14 @@ namespace pfasst
157152

158153
perform_sweeps(l.level);
159154

160-
int tag = l.level * 10000 + this->get_iteration() + 10;
161-
fine->send(comm, tag, false);
155+
if (l == this->finest()) {
156+
// note: convergence tests belong here
157+
}
158+
159+
fine->send(comm, tag(l), false);
162160

163161
auto dt = this->get_time_step();
164-
trns->restrict(crse, fine, true, false);
162+
trns->restrict(crse, fine, true);
165163
trns->fas(dt, crse, fine);
166164
crse->save();
167165

@@ -184,9 +182,7 @@ namespace pfasst
184182

185183
trns->interpolate(fine, crse, true);
186184

187-
int tag = l.level * 10000 + this->get_iteration() + 10;
188-
fine->recv(comm, tag, false);
189-
// XXX call interpolate_q0(pf,F, G)
185+
fine->recv(comm, tag(l), false);
190186
trns->interpolate_initial(fine, crse);
191187

192188
if (l < this->finest()) {
@@ -203,10 +199,9 @@ namespace pfasst
203199
{
204200
auto crse = l.current();
205201

206-
int tag = l.level * 10000 + this->get_iteration() + 10;
207-
crse->recv(comm, tag, true);
202+
crse->recv(comm, tag(l), true);
208203
this->perform_sweeps(l.level);
209-
crse->send(comm, tag, true);
204+
crse->send(comm, tag(l), true);
210205
return l + 1;
211206
}
212207

0 commit comments

Comments
 (0)