Skip to content

Commit d868c4f

Browse files
use get_connection_info in Deisa and Bridge
1 parent fcfa1c7 commit d868c4f

File tree

3 files changed

+48
-46
lines changed

3 files changed

+48
-46
lines changed

src/deisa/dask/deisa.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,36 @@
4444
from distributed import Client, Future
4545

4646

47+
def get_connection_info(dask_scheduler_address: str | Client) -> Client:
48+
if isinstance(dask_scheduler_address, Client):
49+
client = dask_scheduler_address
50+
elif isinstance(dask_scheduler_address, str):
51+
try:
52+
client = Client(address=dask_scheduler_address)
53+
except ValueError:
54+
# try scheduler_file
55+
if os.path.isfile(dask_scheduler_address):
56+
client = Client(scheduler_file=dask_scheduler_address)
57+
else:
58+
raise ValueError(
59+
"dask_scheduler_address must be a string containing the address of the scheduler, "
60+
"or a string containing a file name to a dask scheduler file, or a Dask Client object.")
61+
else:
62+
raise ValueError(
63+
"dask_scheduler_address must be a string containing the address of the scheduler, "
64+
"or a string containing a file name to a dask scheduler file, or a Dask Client object.")
65+
66+
return client
67+
68+
4769
class Bridge:
48-
def __init__(self, dask_scheduler_address: str | Client, mpi_comm_size: int, mpi_rank: int,
49-
arrays_metadata: dict[str, dict], **kwargs):
70+
def __init__(self, mpi_comm_size: int, mpi_rank: int, arrays_metadata: dict[str, dict],
71+
get_connection_info: Callable, *args, **kwargs):
5072
"""
5173
Initializes an object to manage communication between an MPI-based distributed
5274
system and a Dask-based framework. The class ensures proper allocation of workers
5375
among processes and instantiates the required communication objects like queues.
5476
55-
:param dask_scheduler_address: Address of the Dask Scheduler as a string or an instance of
56-
Dask Client that facilitates communication with the cluster.
5777
:type dask_scheduler_address: str | Client
5878
5979
:param mpi_comm_size: Total number of MPI processes involved in the computation.
@@ -75,17 +95,14 @@ def __init__(self, dask_scheduler_address: str | Client, mpi_comm_size: int, mpi
7595
}
7696
:type arrays_metadata: dict[str, dict]
7797
98+
:param get_connection_info: A function that returns a connected Dask Client.
99+
:type get_connection_info: Callable
100+
78101
:param kwargs: Currently unused.
79102
:type kwargs: dict
80103
"""
81104

82-
if isinstance(dask_scheduler_address, str):
83-
self.client = Client(dask_scheduler_address)
84-
elif isinstance(dask_scheduler_address, Client):
85-
self.client = dask_scheduler_address
86-
else:
87-
raise ValueError("dask_scheduler_address must be a string or a Dask Client object.")
88-
105+
self.client = get_connection_info()
89106
self.mpi_rank = mpi_rank
90107
self.arrays_metadata = arrays_metadata
91108
self.futures = []
@@ -143,38 +160,21 @@ def publish_data(self, array_name: str, data: np.ndarray, iteration: int):
143160
class Deisa:
144161
SLIDING_WINDOW_THREAD_PREFIX = "deisa_sliding_window_callback_"
145162

146-
def __init__(self, dask_scheduler_address: str | Client, mpi_comm_size: int, nb_workers: int):
163+
def __init__(self, mpi_comm_size, nb_workers, get_connection_info: Callable, *args, **kwargs):
147164
"""
148165
Initializes the distributed processing environment and configures workers using
149166
a Dask scheduler. This class handles setting up a Dask client and ensures the
150167
specified number of workers are available for distributed computation tasks.
151168
152-
:param dask_scheduler_address: Instance of Dask's Client to connect to the cluster,
153-
or address string of the Dask scheduler,
154-
or a string containing a file name to a dask scheduler file.
155169
:param mpi_comm_size: Number of MPI processes for the computation.
156170
:param nb_workers: Expected number of workers to be synchronized with the
157171
Dask client.
172+
:param get_connection_info: A function that returns a connected Dask Client.
173+
:type get_connection_info: Callable
158174
"""
159175
# dask.config.set({"distributed.deploy.lost-worker-timeout": 60, "distributed.workers.memory.spill":0.97, "distributed.workers.memory.target":0.95, "distributed.workers.memory.terminate":0.99 })
160176

161-
if isinstance(dask_scheduler_address, Client):
162-
self.client = dask_scheduler_address
163-
elif isinstance(dask_scheduler_address, str):
164-
try:
165-
self.client = Client(address=dask_scheduler_address)
166-
except ValueError:
167-
# try scheduler_file
168-
if os.path.isfile(dask_scheduler_address):
169-
self.client = Client(scheduler_file=dask_scheduler_address)
170-
else:
171-
raise ValueError(
172-
"dask_scheduler_address must be a string containing the address of the scheduler, "
173-
"or a string containing a file name to a dask scheduler file, or a Dask Client object.")
174-
else:
175-
raise ValueError(
176-
"dask_scheduler_address must be a string containing the address of the scheduler, "
177-
"or a string containing a file name to a dask scheduler file, or a Dask Client object.")
177+
self.client = get_connection_info()
178178

179179
# Wait for all workers to be available.
180180
self.workers = [w_addr for w_addr in self.client.scheduler_info()["workers"].keys()]

test/TestSimulator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
# =============================================================================
2929

3030
import numpy as np
31-
from deisa.common import BridgeInterface
32-
from deisa.dask import Bridge
3331
from distributed import Client
3432

33+
from deisa.dask import Bridge
34+
3535

3636
class TestSimulation:
3737
__test__ = False
@@ -48,8 +48,9 @@ def __init__(self, client: Client, global_grid_size: tuple, mpi_parallelism: tup
4848
assert global_grid_size[1] % mpi_parallelism[1] == 0, "cannot compute local grid size for y dimension"
4949

5050
self.nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1]
51-
self.bridges: list[BridgeInterface] = [Bridge(client, self.nb_mpi_ranks, rank, arrays_metadata) for rank in
52-
range(self.nb_mpi_ranks)]
51+
self.bridges: list[Bridge] = [
52+
Bridge(self.nb_mpi_ranks, rank, arrays_metadata, get_connection_info=lambda: client)
53+
for rank in range(self.nb_mpi_ranks)]
5354

5455
def __gen_data(self, noise_level: int = 0) -> np.ndarray:
5556
# Create coordinate grid

test/test_deisa.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from distributed import Client, LocalCluster, Queue, Variable
4040

4141
from TestSimulator import TestSimulation
42-
from deisa.dask import Deisa
42+
from deisa.dask import Deisa, get_connection_info
4343

4444

4545
@pytest.mark.parametrize('global_shape', [(32, 32), (32, 16), (16, 32)])
@@ -128,30 +128,31 @@ def env_setup_tcp_cluster(self):
128128
def test_deisa_ctor_client(self, env_setup_tcp_cluster):
129129
cluster = env_setup_tcp_cluster
130130
client = Client(cluster)
131-
deisa = Deisa(client, mpi_comm_size=0, nb_workers=0)
131+
deisa = Deisa(mpi_comm_size=0, nb_workers=0, get_connection_info=lambda: client)
132132
assert deisa.client is not None, "Deisa should not be None"
133133
assert deisa.client.scheduler.address == cluster.scheduler_address, "Client should be the same as scheduler"
134134
deisa.close()
135135

136136
def test_deisa_ctor_str(self, env_setup_tcp_cluster):
137137
cluster = env_setup_tcp_cluster
138-
deisa = Deisa('tcp://127.0.0.1:4242', mpi_comm_size=0, nb_workers=0)
138+
deisa = Deisa(mpi_comm_size=0, nb_workers=0,
139+
get_connection_info=lambda: get_connection_info('tcp://127.0.0.1:4242'))
139140
assert deisa.client is not None, "Deisa should not be None"
140141
assert deisa.client.scheduler.address == cluster.scheduler_address, "Client should be the same as scheduler"
141142
deisa.close()
142143

143144
def test_deisa_ctor_scheduler_file(self, env_setup_tcp_cluster):
144145
cluster = env_setup_tcp_cluster
145146
f = os.path.abspath(os.path.dirname(__file__)) + os.path.sep + 'test-scheduler.json'
146-
deisa = Deisa(f, mpi_comm_size=0, nb_workers=0)
147+
deisa = Deisa(mpi_comm_size=0, nb_workers=0, get_connection_info=lambda: get_connection_info(f))
147148
assert deisa.client is not None, "Deisa should not be None"
148149
assert deisa.client.scheduler.address == cluster.scheduler_address, "Client should be the same as scheduler"
149150
deisa.close()
150151

151152
def test_deisa_ctor_scheduler_file_error(self):
152153
with pytest.raises(ValueError) as e:
153154
f = os.path.abspath(os.path.dirname(__file__)) + os.path.sep + 'test-scheduler-error.json'
154-
deisa = Deisa(f, mpi_comm_size=0, nb_workers=0)
155+
deisa = Deisa(mpi_comm_size=0, nb_workers=0, get_connection_info=lambda: get_connection_info(f))
155156

156157

157158
class TestUsingDaskCluster:
@@ -243,7 +244,7 @@ def test_get_dask_array(self, global_grid_size: tuple, mpi_parallelism: tuple, n
243244
nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1]
244245
nb_workers = len(cluster.workers)
245246

246-
deisa = Deisa(client, nb_mpi_ranks, nb_workers)
247+
deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client)
247248
sim = TestSimulation(client,
248249
global_grid_size=global_grid_size,
249250
mpi_parallelism=mpi_parallelism,
@@ -279,7 +280,7 @@ def test_sliding_window_callback_register(self, global_grid_size: tuple, mpi_par
279280
nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1]
280281
nb_workers = len(cluster.workers)
281282

282-
deisa = Deisa(client, nb_mpi_ranks, nb_workers)
283+
deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client)
283284
sim = TestSimulation(client,
284285
global_grid_size=global_grid_size,
285286
mpi_parallelism=mpi_parallelism,
@@ -332,7 +333,7 @@ def test_sliding_window_callback_unregister(self, env_setup):
332333
nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1]
333334
nb_workers = len(cluster.workers)
334335

335-
deisa = Deisa(client, nb_mpi_ranks, nb_workers)
336+
deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client)
336337
sim = TestSimulation(client,
337338
global_grid_size=global_grid_size,
338339
mpi_parallelism=mpi_parallelism,
@@ -379,7 +380,7 @@ def test_sliding_window_callback_throws(self, env_setup):
379380
nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1]
380381
nb_workers = len(cluster.workers)
381382

382-
deisa = Deisa(client, nb_mpi_ranks, nb_workers)
383+
deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client)
383384
sim = TestSimulation(client,
384385
global_grid_size=global_grid_size,
385386
mpi_parallelism=mpi_parallelism,
@@ -450,7 +451,7 @@ def test_sliding_window_map_blocks(self, env_setup):
450451
nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1]
451452
nb_workers = len(cluster.workers)
452453

453-
deisa = Deisa(client, nb_mpi_ranks, nb_workers)
454+
deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client)
454455
sim = TestSimulation(client,
455456
global_grid_size=global_grid_size,
456457
mpi_parallelism=mpi_parallelism,

0 commit comments

Comments
 (0)