Skip to content

Commit 8b5aa97

Browse files
Add handshake
Bridge now waits for analytics (Deisa) to be ready.
1 parent d868c4f commit 8b5aa97

File tree

3 files changed

+183
-63
lines changed

3 files changed

+183
-63
lines changed

src/deisa/dask/deisa.py

Lines changed: 124 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import numpy as np
4242
from dask.array import Array
4343
from dask.distributed import comm, Queue, Variable
44-
from distributed import Client, Future
44+
from distributed import Client, Future, get_client
4545

4646

4747
def get_connection_info(dask_scheduler_address: str | Client) -> Client:
@@ -66,8 +66,116 @@ def get_connection_info(dask_scheduler_address: str | Client) -> Client:
6666
return client
6767

6868

69+
class Handshake:
70+
DEISA_HANDSHAKE_ACTOR_FUTURE_VARIABLE = 'deisa_handshake_actor_future'
71+
DEISA_WAIT_FOR_GO_VARIABLE = 'deisa_handshake_wait_for_go'
72+
73+
class HandshakeActor:
74+
bridges = []
75+
max_bridges = 0
76+
arrays_metadata = {}
77+
analytics_ready = False
78+
79+
def __init__(self):
80+
self.bridges = []
81+
self.max_bridges = 0
82+
self.arrays_metadata = {}
83+
self.analytics_ready = False
84+
self.client = get_client()
85+
86+
def add_bridge(self, id: int, max: int) -> None:
87+
if max == 0:
88+
raise ValueError('max cannot be 0.')
89+
elif self.max_bridges == 0:
90+
self.max_bridges = max
91+
elif self.max_bridges != max:
92+
raise ValueError(f'Value {max} for bridge {id} is unexpected. Expecting max={self.max_bridges}.')
93+
elif len(self.bridges) >= max:
94+
raise RuntimeError(f'add_bridge cannot be called more than {max} times.')
95+
96+
self.bridges.append(id)
97+
98+
def set_analytics_ready(self) -> None:
99+
self.analytics_ready = True
100+
if self.__are_bridges_ready():
101+
self.__go()
102+
103+
def set_arrays_metadata(self, arrays_metadata: dict) -> None:
104+
self.arrays_metadata = arrays_metadata
105+
106+
def get_arrays_metadata(self) -> dict | Future[dict]:
107+
return self.arrays_metadata
108+
109+
def get_max_bridges(self) -> int | Future[int]:
110+
return self.max_bridges
111+
112+
def __are_bridges_ready(self) -> bool | Future[bool]:
113+
return self.max_bridges != 0 and len(self.bridges) == self.max_bridges
114+
115+
def __go(self):
116+
Variable(Handshake.DEISA_WAIT_FOR_GO_VARIABLE, client=self.client).set(None)
117+
118+
def __init__(self, who: str, client: Client, **kwargs):
119+
self.client = client
120+
# self.client.direct_to_workers() # TODO
121+
self.handshake_actor = self.__get_handshake_actor()
122+
assert self.handshake_actor is not None
123+
124+
if who is 'bridge':
125+
self.start_bridge(**kwargs)
126+
elif who is 'deisa':
127+
self.start_deisa(**kwargs)
128+
else:
129+
raise ValueError("Expecting 'bridge' or 'deisa'.")
130+
131+
def start_bridge(self, id: int, max: int, arrays_metadata: dict, wait_for_go=True) -> None:
132+
"""
133+
Bridge must wait for analytics to be ready.
134+
"""
135+
assert self.handshake_actor is not None
136+
self.handshake_actor.add_bridge(id, max)
137+
138+
if id == 0:
139+
self.handshake_actor.set_arrays_metadata(arrays_metadata)
140+
141+
# wait for go
142+
if wait_for_go:
143+
self.__wait_for_go()
144+
145+
def start_deisa(self, wait_for_go=True) -> None:
146+
"""
147+
When analytics is ready, notify all Bridges
148+
"""
149+
assert self.handshake_actor is not None
150+
self.handshake_actor.set_analytics_ready()
151+
152+
# wait for go
153+
if wait_for_go:
154+
self.__wait_for_go()
155+
156+
def get_arrays_metadata(self) -> dict:
157+
assert self.handshake_actor is not None
158+
return self.handshake_actor.get_arrays_metadata().result()
159+
160+
def get_nb_bridges(self) -> int:
161+
assert self.handshake_actor is not None
162+
return self.handshake_actor.get_max_bridges().result()
163+
164+
def __get_handshake_actor(self) -> HandshakeActor:
165+
try:
166+
return Variable(Handshake.DEISA_HANDSHAKE_ACTOR_FUTURE_VARIABLE, client=self.client).get(timeout=0).result()
167+
except asyncio.exceptions.TimeoutError:
168+
actor_future = self.client.submit(Handshake.HandshakeActor, actor=True)
169+
Variable(Handshake.DEISA_HANDSHAKE_ACTOR_FUTURE_VARIABLE, client=self.client).set(actor_future)
170+
return actor_future.result()
171+
172+
def __wait_for_go(self) -> None:
173+
Variable(Handshake.DEISA_WAIT_FOR_GO_VARIABLE, client=self.client).get()
174+
175+
69176
class Bridge:
70-
def __init__(self, mpi_comm_size: int, mpi_rank: int, arrays_metadata: dict[str, dict],
177+
def __init__(self, mpi_comm_size: int, mpi_rank: int,
178+
arrays_metadata: dict[str, dict],
71179
get_connection_info: Callable, *args, **kwargs):
72180
"""
73181
Initializes an object to manage communication between an MPI-based distributed
@@ -101,24 +209,14 @@ def __init__(self, mpi_comm_size: int, mpi_rank: int, arrays_metadata: dict[str,
101209
:param kwargs: Currently unused.
102210
:type kwargs: dict
103211
"""
104-
212+
# system_metadata: Callable[[], dict[str, dict]],
105213
self.client = get_connection_info()
106-
self.mpi_rank = mpi_rank
107214
self.arrays_metadata = arrays_metadata
215+
self.mpi_rank = mpi_rank
108216
self.futures = []
109217

110-
# TODO: check this
111-
# Note: Blocking call. Simulation will wait for the analysis code to be run.
112-
# Variable("workers") is set in the Deisa class.
113-
workers = Variable("workers", client=self.client).get()
114-
if mpi_comm_size > len(workers): # more processes than workers
115-
self.workers = [workers[mpi_rank % len(workers)]]
116-
else:
117-
k = len(workers) // mpi_comm_size # more workers than processes
118-
self.workers = workers[mpi_rank * k:mpi_rank * k + k]
119-
120-
if self.mpi_rank == 0:
121-
Queue("Arrays", client=self.client).put(self.arrays_metadata)
218+
# blocking until analytics is ready
219+
Handshake('bridge', self.client, id=mpi_rank, max=mpi_comm_size, arrays_metadata=arrays_metadata, **kwargs)
122220

123221
def publish_data(self, array_name: str, data: np.ndarray, iteration: int):
124222
"""
@@ -138,7 +236,8 @@ def publish_data(self, array_name: str, data: np.ndarray, iteration: int):
138236

139237
assert self.client.status == 'running', "Client is not connected to a scheduler. Please check your connection."
140238

141-
f = self.client.scatter(data, direct=True, workers=self.workers) # send data to workers
239+
# TODO: select workers to send data to. self.client.scatter(data, direct=True, workers=self.workers)
240+
f = self.client.scatter(data, direct=True) # send data to workers
142241

143242
# TODO: this is a memory leak. Find a way to release the futures once they are used to build a dask array in the client code.
144243
self.futures.append(f)
@@ -160,32 +259,24 @@ def publish_data(self, array_name: str, data: np.ndarray, iteration: int):
160259
class Deisa:
161260
SLIDING_WINDOW_THREAD_PREFIX = "deisa_sliding_window_callback_"
162261

163-
def __init__(self, mpi_comm_size, nb_workers, get_connection_info: Callable, *args, **kwargs):
262+
def __init__(self, get_connection_info: Callable, *args, **kwargs):
164263
"""
165264
Initializes the distributed processing environment and configures workers using
166265
a Dask scheduler. This class handles setting up a Dask client and ensures the
167266
specified number of workers are available for distributed computation tasks.
168267
169-
:param mpi_comm_size: Number of MPI processes for the computation.
170-
:param nb_workers: Expected number of workers to be synchronized with the
171-
Dask client.
172268
:param get_connection_info: A function that returns a connected Dask Client.
173269
:type get_connection_info: Callable
174270
"""
175271
# dask.config.set({"distributed.deploy.lost-worker-timeout": 60, "distributed.workers.memory.spill":0.97, "distributed.workers.memory.target":0.95, "distributed.workers.memory.terminate":0.99 })
176272

177-
self.client = get_connection_info()
178-
179-
# Wait for all workers to be available.
180-
self.workers = [w_addr for w_addr in self.client.scheduler_info()["workers"].keys()]
181-
while len(self.workers) != nb_workers:
182-
self.workers = [w_addr for w_addr in self.client.scheduler_info()["workers"].keys()]
273+
self.client: Client = get_connection_info()
183274

184-
Variable("workers", client=self.client).set(self.workers)
275+
# blocking until all bridges are ready
276+
handshake = Handshake('deisa', self.client, **kwargs)
185277

186-
# print(self.workers)
187-
self.mpi_comm_size = mpi_comm_size
188-
self.arrays_metadata = None
278+
self.mpi_comm_size = handshake.get_nb_bridges()
279+
self.arrays_metadata = handshake.get_arrays_metadata()
189280
self.sliding_window_callback_threads: dict[str, threading.Thread] = {}
190281
self.sliding_window_callback_thread_lock = threading.Lock()
191282

@@ -212,8 +303,8 @@ def close(self):
212303
def get_array(self, name: str, timeout=None) -> tuple[Array, int]:
213304
"""Retrieve a Dask array for a given array name."""
214305

215-
if self.arrays_metadata is None:
216-
self.arrays_metadata = Queue("Arrays", client=self.client).get(timeout=timeout)
306+
# if self.arrays_metadata is None:
307+
# self.arrays_metadata = Queue("Arrays", client=self.client).get(timeout=timeout)
217308
# arrays_metadata will look something like this:
218309
# arrays_metadata = {
219310
# 'global_t': {

test/TestSimulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class TestSimulation:
3737
__test__ = False
3838

3939
def __init__(self, client: Client, global_grid_size: tuple, mpi_parallelism: tuple,
40-
arrays_metadata: dict[str, dict]):
40+
arrays_metadata: dict[str, dict], *args, **kwargs):
4141
self.client = client
4242
self.global_grid_size = global_grid_size
4343
self.mpi_parallelism = mpi_parallelism
@@ -49,7 +49,7 @@ def __init__(self, client: Client, global_grid_size: tuple, mpi_parallelism: tup
4949

5050
self.nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1]
5151
self.bridges: list[Bridge] = [
52-
Bridge(self.nb_mpi_ranks, rank, arrays_metadata, get_connection_info=lambda: client)
52+
Bridge(self.nb_mpi_ranks, rank, arrays_metadata, get_connection_info=lambda: client, *args, **kwargs)
5353
for rank in range(self.nb_mpi_ranks)]
5454

5555
def __gen_data(self, noise_level: int = 0) -> np.ndarray:

0 commit comments

Comments
 (0)