Skip to content

Commit f089690

Browse files
committed
internals: splitting up decls/defs for mpi communicator
towards fixing #84
1 parent e491dab commit f089690

File tree

2 files changed

+144
-106
lines changed

2 files changed

+144
-106
lines changed
Lines changed: 22 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,29 @@
1-
/*
2-
* Interfaces for SDC/MLSDC/PFASST algorithms.
3-
*/
4-
51
#ifndef _PFASST_MPI_COMMUNICATOR_HPP_
62
#define _PFASST_MPI_COMMUNICATOR_HPP_
73

8-
#include <exception>
4+
#include <stdexcept>
95
#include <vector>
6+
using namespace std;
107

118
#include <mpi.h>
129

1310
#include "interfaces.hpp"
14-
#include "logging.hpp"
1511

16-
using namespace std;
1712

1813
namespace pfasst
1914
{
2015
namespace mpi
2116
{
22-
2317
class MPIError
24-
: public exception
18+
: public runtime_error
2519
{
2620
public:
27-
const char* what() const throw()
28-
{
29-
return "mpi error";
30-
}
21+
explicit MPIError(const string& msg="");
22+
virtual const char* what() const throw();
3123
};
3224

25+
26+
// forward declare for MPICommunicator
3327
class MPIStatus;
3428

3529

@@ -47,115 +41,37 @@ namespace pfasst
4741
//! @}
4842

4943
//! @{
50-
MPICommunicator()
51-
{}
52-
53-
MPICommunicator(MPI_Comm comm)
54-
{
55-
set_comm(comm);
56-
}
44+
MPICommunicator();
45+
MPICommunicator(MPI_Comm comm);
5746
//! @}
5847

5948
//! @{
60-
void set_comm(MPI_Comm comm)
61-
{
62-
this->comm = comm;
63-
MPI_Comm_size(this->comm, &(this->_size));
64-
MPI_Comm_rank(this->comm, &(this->_rank));
65-
66-
shared_ptr<MPIStatus> status = make_shared<MPIStatus>();
67-
this->status = status;
68-
this->status->set_comm(this);
69-
}
70-
71-
int size() { return this->_size; }
72-
int rank() { return this->_rank; }
49+
virtual void set_comm(MPI_Comm comm);
50+
virtual int size();
51+
virtual int rank();
7352
//! @}
7453
};
7554

7655

7756
class MPIStatus
7857
: public IStatus
7958
{
59+
protected:
8060
vector<bool> converged;
8161
MPICommunicator* mpi;
8262

8363
public:
84-
85-
virtual void set_comm(ICommunicator* comm)
86-
{
87-
this->comm = comm;
88-
this->converged.resize(comm->size());
89-
90-
this->mpi = dynamic_cast<MPICommunicator*>(comm); assert(this->mpi);
91-
}
92-
93-
virtual void clear() override
94-
{
95-
std::fill(converged.begin(), converged.end(), false);
96-
}
97-
98-
virtual void set_converged(bool converged) override
99-
{
100-
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " set converged to " << converged;
101-
this->converged.at(this->comm->rank()) = converged;
102-
}
103-
104-
virtual bool get_converged(int rank) override
105-
{
106-
return this->converged.at(rank);
107-
}
108-
109-
virtual void post()
110-
{
111-
// noop: send/recv for status info is blocking
112-
}
113-
114-
virtual void send()
115-
{
116-
// don't send forward if: single processor run, or we're the last processor
117-
if (mpi->size() == 1) { return; }
118-
if (mpi->rank() == mpi->size() - 1) { return; }
119-
120-
int iconverged = converged.at(mpi->rank()) ? 1 : 0;
121-
122-
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " status send " << iconverged;
123-
124-
int err = MPI_Send(&iconverged, sizeof(int), MPI_INT,
125-
(mpi->rank() + 1) % mpi->size(), 1, mpi->comm);
126-
127-
if (err != MPI_SUCCESS) {
128-
throw MPIError();
129-
}
130-
}
131-
132-
virtual void recv()
133-
{
134-
// don't recv if: single processor run, or we're the first processor
135-
if (mpi->size() == 1) { return; }
136-
if (mpi->rank() == 0) { return; }
137-
138-
if (get_converged(mpi->rank()-1)) {
139-
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " skipping status recv";
140-
return;
141-
}
142-
143-
MPI_Status stat;
144-
int iconverged;
145-
int err = MPI_Recv(&iconverged, sizeof(iconverged), MPI_INT,
146-
(mpi->rank() - 1) % mpi->size(), 1, mpi->comm, &stat);
147-
148-
if (err != MPI_SUCCESS) {
149-
throw MPIError();
150-
}
151-
152-
converged.at(mpi->rank()-1) = iconverged == 1 ? true : false;
153-
154-
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " status recv " << iconverged;
155-
}
64+
virtual void set_comm(ICommunicator* comm);
65+
virtual void clear() override;
66+
virtual void set_converged(bool converged) override;
67+
virtual bool get_converged(int rank) override;
68+
virtual void post();
69+
virtual void send();
70+
virtual void recv();
15671
};
157-
15872
} // ::pfasst::mpi
15973
} // ::pfasst
16074

75+
#include "mpi_communicator_impl.hpp"
76+
16177
#endif
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#include "mpi_communicator.hpp"
2+
3+
#include "logging.hpp"
4+
5+
6+
namespace pfasst
7+
{
8+
namespace mpi
9+
{
10+
MPIError::MPIError(const string& msg)
11+
: runtime_error(msg)
12+
{}
13+
14+
const char* MPIError::what() const throw()
15+
{
16+
return (string("mpi error: ") + string(runtime_error::what())).c_str();
17+
}
18+
19+
20+
MPICommunicator::MPICommunicator()
21+
{}
22+
23+
MPICommunicator::MPICommunicator(MPI_Comm comm)
24+
{
25+
set_comm(comm);
26+
}
27+
28+
void MPICommunicator::set_comm(MPI_Comm comm)
29+
{
30+
this->comm = comm;
31+
MPI_Comm_size(this->comm, &(this->_size));
32+
MPI_Comm_rank(this->comm, &(this->_rank));
33+
34+
shared_ptr<MPIStatus> status = make_shared<MPIStatus>();
35+
this->status = status;
36+
this->status->set_comm(this);
37+
}
38+
39+
int MPICommunicator::size()
40+
{
41+
return this->_size;
42+
}
43+
44+
int MPICommunicator::rank()
45+
{
46+
return this->_rank;
47+
}
48+
49+
50+
void MPIStatus::set_comm(ICommunicator* comm)
51+
{
52+
this->comm = comm;
53+
this->converged.resize(comm->size());
54+
55+
this->mpi = dynamic_cast<MPICommunicator*>(comm); assert(this->mpi);
56+
}
57+
58+
void MPIStatus::clear()
59+
{
60+
std::fill(converged.begin(), converged.end(), false);
61+
}
62+
63+
void MPIStatus::set_converged(bool converged)
64+
{
65+
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " set converged to " << converged;
66+
this->converged.at(this->comm->rank()) = converged;
67+
}
68+
69+
bool MPIStatus::get_converged(int rank)
70+
{
71+
return this->converged.at(rank);
72+
}
73+
74+
void MPIStatus::post()
75+
{
76+
// noop: send/recv for status info is blocking
77+
}
78+
79+
void MPIStatus::send()
80+
{
81+
// don't send forward if: single processor run, or we're the last processor
82+
if (mpi->size() == 1) { return; }
83+
if (mpi->rank() == mpi->size() - 1) { return; }
84+
85+
int iconverged = converged.at(mpi->rank()) ? 1 : 0;
86+
87+
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " status send " << iconverged;
88+
89+
int err = MPI_Send(&iconverged, sizeof(int), MPI_INT,
90+
(mpi->rank() + 1) % mpi->size(), 1, mpi->comm);
91+
92+
if (err != MPI_SUCCESS) {
93+
throw MPIError();
94+
}
95+
}
96+
97+
void MPIStatus::recv()
98+
{
99+
// don't recv if: single processor run, or we're the first processor
100+
if (mpi->size() == 1) { return; }
101+
if (mpi->rank() == 0) { return; }
102+
103+
if (get_converged(mpi->rank()-1)) {
104+
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " skipping status recv";
105+
return;
106+
}
107+
108+
MPI_Status stat;
109+
int iconverged;
110+
int err = MPI_Recv(&iconverged, sizeof(iconverged), MPI_INT,
111+
(mpi->rank() - 1) % mpi->size(), 1, mpi->comm, &stat);
112+
113+
if (err != MPI_SUCCESS) {
114+
throw MPIError();
115+
}
116+
117+
converged.at(mpi->rank()-1) = iconverged == 1 ? true : false;
118+
119+
LOG(DEBUG) << "mpi rank " << this->comm->rank() << " status recv " << iconverged;
120+
}
121+
} // ::pfasst::mpi
122+
} // ::pfasst

0 commit comments

Comments
 (0)