Skip to content

Commit 1eaccb8

Browse files
committed
Add new gather function for std::chrono::milliseconds.
1 parent 4d10527 commit 1eaccb8

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

include/plssvm/mpi/communicator.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mpi.h" // MPI_Comm, MPI_COMM_WORLD, MPI_Gather
2222
#endif
2323

24+
#include <chrono> // std::chrono::milliseconds
2425
#include <cstddef> // std::size_t
2526
#include <functional> // std::invoke
2627
#include <string> // std::string
@@ -126,6 +127,14 @@ class communicator {
126127
*/
127128
[[nodiscard]] std::vector<std::string> gather(const std::string &str) const;
128129

130+
/**
131+
* @brief Gather the `std::chrono::milliseconds` @p duration from each MPI rank on the `communicator::main_rank()`.
132+
* @details If `PLSSVM_HAS_MPI_ENABLED` is undefined, returns the provided @p duration wrapped in a `std::vector`.
133+
* @param[in] duration the duration to gather at the main MPI rank
134+
* @return a `std::vector` containing all gathered durations (`[[nodiscard]]`)
135+
*/
136+
[[nodiscard]] std::vector<std::chrono::milliseconds> gather(const std::chrono::milliseconds &duration) const;
137+
129138
#if defined(PLSSVM_HAS_MPI_ENABLED)
130139
/**
131140
* @brief Add implicit conversion operator back to a native MPI communicator.

src/plssvm/mpi/communicator.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
#include "mpi.h"
1515
#endif
1616

17-
#include <cstddef> // std::size_t
18-
#include <string> // std::string
19-
#include <vector> // std::vector
17+
#include <algorithm> // std::transform
18+
#include <chrono> // std::chrono::milliseconds
19+
#include <cstddef> // std::size_t
20+
#include <cstdint> // std::int64_t
21+
#include <string> // std::string
22+
#include <vector> // std::vector
2023

2124
namespace plssvm::mpi {
2225

@@ -95,4 +98,20 @@ std::vector<std::string> communicator::gather(const std::string &str) const {
9598
#endif
9699
}
97100

101+
std::vector<std::chrono::milliseconds> communicator::gather(const std::chrono::milliseconds &duration) const {
102+
#if defined(PLSSVM_HAS_MPI_ENABLED)
103+
// convert the duration to an integer
104+
const std::int64_t intermediate_dur = duration.count();
105+
std::vector<std::int64_t> intermediate_result(this->size());
106+
// gather the integer values from each MPI rank
107+
PLSSVM_MPI_ERROR_CHECK(MPI_Gather(&intermediate_dur, 1, detail::mpi_datatype<std::int64_t>(), intermediate_result.data(), 1, detail::mpi_datatype<std::int64_t>(), communicator::main_rank(), comm_));
108+
// cast integers back to durations
109+
std::vector<std::chrono::milliseconds> result(this->size());
110+
std::transform(intermediate_result.cbegin(), intermediate_result.cend(), result.begin(), [](const std::int64_t dur) { return static_cast<std::chrono::milliseconds>(dur); });
111+
return result;
112+
#else
113+
return { duration };
114+
#endif
115+
}
116+
98117
} // namespace plssvm::mpi

0 commit comments

Comments
 (0)