Skip to content

Commit b5b0422

Browse files
committed
Add support for directly communicating enum types.
1 parent 1404388 commit b5b0422

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

include/plssvm/mpi/detail/mpi_datatype.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
#include "mpi.h" // MPI_Datatype, various MPI datatypes
1919

20-
#include <complex> // std::complex
20+
#include <complex> // std::complex
21+
#include <type_traits> // std::enable_if_t, std::is_enum_v, std::underlying_type_t
2122

2223
/**
2324
* @def PLSSVM_CREATE_MPI_DATATYPE_MAPPING
@@ -33,11 +34,11 @@ namespace plssvm::mpi::detail {
3334

3435
/**
3536
* @brief Tries to convert the given C++ type to its corresponding MPI_Datatype.
36-
* @details The definition is marked as **deleted** if `T` isn't representable as [`MPI_Datatype`](https://www.mpi-forum.org/docs/mpi-2.2/mpi22-report/node44.htm).
37+
* @details The definition is marked as **deleted** if `T` isn't representable as [`MPI_Datatype`](https://www.mpi-forum.org/docs/mpi-2.2/mpi22-report/node44.htm) or an enum.
3738
* @tparam T the type to convert to a MPI_Datatype
3839
* @return the corresponding MPI_Datatype (`[[nodiscard]]`)
3940
*/
40-
template <typename T>
41+
template <typename T, std::enable_if_t<!std::is_enum_v<T>, bool> = true>
4142
[[nodiscard]] inline MPI_Datatype mpi_datatype() = delete;
4243

4344
PLSSVM_CREATE_MPI_DATATYPE_MAPPING(bool, MPI_C_BOOL)
@@ -76,6 +77,16 @@ PLSSVM_CREATE_MPI_DATATYPE_MAPPING(std::complex<float>, MPI_C_COMPLEX)
7677
PLSSVM_CREATE_MPI_DATATYPE_MAPPING(std::complex<double>, MPI_C_DOUBLE_COMPLEX)
7778
PLSSVM_CREATE_MPI_DATATYPE_MAPPING(std::complex<long double>, MPI_C_LONG_DOUBLE_COMPLEX)
7879

80+
/**
81+
* @brief Specialization for enums: for enums, use their underlying type in MPI communications.
82+
* @tparam T the enum type to convert to a MPI_Datatype
83+
* @return the corresponding MPI_Datatype (`[[nodiscard]]`)
84+
*/
85+
template <typename T, std::enable_if_t<std::is_enum_v<T>, bool> = true>
86+
[[nodiscard]] inline MPI_Datatype mpi_datatype() {
87+
return mpi_datatype<std::underlying_type_t<T>>();
88+
}
89+
7990
} // namespace plssvm::mpi::detail
8091

8192
#undef PLSSVM_CREATE_MPI_DATATYPE_MAPPING

0 commit comments

Comments
 (0)