|
26 | 26 | #include <functional> |
27 | 27 | #include <map> |
28 | 28 | #include <memory> |
| 29 | +#include <set> |
29 | 30 | #include <string> |
30 | 31 | #include <tuple> |
31 | 32 | #include <type_traits> |
@@ -306,6 +307,37 @@ class Mesh { |
306 | 307 | comm_buf_map_t boundary_comm_map; |
307 | 308 | TagMap tag_map; |
308 | 309 |
|
| 310 | + // Sets the number of communication buffers that can be in-flight concurrently |
| 311 | + // for a given boundary type. This *must* be called before build boundary buffers |
| 312 | + // is called internally, so use beyond the defaults with care |
| 313 | + void SetNumberOfCommChannels(BoundaryType bound, std::size_t n_channels) { |
| 314 | + if (locked_comm_channel_numbers_.count(bound)) |
| 315 | + PARTHENON_FAIL("Trying to reset the number of comm channels after boundary buffers " |
| 316 | + "have been set up."); |
| 317 | + if (number_of_comm_channels_.count(bound) && |
| 318 | + number_of_comm_channels_[bound] > n_channels) |
| 319 | + PARTHENON_WARN( |
| 320 | + "You are reducing the number of comm channels from a previously set value."); |
| 321 | + number_of_comm_channels_[bound] = n_channels; |
| 322 | + |
| 323 | + // Need to set the complementary channels to the same value |
| 324 | + if (!IsSender(bound)) |
| 325 | + number_of_comm_channels_[GetAssociatedSender(bound)] = n_channels; |
| 326 | + if (!IsReceiver(bound)) |
| 327 | + number_of_comm_channels_[GetAssociatedReceiver(bound)] = n_channels; |
| 328 | + } |
| 329 | + |
| 330 | + void LockCommChannelNumbers(BoundaryType bound) { |
| 331 | + locked_comm_channel_numbers_.insert(GetAssociatedSender(bound)); |
| 332 | + locked_comm_channel_numbers_.insert(GetAssociatedReceiver(bound)); |
| 333 | + } |
| 334 | + |
| 335 | + std::size_t GetNumberOfCommChannels(BoundaryType bound) const { |
| 336 | + if (number_of_comm_channels_.count(bound)) return number_of_comm_channels_.at(bound); |
| 337 | + // We default to only having a single communication channel |
| 338 | + return 1; |
| 339 | + } |
| 340 | + |
309 | 341 | std::shared_ptr<CoalescedComms> pcoalesced_comms; |
310 | 342 |
|
311 | 343 | bool TryReallocCommBufferPools(); |
@@ -365,6 +397,8 @@ class Mesh { |
365 | 397 | // the last 4x should be std::size_t, but are limited to int by MPI |
366 | 398 | // Refinement tags used by MeshData checks |
367 | 399 | ParArray1D<AmrTag> amr_tags; |
| 400 | + std::map<BoundaryType, std::size_t> number_of_comm_channels_; |
| 401 | + std::set<BoundaryType> locked_comm_channel_numbers_; |
368 | 402 |
|
369 | 403 | std::vector<LogicalLocation> loclist; |
370 | 404 |
|
|
0 commit comments