Skip to content

Commit b21d135

Browse files
committed
wip: opti mpi
1 parent 3c1cc87 commit b21d135

File tree

9 files changed

+159
-73
lines changed

9 files changed

+159
-73
lines changed

apps/core/includes/sync.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace Simulation
88
class SimulationUnit;
99
}
1010

11+
#include <mpi_w/wrap_mpi.hpp>
12+
1113
/**
1214
* @brief Synchronization after particle processing.
1315
*
@@ -34,7 +36,8 @@ void sync_step(const ExecInfo& exec, Simulation::SimulationUnit& simulation);
3436
* @param simulation The `Simulation::SimulationUnit` object representing the
3537
* simulation being synchronized.
3638
*/
37-
void sync_prepare_next(Simulation::SimulationUnit& simulation);
39+
void sync_prepare_next(Simulation::SimulationUnit& simulation,
40+
MPI_Request* request);
3841

3942
/**
4043
* @brief Final synchronization before exporting results.

apps/core/src/host_specific.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
# define SEND_MPI_SIG_STOP
6666
# define INIT_PAYLOAD
6767
# define SEND_MPI_SIG_RUN
68+
6869
#endif
6970

7071
#ifdef DEBUG
@@ -190,6 +191,8 @@ namespace
190191
exporter_handler.pre_post_export(current_time, simulation, d_transionner);
191192
}
192193

194+
MPI_Request req{};
195+
193196
auto loop_functor = [&](auto&& local_container)
194197
{
195198
Core::SignalHandler sig;
@@ -232,6 +235,8 @@ namespace
232235
}
233236
}
234237

238+
mpi_payload.wait();
239+
235240
sync_step(exec, simulation);
236241
{
237242
PROFILE_SECTION("host:sync_update")
@@ -240,8 +245,9 @@ namespace
240245
// From here, contributions can be overwritten
241246
current_time += d_t;
242247
}
243-
sync_prepare_next(simulation);
244-
248+
#ifndef NO_MPI
249+
sync_prepare_next(simulation, &req);
250+
#endif
245251
simulation.cycleProcess(local_container, d_t, functors);
246252

247253
if (Core::SignalHandler::is_usr1_raised()) [[unlikely]]
@@ -257,6 +263,8 @@ namespace
257263
}
258264
break;
259265
}
266+
WrapMPI::Async::wait(req);
267+
260268
} // end for
261269

262270
local_container.force_remove_dead();

apps/core/src/sync.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ void sync_step(const ExecInfo& exec, Simulation::SimulationUnit& simulation)
9191
}
9292
}
9393

94-
void sync_prepare_next(Simulation::SimulationUnit& simulation)
94+
void sync_prepare_next(Simulation::SimulationUnit& simulation,
95+
MPI_Request* request)
9596
{
9697
PROFILE_SECTION("sync_prepare_next")
9798
simulation.clearContribution();
@@ -104,8 +105,10 @@ void sync_prepare_next(Simulation::SimulationUnit& simulation)
104105
simulation.getCliqData(); // Get concentration ptr wrapped into span
105106

106107
WrapMPI::barrier();
107-
// We can use span here because we broadcast without changing size
108-
WrapMPI::broadcast_span(data, 0);
108+
// We can use span here because we broadcast without changing size
109+
#ifndef NO_MPI
110+
WrapMPI::Async::broadcast_span(data, 0, *request);
111+
#endif
109112
}
110113
}
111114

apps/core/src/worker_specific.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ void workers_process([[maybe_unused]] std::shared_ptr<IO::Logger> logger,
1818
double d_t = params.d_t;
1919
size_t n_compartments = simulation.mc_unit->domain.getNumberCompartments();
2020
MPI_Status status;
21-
21+
MPI_Request req;
2222
const bool do_export = true; // TODO
2323

2424
WrapMPI::IterationPayload payload(n_compartments);
@@ -75,19 +75,21 @@ void workers_process([[maybe_unused]] std::shared_ptr<IO::Logger> logger,
7575
const auto cycle_callback =
7676
[&](double& current_time, auto& container, auto& functors)
7777
{
78+
sync_step(exec, simulation);
79+
sync_prepare_next(simulation, &req);
7880
simulation.update_feed(current_time, d_t, false);
81+
WrapMPI::Async::wait(req);
7982
simulation.cycleProcess(container, d_t, functors);
8083

8184
current_time += d_t;
82-
sync_step(exec, simulation);
83-
sync_prepare_next(simulation);
8485
};
8586

8687
const auto loop_functor = [&](auto&& container)
8788
{
8889
auto functors = simulation.init_functors<ComputeSpace>(container);
8990
// bool stop = false;
9091
WrapMPI::SIGNALS signal{};
92+
9193
double current_time = 0;
9294
while (true)
9395
{

apps/libs/mpi_w/public/mpi_w/impl_async.hpp

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,42 @@
99
#include <mpi_w/message_t.hpp>
1010
#include <mpi_w/mpi_types.hpp>
1111
#include <span>
12-
// #include <optional>
13-
// #include <limits>
14-
// #include <vector>
15-
// #include <stdexcept>
1612

1713
namespace WrapMPI::Async
1814
{
15+
namespace
16+
{
17+
template <POD_t DataType>
18+
int _send_unsafe(MPI_Request& request,
19+
DataType* buf,
20+
size_t buf_size,
21+
size_t dest,
22+
size_t tag) noexcept
23+
{
24+
return MPI_Isend(buf,
25+
buf_size,
26+
get_type<DataType>(),
27+
dest,
28+
tag,
29+
MPI_COMM_WORLD,
30+
&request);
31+
}
32+
33+
} // namespace
1934

2035
inline MPI_Status wait(MPI_Request& request)
2136
{
37+
PROFILE_SECTION("WrapMPI::wait");
2238
MPI_Status status;
2339
MPI_Wait(&request, &status);
2440
return status;
2541
}
2642

27-
template <POD_t DataType>
28-
static int _send_unsafe(MPI_Request& request,
29-
DataType* buf,
30-
size_t buf_size,
31-
size_t dest,
32-
size_t tag) noexcept
43+
inline void wait(MPI_Request& request, MPI_Status* status)
3344
{
34-
return MPI_Isend(buf,
35-
buf_size,
36-
get_type<DataType>(),
37-
dest,
38-
tag,
39-
MPI_COMM_WORLD,
40-
&request);
45+
MPI_Wait(&request, status);
4146
}
47+
4248
template <POD_t DataType>
4349
int send(MPI_Request& request, DataType data, size_t dest, size_t tag)
4450
{
@@ -82,6 +88,57 @@ namespace WrapMPI::Async
8288
&request);
8389
}
8490

91+
template <POD_t DataType>
92+
std::optional<DataType>
93+
recv(size_t src, MPI_Request& request, size_t tag) noexcept
94+
{
95+
DataType buf;
96+
97+
int recv_status = MPI_Irecv(
98+
&buf, sizeof(DataType), MPI_BYTE, src, tag, MPI_COMM_WORLD, &request);
99+
if (recv_status != MPI_SUCCESS)
100+
{
101+
return std::nullopt;
102+
}
103+
return buf;
104+
}
105+
106+
template <POD_t T>
107+
int _broadcast_unsafe(T* data,
108+
size_t _size,
109+
size_t root,
110+
MPI_Request& request)
111+
{
112+
if (data == nullptr)
113+
{
114+
throw std::invalid_argument("Data pointer is null");
115+
}
116+
if (_size == 0 || _size > std::numeric_limits<size_t>::max())
117+
{
118+
throw std::invalid_argument("Error size");
119+
}
120+
121+
int comm_size = 0;
122+
MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
123+
if (root >= static_cast<size_t>(comm_size))
124+
{
125+
throw std::invalid_argument("Root process rank is out of range");
126+
}
127+
128+
return MPI_Ibcast(data,
129+
_size,
130+
get_type<T>(),
131+
static_cast<int>(root),
132+
MPI_COMM_WORLD,
133+
&request);
134+
}
135+
136+
template <POD_t T>
137+
int broadcast_span(std::span<T> data, size_t root, MPI_Request& request)
138+
{
139+
return _broadcast_unsafe(data.data(), data.size(), root, request);
140+
}
141+
85142
} // namespace WrapMPI::Async
86143

87144
#endif

apps/libs/mpi_w/public/mpi_w/impl_op.hpp

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#ifndef __IMPL_MPI_OP_HPP__
22
#define __IMPL_MPI_OP_HPP__
33

4-
#include <common/traits.hpp>
5-
#include <mpi_w/message_t.hpp>
6-
#include <mpi_w/mpi_types.hpp>
7-
84
#include <common/execinfo.hpp>
5+
#include <common/traits.hpp>
96
#include <cstddef>
107
#include <limits>
118
#include <math.h>
129
#include <mpi.h>
10+
#include <mpi_w/impl_async.hpp>
11+
#include <mpi_w/message_t.hpp>
12+
#include <mpi_w/mpi_types.hpp>
1313
#include <optional>
1414
#include <span>
1515
#include <stdexcept>
@@ -23,27 +23,27 @@ namespace WrapMPI
2323

2424
// SENDING
2525

26-
/**
27-
* @brief Sends raw data to a destination in an unsafe manner.
28-
*
29-
* This function sends raw data of type `DataType` to the specified
30-
* destination. It assumes the provided buffer and its size are valid
31-
*
32-
* @tparam DataType A type satisfying the `POD` concept.
33-
* @param buf Pointer to the buffer containing data to send.
34-
* @param buf_size The size of the buffer in bytes.
35-
* @param dest The destination identifier for the data.
36-
* @param tag Optional tag to identify the message (default is 0).
37-
* @return An integer indicating success or failure of the operation.
38-
*
39-
* @note Use this function with caution as it performs no validation on the
40-
* input.
41-
*/
42-
template <POD_t DataType>
43-
[[nodiscard]] static int _send_unsafe(DataType* buf,
44-
size_t buf_size,
45-
size_t dest,
46-
size_t tag = 0) noexcept;
26+
// /**
27+
// * @brief Sends raw data to a destination in an unsafe manner.
28+
// *
29+
// * This function sends raw data of type `DataType` to the specified
30+
// * destination. It assumes the provided buffer and its size are valid
31+
// *
32+
// * @tparam DataType A type satisfying the `POD` concept.
33+
// * @param buf Pointer to the buffer containing data to send.
34+
// * @param buf_size The size of the buffer in bytes.
35+
// * @param dest The destination identifier for the data.
36+
// * @param tag Optional tag to identify the message (default is 0).
37+
// * @return An integer indicating success or failure of the operation.
38+
// *
39+
// * @note Use this function with caution as it performs no validation on the
40+
// * input.
41+
// */
42+
// template <POD_t DataType>
43+
// [[nodiscard]] static int _send_unsafe(DataType* buf,
44+
// size_t buf_size,
45+
// size_t dest,
46+
// size_t tag = 0) noexcept;
4747

4848
/**
4949
* @brief Sends a single instance of data to a destination.
@@ -335,13 +335,6 @@ namespace WrapMPI
335335
//**
336336
// IMPL
337337
//**
338-
template <POD_t DataType>
339-
static int
340-
_send_unsafe(DataType* buf, size_t buf_size, size_t dest, size_t tag) noexcept
341-
{
342-
return MPI_Send(
343-
buf, buf_size, get_type<DataType>(), dest, tag, MPI_COMM_WORLD);
344-
}
345338

346339
template <POD_t DataType>
347340
DataType try_recv(size_t src, MPI_Status* status, size_t tag)
@@ -621,7 +614,10 @@ namespace WrapMPI
621614

622615
template <POD_t DataType> int send(DataType data, size_t dest, size_t tag)
623616
{
624-
return _send_unsafe<DataType>(&data, 1, dest, tag);
617+
MPI_Request req;
618+
auto res = WrapMPI::Async::_send_unsafe<DataType>(req, &data, 1, dest, tag);
619+
WrapMPI::Async::wait(req);
620+
return res;
625621
}
626622

627623
template <POD_t DataType>
@@ -639,7 +635,12 @@ namespace WrapMPI
639635

640636
if (send_status == MPI_SUCCESS)
641637
{
642-
send_status = _send_unsafe(data.data(), data.size(), dest, tag);
638+
MPI_Request req{};
639+
auto res = WrapMPI::Async::_send_unsafe<DataType>(
640+
req, data.data(), data.size(), dest, tag);
641+
WrapMPI::Async::wait(req);
642+
643+
send_status = res;
643644
}
644645

645646
return send_status;

apps/libs/mpi_w/public/mpi_w/iteration_payload.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,16 @@ namespace WrapMPI
6464

6565
[[nodiscard]] bool sendAll(std::size_t n_rank) noexcept;
6666

67+
void wait() noexcept;
68+
6769
private:
6870
std::span<const double> liquid_volumes;
6971
std::span<const std::size_t> liquid_neighbors_flat;
7072
std::span<const double> proba_leaving_flat;
7173
std::span<const double> liquid_out_flows;
7274

75+
bool to_wait;
76+
7377
/**
7478
* @brief Sends this payload to a specified MPI rank.
7579
*

0 commit comments

Comments
 (0)