Skip to content

Commit 1a38e2d

Browse files
committed
tag communication by destination volume
1 parent d5576fb commit 1a38e2d

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

grudge/trace_pair.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,17 @@ def __init__(self,
570570
self.local_bdry_data = local_bdry_data
571571
self.remote_bdry_data_template = remote_bdry_data_template
572572

573-
self.comm_tag = self.base_comm_tag
574-
comm_tag = _sym_tag_to_num_tag(comm_tag)
575-
if comm_tag is not None:
576-
self.comm_tag += comm_tag
573+
def _generate_num_comm_tag(sym_comm_tag):
574+
result = self.base_comm_tag
575+
num_comm_tag = _sym_tag_to_num_tag(sym_comm_tag)
576+
if num_comm_tag is not None:
577+
result += num_comm_tag
578+
return result
579+
580+
send_sym_comm_tag = (remote_part_id.volume_tag, comm_tag)
581+
recv_sym_comm_tag = (local_part_id.volume_tag, comm_tag)
582+
self.send_comm_tag = _generate_num_comm_tag(send_sym_comm_tag)
583+
self.recv_comm_tag = _generate_num_comm_tag(recv_sym_comm_tag)
577584
del comm_tag
578585

579586
# NOTE: mpi4py currently (2021-11-03) holds a reference to the send
@@ -588,7 +595,7 @@ def send_single_array(key, local_subary):
588595
if not isinstance(local_subary, Number):
589596
local_subary_np = to_numpy(local_subary, actx)
590597
self.send_reqs.append(
591-
comm.Isend(local_subary_np, remote_rank, tag=self.comm_tag))
598+
comm.Isend(local_subary_np, remote_rank, tag=self.send_comm_tag))
592599
self.send_data.append(local_subary_np)
593600
return local_subary
594601

@@ -601,7 +608,8 @@ def recv_single_array(key, remote_subary_template):
601608
remote_subary_template.shape,
602609
remote_subary_template.dtype)
603610
self.recv_reqs.append(
604-
comm.Irecv(remote_subary_np, remote_rank, tag=self.comm_tag))
611+
comm.Irecv(remote_subary_np, remote_rank,
612+
tag=self.recv_comm_tag))
605613
self.recv_data[key] = remote_subary_np
606614
return remote_subary_template
607615

@@ -702,7 +710,7 @@ def send_single_array(key, local_subary):
702710
if isinstance(local_subary, Number):
703711
return
704712
else:
705-
ary_tag = (comm_tag, key)
713+
ary_tag = (remote_part_id.volume_tag, comm_tag, key)
706714
sends[key] = make_distributed_send(
707715
local_subary, dest_rank=remote_rank, comm_tag=ary_tag)
708716

@@ -711,7 +719,7 @@ def recv_single_array(key, remote_subary_template):
711719
# NOTE: Assumes that the same number is passed on every rank
712720
return remote_subary_template
713721
else:
714-
ary_tag = (comm_tag, key)
722+
ary_tag = (local_part_id.volume_tag, comm_tag, key)
715723
return DistributedSendRefHolder(
716724
sends[key],
717725
make_distributed_recv(

0 commit comments

Comments
 (0)