Skip to content

Commit 97915bd

Browse files
committed
forget about heterogeneous inter-volume trace pairs for now
1 parent e1b2da7 commit 97915bd

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

grudge/trace_pair.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -692,47 +692,47 @@ def __init__(self,
692692
self.local_part_id = local_part_id
693693
self.remote_part_id = remote_part_id
694694

695-
from pytato import make_distributed_recv, staple_distributed_send
695+
from pytato import (
696+
make_distributed_recv,
697+
make_distributed_send,
698+
DistributedSendRefHolder)
699+
700+
# TODO: This currently assumes that local_bdry_data and
701+
# remote_bdry_data_template have the same structure. This is not true
702+
# in general. Find a way to staple the sends appropriately when the number
703+
# of recvs is not equal to the number of sends
704+
assert type(local_bdry_data) == type(remote_bdry_data_template)
705+
706+
sends = {}
696707

697-
# Staple the sends to a bunch of dummy arrays of zeros
698708
def send_single_array(key, local_subary):
699709
if isinstance(local_subary, Number):
700-
return 0
710+
return
701711
else:
702712
ary_tag = (comm_tag, key)
703-
return staple_distributed_send(
704-
local_subary, dest_rank=remote_rank, comm_tag=ary_tag,
705-
stapled_to=actx.zeros_like(local_subary))
713+
sends[key] = make_distributed_send(
714+
local_subary, dest_rank=remote_rank, comm_tag=ary_tag)
706715

707716
def recv_single_array(key, remote_subary_template):
708717
if isinstance(remote_subary_template, Number):
709718
# NOTE: Assumes that the same number is passed on every rank
710-
return remote_subary_template
719+
return Number
711720
else:
712721
ary_tag = (comm_tag, key)
713-
return make_distributed_recv(
714-
src_rank=remote_rank, comm_tag=ary_tag,
715-
shape=remote_subary_template.shape,
716-
dtype=remote_subary_template.dtype)
722+
return DistributedSendRefHolder(
723+
sends[key],
724+
make_distributed_recv(
725+
src_rank=remote_rank, comm_tag=ary_tag,
726+
shape=remote_subary_template.shape,
727+
dtype=remote_subary_template.dtype))
717728

718729
from arraycontext.container.traversal import rec_keyed_map_array_container
719-
zeros_like_local_bdry_data = rec_keyed_map_array_container(
720-
send_single_array, local_bdry_data)
721-
unswapped_remote_bdry_data = rec_keyed_map_array_container(
722-
recv_single_array, remote_bdry_data_template)
723730

724-
# Sum up the dummy zeros
725-
zero = actx.np.sum(zeros_like_local_bdry_data)
731+
rec_keyed_map_array_container(send_single_array, local_bdry_data)
732+
self.local_bdry_data = local_bdry_data
726733

727-
# Add the dummy zeros and hope that the caller proceeds to actually
728-
# use some of this data on every rank...
729-
from arraycontext import rec_map_array_container
730-
self.local_bdry_data = rec_map_array_container(
731-
lambda x: x + zero,
732-
local_bdry_data)
733-
self.unswapped_remote_bdry_data = rec_map_array_container(
734-
lambda x: x + zero,
735-
unswapped_remote_bdry_data)
734+
self.unswapped_remote_bdry_data = rec_keyed_map_array_container(
735+
recv_single_array, remote_bdry_data_template)
736736

737737
def finish(self):
738738
remote_to_local = self.dcoll._inter_partition_connections[

0 commit comments

Comments
 (0)