77
88import datetime
99from collections .abc import Awaitable , Callable , Iterator
10- from typing import Any , cast
10+ from typing import Any , AsyncIterator , cast
1111from uuid import uuid4
1212
1313import arrow
3636from simcore_service_director_v2 .utils .dask import generate_dask_job_id
3737from simcore_service_director_v2 .utils .db import to_clusters_db
3838from sqlalchemy .dialects .postgresql import insert as pg_insert
39+ from sqlalchemy .ext .asyncio import AsyncEngine
3940
4041
4142@pytest .fixture
@@ -75,12 +76,12 @@ def creator(**pipeline_kwargs) -> CompPipelineAtDB:
7576
7677
7778@pytest .fixture
78- def tasks (
79- postgres_db : sa . engine . Engine ,
80- ) -> Iterator [Callable [..., list [CompTaskAtDB ]]]:
79+ async def create_tasks (
80+ sqlalchemy_async_engine : AsyncEngine ,
81+ ) -> AsyncIterator [Callable [..., Awaitable [ list [CompTaskAtDB ] ]]]:
8182 created_task_ids : list [int ] = []
8283
83- def creator (
84+ async def creator (
8485 user : dict [str , Any ], project : ProjectAtDB , ** overrides_kwargs
8586 ) -> list [CompTaskAtDB ]:
8687 created_tasks : list [CompTaskAtDB ] = []
@@ -132,8 +133,8 @@ def creator(
132133 ),
133134 }
134135 task_config .update (** overrides_kwargs )
135- with postgres_db .connect () as conn :
136- result = conn .execute (
136+ async with sqlalchemy_async_engine .connect () as conn :
137+ result = await conn .execute (
137138 comp_tasks .insert ()
138139 .values (** task_config )
139140 .returning (sa .literal_column ("*" ))
@@ -146,8 +147,8 @@ def creator(
146147 yield creator
147148
148149 # cleanup
149- with postgres_db .connect () as conn :
150- conn .execute (
150+ async with sqlalchemy_async_engine .connect () as conn :
151+ await conn .execute (
151152 comp_tasks .delete ().where (comp_tasks .c .task_id .in_ (created_task_ids ))
152153 )
153154
@@ -186,12 +187,12 @@ def run_metadata(
186187
187188
188189@pytest .fixture
189- def runs (
190- postgres_db : sa . engine . Engine , run_metadata : RunMetadataDict
191- ) -> Iterator [Callable [..., CompRunsAtDB ]]:
190+ async def create_comp_run (
191+ sqlalchemy_async_engine : AsyncEngine , run_metadata : RunMetadataDict
192+ ) -> AsyncIterator [Callable [..., Awaitable [ CompRunsAtDB ] ]]:
192193 created_run_ids : list [int ] = []
193194
194- def creator (
195+ async def _ (
195196 user : dict [str , Any ], project : ProjectAtDB , ** run_kwargs
196197 ) -> CompRunsAtDB :
197198 run_config = {
@@ -203,8 +204,8 @@ def creator(
203204 "use_on_demand_clusters" : False ,
204205 }
205206 run_config .update (** run_kwargs )
206- with postgres_db .connect () as conn :
207- result = conn .execute (
207+ async with sqlalchemy_async_engine .connect () as conn :
208+ result = await conn .execute (
208209 comp_runs .insert ()
209210 .values (** jsonable_encoder (run_config ))
210211 .returning (sa .literal_column ("*" ))
@@ -213,11 +214,13 @@ def creator(
213214 created_run_ids .append (new_run .run_id )
214215 return new_run
215216
216- yield creator
217+ yield _
217218
218219 # cleanup
219- with postgres_db .connect () as conn :
220- conn .execute (comp_runs .delete ().where (comp_runs .c .run_id .in_ (created_run_ids )))
220+ async with sqlalchemy_async_engine .connect () as conn :
221+ await conn .execute (
222+ comp_runs .delete ().where (comp_runs .c .run_id .in_ (created_run_ids ))
223+ )
221224
222225
223226@pytest .fixture
@@ -299,7 +302,7 @@ async def publish_project(
299302 registered_user : Callable [..., dict [str , Any ]],
300303 project : Callable [..., Awaitable [ProjectAtDB ]],
301304 pipeline : Callable [..., CompPipelineAtDB ],
302- tasks : Callable [..., list [CompTaskAtDB ]],
305+ create_tasks : Callable [..., list [CompTaskAtDB ]],
303306 fake_workbench_without_outputs : dict [str , Any ],
304307 fake_workbench_adjacency : dict [str , Any ],
305308) -> Callable [[], Awaitable [PublishedProject ]]:
@@ -313,7 +316,9 @@ async def _() -> PublishedProject:
313316 project_id = f"{ created_project .uuid } " ,
314317 dag_adjacency_list = fake_workbench_adjacency ,
315318 ),
316- tasks = tasks (user = user , project = created_project , state = StateType .PUBLISHED ),
319+ tasks = create_tasks (
320+ user = user , project = created_project , state = StateType .PUBLISHED
321+ ),
317322 )
318323
319324 return _
@@ -331,8 +336,8 @@ async def running_project(
331336 registered_user : Callable [..., dict [str , Any ]],
332337 project : Callable [..., Awaitable [ProjectAtDB ]],
333338 pipeline : Callable [..., CompPipelineAtDB ],
334- tasks : Callable [..., list [CompTaskAtDB ]],
335- runs : Callable [..., CompRunsAtDB ],
339+ create_tasks : Callable [..., list [CompTaskAtDB ]],
340+ create_comp_run : Callable [..., CompRunsAtDB ],
336341 fake_workbench_without_outputs : dict [str , Any ],
337342 fake_workbench_adjacency : dict [str , Any ],
338343) -> RunningProject :
@@ -345,14 +350,14 @@ async def running_project(
345350 project_id = f"{ created_project .uuid } " ,
346351 dag_adjacency_list = fake_workbench_adjacency ,
347352 ),
348- tasks = tasks (
353+ tasks = create_tasks (
349354 user = user ,
350355 project = created_project ,
351356 state = StateType .RUNNING ,
352357 progress = 0.0 ,
353358 start = now_time ,
354359 ),
355- runs = runs (
360+ runs = create_comp_run (
356361 user = user ,
357362 project = created_project ,
358363 started = now_time ,
@@ -367,8 +372,8 @@ async def running_project_mark_for_cancellation(
367372 registered_user : Callable [..., dict [str , Any ]],
368373 project : Callable [..., Awaitable [ProjectAtDB ]],
369374 pipeline : Callable [..., CompPipelineAtDB ],
370- tasks : Callable [..., list [CompTaskAtDB ]],
371- runs : Callable [..., CompRunsAtDB ],
375+ create_tasks : Callable [..., list [CompTaskAtDB ]],
376+ create_comp_run : Callable [..., CompRunsAtDB ],
372377 fake_workbench_without_outputs : dict [str , Any ],
373378 fake_workbench_adjacency : dict [str , Any ],
374379) -> RunningProject :
@@ -381,14 +386,14 @@ async def running_project_mark_for_cancellation(
381386 project_id = f"{ created_project .uuid } " ,
382387 dag_adjacency_list = fake_workbench_adjacency ,
383388 ),
384- tasks = tasks (
389+ tasks = create_tasks (
385390 user = user ,
386391 project = created_project ,
387392 state = StateType .RUNNING ,
388393 progress = 0.0 ,
389394 start = now_time ,
390395 ),
391- runs = runs (
396+ runs = create_comp_run (
392397 user = user ,
393398 project = created_project ,
394399 result = StateType .RUNNING ,
0 commit comments