4141import numpy as np
4242from dask .array import Array
4343from dask .distributed import comm , Queue , Variable
44- from distributed import Client , Future
44+ from distributed import Client , Future , get_client
4545
4646
4747def get_connection_info (dask_scheduler_address : str | Client ) -> Client :
@@ -66,8 +66,116 @@ def get_connection_info(dask_scheduler_address: str | Client) -> Client:
6666 return client
6767
6868
69+ class Handshake :
70+ DEISA_HANDSHAKE_ACTOR_FUTURE_VARIABLE = 'deisa_handshake_actor_future'
71+ DEISA_WAIT_FOR_GO_VARIABLE = 'deisa_handshake_wait_for_go'
72+
73+ class HandshakeActor :
74+ bridges = []
75+ max_bridges = 0
76+ arrays_metadata = {}
77+ analytics_ready = False
78+
79+ def __init__ (self ):
80+ self .bridges = []
81+ self .max_bridges = 0
82+ self .arrays_metadata = {}
83+ self .analytics_ready = False
84+ self .client = get_client ()
85+
86+ def add_bridge (self , id : int , max : int ) -> None :
87+ if max == 0 :
88+ raise ValueError ('max cannot be 0.' )
89+ elif self .max_bridges == 0 :
90+ self .max_bridges = max
91+ elif self .max_bridges != max :
92+ raise ValueError (f'Value { max } for bridge { id } is unexpected. Expecting max={ self .max_bridges } .' )
93+ elif len (self .bridges ) >= max :
94+ raise RuntimeError (f'add_bridge cannot be called more than { max } times.' )
95+
96+ self .bridges .append (id )
97+
98+ def set_analytics_ready (self ) -> None :
99+ self .analytics_ready = True
100+ if self .__are_bridges_ready ():
101+ self .__go ()
102+
103+ def set_arrays_metadata (self , arrays_metadata : dict ) -> None :
104+ self .arrays_metadata = arrays_metadata
105+
106+ def get_arrays_metadata (self ) -> dict | Future [dict ]:
107+ return self .arrays_metadata
108+
109+ def get_max_bridges (self ) -> int | Future [int ]:
110+ return self .max_bridges
111+
112+ def __are_bridges_ready (self ) -> bool | Future [bool ]:
113+ return self .max_bridges != 0 and len (self .bridges ) == self .max_bridges
114+
115+ def __go (self ):
116+ Variable (Handshake .DEISA_WAIT_FOR_GO_VARIABLE , client = self .client ).set (None )
117+
118+ def __init__ (self , who : str , client : Client , ** kwargs ):
119+ self .client = client
120+ # self.client.direct_to_workers() # TODO
121+ self .handshake_actor = self .__get_handshake_actor ()
122+ assert self .handshake_actor is not None
123+
124+ if who is 'bridge' :
125+ self .start_bridge (** kwargs )
126+ elif who is 'deisa' :
127+ self .start_deisa (** kwargs )
128+ else :
129+ raise ValueError ("Expecting 'bridge' or 'deisa'." )
130+
131+ def start_bridge (self , id : int , max : int , arrays_metadata : dict , wait_for_go = True ) -> None :
132+ """
133+ Bridge must wait for analytics to be ready.
134+ """
135+ assert self .handshake_actor is not None
136+ self .handshake_actor .add_bridge (id , max )
137+
138+ if id == 0 :
139+ self .handshake_actor .set_arrays_metadata (arrays_metadata )
140+
141+ # wait for go
142+ if wait_for_go :
143+ self .__wait_for_go ()
144+
145+ def start_deisa (self , wait_for_go = True ) -> None :
146+ """
147+ When analytics is ready, notify all Bridges
148+ """
149+ assert self .handshake_actor is not None
150+ self .handshake_actor .set_analytics_ready ()
151+
152+ # wait for go
153+ if wait_for_go :
154+ self .__wait_for_go ()
155+
156+ def get_arrays_metadata (self ) -> dict :
157+ assert self .handshake_actor is not None
158+ return self .handshake_actor .get_arrays_metadata ().result ()
159+
160+ def get_nb_bridges (self ) -> int :
161+ assert self .handshake_actor is not None
162+ return self .handshake_actor .get_max_bridges ().result ()
163+
164+ def __get_handshake_actor (self ) -> HandshakeActor :
165+ try :
166+ return Variable (Handshake .DEISA_HANDSHAKE_ACTOR_FUTURE_VARIABLE , client = self .client ).get (timeout = 0 ).result ()
167+ except asyncio .exceptions .TimeoutError :
168+ actor_future = self .client .submit (Handshake .HandshakeActor , actor = True )
169+ Variable (Handshake .DEISA_HANDSHAKE_ACTOR_FUTURE_VARIABLE , client = self .client ).set (actor_future )
170+ return actor_future .result ()
171+
172+ def __wait_for_go (self ) -> None :
173+ Variable (Handshake .DEISA_WAIT_FOR_GO_VARIABLE , client = self .client ).get ()
174+
175+
69176class Bridge :
70- def __init__ (self , mpi_comm_size : int , mpi_rank : int , arrays_metadata : dict [str , dict ],
177+ def __init__ (self , mpi_comm_size : int , mpi_rank : int ,
178+ arrays_metadata : dict [str , dict ],
71179 get_connection_info : Callable , * args , ** kwargs ):
72180 """
73181 Initializes an object to manage communication between an MPI-based distributed
@@ -101,24 +209,14 @@ def __init__(self, mpi_comm_size: int, mpi_rank: int, arrays_metadata: dict[str,
101209 :param kwargs: Currently unused.
102210 :type kwargs: dict
103211 """
104-
212+ # system_metadata: Callable[[], dict[str, dict]],
105213 self .client = get_connection_info ()
106- self .mpi_rank = mpi_rank
107214 self .arrays_metadata = arrays_metadata
215+ self .mpi_rank = mpi_rank
108216 self .futures = []
109217
110- # TODO: check this
111- # Note: Blocking call. Simulation will wait for the analysis code to be run.
112- # Variable("workers") is set in the Deisa class.
113- workers = Variable ("workers" , client = self .client ).get ()
114- if mpi_comm_size > len (workers ): # more processes than workers
115- self .workers = [workers [mpi_rank % len (workers )]]
116- else :
117- k = len (workers ) // mpi_comm_size # more workers than processes
118- self .workers = workers [mpi_rank * k :mpi_rank * k + k ]
119-
120- if self .mpi_rank == 0 :
121- Queue ("Arrays" , client = self .client ).put (self .arrays_metadata )
218+ # blocking until analytics is ready
219+ Handshake ('bridge' , self .client , id = mpi_rank , max = mpi_comm_size , arrays_metadata = arrays_metadata , ** kwargs )
122220
123221 def publish_data (self , array_name : str , data : np .ndarray , iteration : int ):
124222 """
@@ -138,7 +236,8 @@ def publish_data(self, array_name: str, data: np.ndarray, iteration: int):
138236
139237 assert self .client .status == 'running' , "Client is not connected to a scheduler. Please check your connection."
140238
141- f = self .client .scatter (data , direct = True , workers = self .workers ) # send data to workers
239+ # TODO: select workers to send data to. self.client.scatter(data, direct=True, workers=self.workers)
240+ f = self .client .scatter (data , direct = True ) # send data to workers
142241
143242 # TODO: this is a memory leak. Find a way to release the futures once they are used to build a dask array in the client code.
144243 self .futures .append (f )
@@ -160,32 +259,24 @@ def publish_data(self, array_name: str, data: np.ndarray, iteration: int):
160259class Deisa :
161260 SLIDING_WINDOW_THREAD_PREFIX = "deisa_sliding_window_callback_"
162261
163- def __init__ (self , mpi_comm_size , nb_workers , get_connection_info : Callable , * args , ** kwargs ):
262+ def __init__ (self , get_connection_info : Callable , * args , ** kwargs ):
164263 """
165264 Initializes the distributed processing environment and configures workers using
166265 a Dask scheduler. This class handles setting up a Dask client and ensures the
167266 specified number of workers are available for distributed computation tasks.
168267
169- :param mpi_comm_size: Number of MPI processes for the computation.
170- :param nb_workers: Expected number of workers to be synchronized with the
171- Dask client.
172268 :param get_connection_info: A function that returns a connected Dask Client.
173269 :type get_connection_info: Callable
174270 """
175271 # dask.config.set({"distributed.deploy.lost-worker-timeout": 60, "distributed.workers.memory.spill":0.97, "distributed.workers.memory.target":0.95, "distributed.workers.memory.terminate":0.99 })
176272
177- self .client = get_connection_info ()
178-
179- # Wait for all workers to be available.
180- self .workers = [w_addr for w_addr in self .client .scheduler_info ()["workers" ].keys ()]
181- while len (self .workers ) != nb_workers :
182- self .workers = [w_addr for w_addr in self .client .scheduler_info ()["workers" ].keys ()]
273+ self .client : Client = get_connection_info ()
183274
184- Variable ("workers" , client = self .client ).set (self .workers )
275+ # blocking until all bridges are ready
276+ handshake = Handshake ('deisa' , self .client , ** kwargs )
185277
186- # print(self.workers)
187- self .mpi_comm_size = mpi_comm_size
188- self .arrays_metadata = None
278+ self .mpi_comm_size = handshake .get_nb_bridges ()
279+ self .arrays_metadata = handshake .get_arrays_metadata ()
189280 self .sliding_window_callback_threads : dict [str , threading .Thread ] = {}
190281 self .sliding_window_callback_thread_lock = threading .Lock ()
191282
@@ -212,8 +303,8 @@ def close(self):
212303 def get_array (self , name : str , timeout = None ) -> tuple [Array , int ]:
213304 """Retrieve a Dask array for a given array name."""
214305
215- if self .arrays_metadata is None :
216- self .arrays_metadata = Queue ("Arrays" , client = self .client ).get (timeout = timeout )
306+ # if self.arrays_metadata is None:
307+ # self.arrays_metadata = Queue("Arrays", client=self.client).get(timeout=timeout)
217308 # arrays_metadata will look something like this:
218309 # arrays_metadata = {
219310 # 'global_t': {
0 commit comments