Skip to content

Commit 864dcd7

Browse files
committed
Work on yurlungur comments
1 parent 2bf2188 commit 864dcd7

File tree

13 files changed

+100
-32
lines changed

13 files changed

+100
-32
lines changed

src/bvals/comms/bnd_id.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace parthenon {
4343
BndId BndId::GetSend(MeshBlock *pmb, const NeighborBlock &nb,
4444
std::shared_ptr<Variable<Real>> v, BoundaryType b_type,
4545
int partition, int start_idx) {
46-
// TODO: This needs to be fixed for unique buffer ids
46+
// TODO(LFR): This needs to be fixed for unique buffer ids
4747
auto [send_gid, recv_gid, vlabel, loc, extra_id] = SendKey(pmb, nb, v, b_type, 0);
4848
BndId out;
4949
out.send_gid() = send_gid;

src/bvals/comms/build_boundary_buffers.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ void BuildBoundaryBufferSubset(std::shared_ptr<MeshData<Real>> &md,
117117

118118
// Build send buffer (unless this is a receiving flux boundary)
119119
if constexpr (IsSender(BTYPE)) {
120-
for (int id = 0; id <= (BTYPE == BoundaryType::gmg_restrict_send); ++id) {
120+
pmesh->LockCommChannelNumbers(BTYPE);
121+
for (int id = 0; id < pmesh->GetNumberOfCommChannels(BTYPE); ++id) {
121122
auto s_key = SendKey(pmb, nb, v, BTYPE, id);
122123
const int tag = pmesh->tag_map.GetTag(pmb, nb, id);
123124
if (buf_map.count(s_key) == 0)
@@ -130,7 +131,8 @@ void BuildBoundaryBufferSubset(std::shared_ptr<MeshData<Real>> &md,
130131
// Also build the non-local receive buffers here
131132
if constexpr (IsReceiver(BTYPE)) {
132133
if (sender_rank != receiver_rank) {
133-
for (int id = 0; id <= (BTYPE == BoundaryType::gmg_restrict_recv); ++id) {
134+
pmesh->LockCommChannelNumbers(BTYPE);
135+
for (int id = 0; id < pmesh->GetNumberOfCommChannels(BTYPE); ++id) {
134136
auto r_key = ReceiveKey(pmb, nb, v, BTYPE, id);
135137
const int tag = pmesh->tag_map.GetTag(pmb, nb, id);
136138
if (buf_map.count(r_key) == 0)

src/bvals/comms/bvals_utils.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,13 @@ inline Mesh::channel_key_t SendKey(const MeshBlock *pmb, const NeighborBlock &nb
5454
const int sender_id = pmb->gid;
5555
const int receiver_id = nb.gid;
5656
const int location_idx = nb.offsets.GetIdx();
57-
int other = (nb.gid == pmb->gid && (btype == BoundaryType::gmg_restrict_recv ||
58-
btype == BoundaryType::gmg_restrict_send))
59-
? 1
60-
: 0;
61-
other += 2 * id;
57+
58+
int gmg_self_comm = (nb.gid == pmb->gid && (btype == BoundaryType::gmg_restrict_recv ||
59+
btype == BoundaryType::gmg_restrict_send))
60+
? 1
61+
: 0;
62+
int other =
63+
gmg_self_comm + 2 * id; // Combine the id and gmg_self_comm into a single tag
6264
return {sender_id, receiver_id, pcv->label(), location_idx, other};
6365
}
6466

@@ -68,11 +70,12 @@ inline Mesh::channel_key_t ReceiveKey(const MeshBlock *pmb, const NeighborBlock
6870
const int receiver_id = pmb->gid;
6971
const int sender_id = nb.gid;
7072
const int location_idx = nb.lcoord_trans.Transform(nb.offsets).GetReverseIdx();
71-
int other = (nb.gid == pmb->gid && (btype == BoundaryType::gmg_restrict_recv ||
72-
btype == BoundaryType::gmg_restrict_send))
73-
? 1
74-
: 0;
75-
other += 2 * id;
73+
int gmg_self_comm = (nb.gid == pmb->gid && (btype == BoundaryType::gmg_restrict_recv ||
74+
btype == BoundaryType::gmg_restrict_send))
75+
? 1
76+
: 0;
77+
int other =
78+
gmg_self_comm + 2 * id; // Combine the id and gmg_self_comm into a single tag
7679
return {sender_id, receiver_id, pcv->label(), location_idx, other};
7780
}
7881

src/bvals/comms/tag_map.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
//========================================================================================
1717

1818
#include <memory>
19+
#include <utility>
1920

2021
#include "bnd_info.hpp"
2122
#include "bvals_utils.hpp"

src/interface/mesh_data.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ void MeshData<T>::SetMeshProperties(Mesh *pmesh) {
4444
ndim_ = pmesh == nullptr ? 0 : pmesh->ndim;
4545
}
4646

47+
template <typename T>
48+
void MeshData<T>::SetBoundBufferId(BoundaryType btype, int id) {
49+
PARTHENON_REQUIRE(id < pmy_mesh_->GetNumberOfCommChannels(btype),
50+
"Trying to set MeshData to communicate on a non-existent channel.");
51+
// We do not enforce symmetry here between associated senders and
52+
// receivers for maximum flexibility.
53+
bound_buffer_ids_[btype] = id;
54+
}
55+
4756
template class MeshData<Real>;
4857

4958
} // namespace parthenon

src/interface/mesh_data.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class MeshData {
199199
GridIdentifier grid;
200200
int partition{-1};
201201

202-
void SetBoundBufferId(BoundaryType btype, int id) { bound_buffer_ids_[btype] = id; }
202+
void SetBoundBufferId(BoundaryType btype, int id);
203203

204204
int GetBoundBufferId(BoundaryType btype) const {
205205
if (bound_buffer_ids_.count(btype)) return bound_buffer_ids_.at(btype);

src/mesh/mesh-gmg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <iostream>
2323
#include <limits>
2424
#include <map>
25+
#include <memory>
2526
#include <numeric>
2627
#include <sstream>
2728
#include <string>

src/mesh/mesh.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, Packages_t &packages,
188188
max_level = max_level_ref_;
189189
}
190190

191+
if (multigrid) SetNumberOfCommChannels(BoundaryType::gmg_restrict_send, 2);
192+
191193
SetupMPIComms();
192194

193195
RegisterLoadBalancing_(pin);

src/mesh/mesh.hpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <functional>
2727
#include <map>
2828
#include <memory>
29+
#include <set>
2930
#include <string>
3031
#include <tuple>
3132
#include <type_traits>
@@ -306,6 +307,37 @@ class Mesh {
306307
comm_buf_map_t boundary_comm_map;
307308
TagMap tag_map;
308309

310+
// Sets the number of communication buffers that can be in-flight concurrently
311+
// for a given boundary type. This *must* be called before build boundary buffers
312+
// is called internally, so use beyond the defaults with care
313+
void SetNumberOfCommChannels(BoundaryType bound, std::size_t n_channels) {
314+
if (locked_comm_channel_numbers_.count(bound))
315+
PARTHENON_FAIL("Trying to reset the number of comm channels after boundary buffers "
316+
"have been set up.");
317+
if (number_of_comm_channels_.count(bound) &&
318+
number_of_comm_channels_[bound] > n_channels)
319+
PARTHENON_WARN(
320+
"You are reducing the number of comm channels from a previously set value.");
321+
number_of_comm_channels_[bound] = n_channels;
322+
323+
// Need to set the complementary channels to the same value
324+
if (!IsSender(bound))
325+
number_of_comm_channels_[GetAssociatedSender(bound)] = n_channels;
326+
if (!IsReceiver(bound))
327+
number_of_comm_channels_[GetAssociatedReceiver(bound)] = n_channels;
328+
}
329+
330+
void LockCommChannelNumbers(BoundaryType bound) {
331+
locked_comm_channel_numbers_.insert(GetAssociatedSender(bound));
332+
locked_comm_channel_numbers_.insert(GetAssociatedReceiver(bound));
333+
}
334+
335+
std::size_t GetNumberOfCommChannels(BoundaryType bound) const {
336+
if (number_of_comm_channels_.count(bound)) return number_of_comm_channels_.at(bound);
337+
// We default to only having a single communication channel
338+
return 1;
339+
}
340+
309341
std::shared_ptr<CoalescedComms> pcoalesced_comms;
310342

311343
bool TryReallocCommBufferPools();
@@ -365,6 +397,8 @@ class Mesh {
365397
// the last 4x should be std::size_t, but are limited to int by MPI
366398
// Refinement tags used by MeshData checks
367399
ParArray1D<AmrTag> amr_tags;
400+
std::map<BoundaryType, std::size_t> number_of_comm_channels_;
401+
std::set<BoundaryType> locked_comm_channel_numbers_;
368402

369403
std::vector<LogicalLocation> loclist;
370404

src/mesh/meshblock.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ class MeshBlock : public std::enable_shared_from_this<MeshBlock> {
453453
std::forward<Args>(args)...);
454454
}
455455

456+
// Checks if the LogicalLocation of this block is a leaf logical location
456457
bool IsLeafLL() const { return is_leaf_ll_; }
457458

458459
private:

0 commit comments

Comments
 (0)