|
39 | 39 | from distributed import Client, LocalCluster, Queue, Variable |
40 | 40 |
|
41 | 41 | from TestSimulator import TestSimulation |
42 | | -from deisa.dask import Deisa |
| 42 | +from deisa.dask import Deisa, get_connection_info |
43 | 43 |
|
44 | 44 |
|
45 | 45 | @pytest.mark.parametrize('global_shape', [(32, 32), (32, 16), (16, 32)]) |
@@ -128,30 +128,31 @@ def env_setup_tcp_cluster(self): |
128 | 128 | def test_deisa_ctor_client(self, env_setup_tcp_cluster): |
129 | 129 | cluster = env_setup_tcp_cluster |
130 | 130 | client = Client(cluster) |
131 | | - deisa = Deisa(client, mpi_comm_size=0, nb_workers=0) |
| 131 | + deisa = Deisa(mpi_comm_size=0, nb_workers=0, get_connection_info=lambda: client) |
132 | 132 | assert deisa.client is not None, "Deisa should not be None" |
133 | 133 | assert deisa.client.scheduler.address == cluster.scheduler_address, "Client should be the same as scheduler" |
134 | 134 | deisa.close() |
135 | 135 |
|
136 | 136 | def test_deisa_ctor_str(self, env_setup_tcp_cluster): |
137 | 137 | cluster = env_setup_tcp_cluster |
138 | | - deisa = Deisa('tcp://127.0.0.1:4242', mpi_comm_size=0, nb_workers=0) |
| 138 | + deisa = Deisa(mpi_comm_size=0, nb_workers=0, |
| 139 | + get_connection_info=lambda: get_connection_info('tcp://127.0.0.1:4242')) |
139 | 140 | assert deisa.client is not None, "Deisa should not be None" |
140 | 141 | assert deisa.client.scheduler.address == cluster.scheduler_address, "Client should be the same as scheduler" |
141 | 142 | deisa.close() |
142 | 143 |
|
143 | 144 | def test_deisa_ctor_scheduler_file(self, env_setup_tcp_cluster): |
144 | 145 | cluster = env_setup_tcp_cluster |
145 | 146 | f = os.path.abspath(os.path.dirname(__file__)) + os.path.sep + 'test-scheduler.json' |
146 | | - deisa = Deisa(f, mpi_comm_size=0, nb_workers=0) |
| 147 | + deisa = Deisa(mpi_comm_size=0, nb_workers=0, get_connection_info=lambda: get_connection_info(f)) |
147 | 148 | assert deisa.client is not None, "Deisa should not be None" |
148 | 149 | assert deisa.client.scheduler.address == cluster.scheduler_address, "Client should be the same as scheduler" |
149 | 150 | deisa.close() |
150 | 151 |
|
151 | 152 | def test_deisa_ctor_scheduler_file_error(self): |
152 | 153 | with pytest.raises(ValueError) as e: |
153 | 154 | f = os.path.abspath(os.path.dirname(__file__)) + os.path.sep + 'test-scheduler-error.json' |
154 | | - deisa = Deisa(f, mpi_comm_size=0, nb_workers=0) |
| 155 | + deisa = Deisa(mpi_comm_size=0, nb_workers=0, get_connection_info=lambda: get_connection_info(f)) |
155 | 156 |
|
156 | 157 |
|
157 | 158 | class TestUsingDaskCluster: |
@@ -243,7 +244,7 @@ def test_get_dask_array(self, global_grid_size: tuple, mpi_parallelism: tuple, n |
243 | 244 | nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1] |
244 | 245 | nb_workers = len(cluster.workers) |
245 | 246 |
|
246 | | - deisa = Deisa(client, nb_mpi_ranks, nb_workers) |
| 247 | + deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client) |
247 | 248 | sim = TestSimulation(client, |
248 | 249 | global_grid_size=global_grid_size, |
249 | 250 | mpi_parallelism=mpi_parallelism, |
@@ -279,7 +280,7 @@ def test_sliding_window_callback_register(self, global_grid_size: tuple, mpi_par |
279 | 280 | nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1] |
280 | 281 | nb_workers = len(cluster.workers) |
281 | 282 |
|
282 | | - deisa = Deisa(client, nb_mpi_ranks, nb_workers) |
| 283 | + deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client) |
283 | 284 | sim = TestSimulation(client, |
284 | 285 | global_grid_size=global_grid_size, |
285 | 286 | mpi_parallelism=mpi_parallelism, |
@@ -332,7 +333,7 @@ def test_sliding_window_callback_unregister(self, env_setup): |
332 | 333 | nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1] |
333 | 334 | nb_workers = len(cluster.workers) |
334 | 335 |
|
335 | | - deisa = Deisa(client, nb_mpi_ranks, nb_workers) |
| 336 | + deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client) |
336 | 337 | sim = TestSimulation(client, |
337 | 338 | global_grid_size=global_grid_size, |
338 | 339 | mpi_parallelism=mpi_parallelism, |
@@ -379,7 +380,7 @@ def test_sliding_window_callback_throws(self, env_setup): |
379 | 380 | nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1] |
380 | 381 | nb_workers = len(cluster.workers) |
381 | 382 |
|
382 | | - deisa = Deisa(client, nb_mpi_ranks, nb_workers) |
| 383 | + deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client) |
383 | 384 | sim = TestSimulation(client, |
384 | 385 | global_grid_size=global_grid_size, |
385 | 386 | mpi_parallelism=mpi_parallelism, |
@@ -450,7 +451,7 @@ def test_sliding_window_map_blocks(self, env_setup): |
450 | 451 | nb_mpi_ranks = mpi_parallelism[0] * mpi_parallelism[1] |
451 | 452 | nb_workers = len(cluster.workers) |
452 | 453 |
|
453 | | - deisa = Deisa(client, nb_mpi_ranks, nb_workers) |
| 454 | + deisa = Deisa(nb_mpi_ranks, nb_workers, get_connection_info=lambda: client) |
454 | 455 | sim = TestSimulation(client, |
455 | 456 | global_grid_size=global_grid_size, |
456 | 457 | mpi_parallelism=mpi_parallelism, |
|
0 commit comments