@@ -692,47 +692,47 @@ def __init__(self,
692692 self .local_part_id = local_part_id
693693 self .remote_part_id = remote_part_id
694694
695- from pytato import make_distributed_recv , staple_distributed_send
695+ from pytato import (
696+ make_distributed_recv ,
697+ make_distributed_send ,
698+ DistributedSendRefHolder )
699+
700+ # TODO: This currently assumes that local_bdry_data and
701+ # remote_bdry_data_template have the same structure. This is not true
702+ # in general. Find a way to staple the sends appropriately when the number
703+ # of recvs is not equal to the number of sends
704+ assert type (local_bdry_data ) == type (remote_bdry_data_template )
705+
706+ sends = {}
696707
697- # Staple the sends to a bunch of dummy arrays of zeros
698708 def send_single_array (key , local_subary ):
699709 if isinstance (local_subary , Number ):
700- return 0
710+ return
701711 else :
702712 ary_tag = (comm_tag , key )
703- return staple_distributed_send (
704- local_subary , dest_rank = remote_rank , comm_tag = ary_tag ,
705- stapled_to = actx .zeros_like (local_subary ))
713+ sends [key ] = make_distributed_send (
714+ local_subary , dest_rank = remote_rank , comm_tag = ary_tag )
706715
707716 def recv_single_array (key , remote_subary_template ):
708717 if isinstance (remote_subary_template , Number ):
709718 # NOTE: Assumes that the same number is passed on every rank
710- return remote_subary_template
719+ return Number
711720 else :
712721 ary_tag = (comm_tag , key )
713- return make_distributed_recv (
714- src_rank = remote_rank , comm_tag = ary_tag ,
715- shape = remote_subary_template .shape ,
716- dtype = remote_subary_template .dtype )
722+ return DistributedSendRefHolder (
723+ sends [key ],
724+ make_distributed_recv (
725+ src_rank = remote_rank , comm_tag = ary_tag ,
726+ shape = remote_subary_template .shape ,
727+ dtype = remote_subary_template .dtype ))
717728
718729 from arraycontext .container .traversal import rec_keyed_map_array_container
719- zeros_like_local_bdry_data = rec_keyed_map_array_container (
720- send_single_array , local_bdry_data )
721- unswapped_remote_bdry_data = rec_keyed_map_array_container (
722- recv_single_array , remote_bdry_data_template )
723730
724- # Sum up the dummy zeros
725- zero = actx . np . sum ( zeros_like_local_bdry_data )
731+ rec_keyed_map_array_container ( send_single_array , local_bdry_data )
732+ self . local_bdry_data = local_bdry_data
726733
727- # Add the dummy zeros and hope that the caller proceeds to actually
728- # use some of this data on every rank...
729- from arraycontext import rec_map_array_container
730- self .local_bdry_data = rec_map_array_container (
731- lambda x : x + zero ,
732- local_bdry_data )
733- self .unswapped_remote_bdry_data = rec_map_array_container (
734- lambda x : x + zero ,
735- unswapped_remote_bdry_data )
734+ self .unswapped_remote_bdry_data = rec_keyed_map_array_container (
735+ recv_single_array , remote_bdry_data_template )
736736
737737 def finish (self ):
738738 remote_to_local = self .dcoll ._inter_partition_connections [
0 commit comments