5252from simcore_service_webserver .session .plugin import setup_session
5353from simcore_service_webserver .socketio .plugin import setup_socketio
5454from simcore_service_webserver .users .plugin import setup_users
55+ from sqlalchemy .ext .asyncio import AsyncEngine
5556from tenacity .asyncio import AsyncRetrying
5657from tenacity .retry import retry_if_exception_type
5758from tenacity .stop import stop_after_delay
@@ -118,7 +119,7 @@ def user_role_response():
118119
119120@pytest .fixture
120121async def client (
121- postgres_db : sa . engine . Engine ,
122+ sqlalchemy_async_engine : AsyncEngine ,
122123 rabbit_service : RabbitSettings ,
123124 redis_settings : RedisSettings ,
124125 aiohttp_client : Callable ,
@@ -170,29 +171,29 @@ def fake_workbench_adjacency_list(tests_data_dir: Path) -> dict[str, Any]:
170171 return json .load (fp )
171172
172173
173- def _assert_db_contents (
174+ async def _assert_db_contents (
174175 project_id : str ,
175- postgres_db : sa . engine . Engine ,
176+ sqlalchemy_async_engine : AsyncEngine ,
176177 fake_workbench_payload : dict [str , Any ],
177178 fake_workbench_adjacency_list : dict [str , Any ],
178179 check_outputs : bool ,
179180) -> None :
180- with postgres_db .connect () as conn :
181- pipeline_db = conn .execute (
182- sa .select (comp_pipeline ).where (comp_pipeline .c .project_id == project_id )
183- ).fetchone ()
184- assert pipeline_db
181+ async with sqlalchemy_async_engine .connect () as conn :
182+ pipeline_db = (
183+ await conn .execute (
184+ sa .select (comp_pipeline ).where (comp_pipeline .c .project_id == project_id )
185+ )
186+ ).one ()
185187
186- assert pipeline_db [comp_pipeline .c .project_id ] == project_id
187- assert (
188- pipeline_db [comp_pipeline .c .dag_adjacency_list ]
189- == fake_workbench_adjacency_list
190- )
188+ assert pipeline_db .project_id == project_id
189+ assert pipeline_db .dag_adjacency_list == fake_workbench_adjacency_list
191190
192191 # check db comp_tasks
193- tasks_db = conn .execute (
194- sa .select (comp_tasks ).where (comp_tasks .c .project_id == project_id )
195- ).fetchall ()
192+ tasks_db = (
193+ await conn .execute (
194+ sa .select (comp_tasks ).where (comp_tasks .c .project_id == project_id )
195+ )
196+ ).all ()
196197 assert tasks_db
197198
198199 mock_pipeline = fake_workbench_payload
@@ -214,45 +215,43 @@ def _assert_db_contents(
214215NodeIdStr = str
215216
216217
217- def _get_computational_tasks_from_db (
218+ async def _get_computational_tasks_from_db (
218219 project_id : str ,
219- postgres_db : sa . engine . Engine ,
220+ sqlalchemy_async_engine : AsyncEngine ,
220221) -> dict [NodeIdStr , Any ]:
221222 # this check is only there to check the comp_pipeline is there
222- with postgres_db .connect () as conn :
223+ async with sqlalchemy_async_engine .connect () as conn :
223224 assert (
224- conn .execute (
225+ await conn .execute (
225226 sa .select (comp_pipeline ).where (comp_pipeline .c .project_id == project_id )
226- ).fetchone ()
227- is not None
228- ), f"missing pipeline in the database under comp_pipeline { project_id } "
227+ )
228+ ).one (), f"missing pipeline in the database under comp_pipeline { project_id } "
229229
230230 # get the computational tasks
231- tasks_db = conn .execute (
232- sa .select (comp_tasks ).where (
233- (comp_tasks .c .project_id == project_id )
234- & (comp_tasks .c .node_class == NodeClass .COMPUTATIONAL )
231+ tasks_db = (
232+ await conn .execute (
233+ sa .select (comp_tasks ).where (
234+ (comp_tasks .c .project_id == project_id )
235+ & (comp_tasks .c .node_class == NodeClass .COMPUTATIONAL )
236+ )
235237 )
236- ).fetchall ()
238+ ).all ()
237239
238240 print (f"--> tasks from DB: { tasks_db = } " )
239241 return {t .node_id : t for t in tasks_db }
240242
241243
242- def _get_project_workbench_from_db (
244+ async def _get_project_workbench_from_db (
243245 project_id : str ,
244- postgres_db : sa . engine . Engine ,
246+ sqlalchemy_async_engine : AsyncEngine ,
245247) -> dict [str , Any ]:
246248 # this check is only there to check the comp_pipeline is there
247249 print (f"--> looking for project { project_id = } in projects table..." )
248- with postgres_db .connect () as conn :
249- project_in_db = conn .execute (
250- sa .select (projects ).where (projects .c .uuid == project_id )
251- ).fetchone ()
252-
253- assert (
254- project_in_db
255- ), f"missing pipeline in the database under comp_pipeline { project_id } "
250+ async with sqlalchemy_async_engine .connect () as conn :
251+ project_in_db = (
252+ await conn .execute (sa .select (projects ).where (projects .c .uuid == project_id ))
253+ ).one ()
254+
256255 print (
257256 f"<-- found following workbench: { json_dumps (project_in_db .workbench , indent = 2 )} "
258257 )
@@ -295,7 +294,7 @@ async def _assert_and_wait_for_pipeline_state(
295294
296295async def _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects (
297296 project_id : str ,
298- postgres_db : sa . engine . Engine ,
297+ sqlalchemy_async_engine : AsyncEngine ,
299298) -> None :
300299 async for attempt in AsyncRetrying (
301300 reraise = True ,
@@ -307,11 +306,15 @@ async def _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
307306 print (
308307 f"--> waiting for pipeline results to move to projects table, attempt { attempt .retry_state .attempt_number } ..."
309308 )
310- comp_tasks_in_db : dict [NodeIdStr , Any ] = _get_computational_tasks_from_db (
311- project_id , postgres_db
309+ comp_tasks_in_db : dict [NodeIdStr , Any ] = (
310+ await _get_computational_tasks_from_db (
311+ project_id , sqlalchemy_async_engine
312+ )
312313 )
313- workbench_in_db : dict [NodeIdStr , Any ] = _get_project_workbench_from_db (
314- project_id , postgres_db
314+ workbench_in_db : dict [NodeIdStr , Any ] = (
315+ await _get_project_workbench_from_db (
316+ project_id , sqlalchemy_async_engine
317+ )
315318 )
316319 for node_id , node_values in comp_tasks_in_db .items ():
317320 assert (
@@ -343,7 +346,7 @@ async def _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
343346async def test_start_stop_computation (
344347 client : TestClient ,
345348 sleeper_service : dict [str , str ],
346- postgres_db : sa . engine . Engine ,
349+ sqlalchemy_async_engine : AsyncEngine ,
347350 logged_user : dict [str , Any ],
348351 user_project : dict [str , Any ],
349352 fake_workbench_adjacency_list : dict [str , Any ],
@@ -369,9 +372,9 @@ async def test_start_stop_computation(
369372 assert "pipeline_id" in data
370373 assert data ["pipeline_id" ] == project_id
371374
372- _assert_db_contents (
375+ await _assert_db_contents (
373376 project_id ,
374- postgres_db ,
377+ sqlalchemy_async_engine ,
375378 fake_workbench_payload ,
376379 fake_workbench_adjacency_list ,
377380 check_outputs = False ,
@@ -382,7 +385,7 @@ async def test_start_stop_computation(
382385 )
383386 # we need to wait until the webserver has updated the projects DB before starting another round
384387 await _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects (
385- project_id , postgres_db
388+ project_id , sqlalchemy_async_engine
386389 )
387390 # restart the computation, this should produce a 422 since the computation was complete
388391 resp = await client .post (f"{ url_start } " )
@@ -408,15 +411,15 @@ async def test_start_stop_computation(
408411 )
409412 # we need to wait until the webserver has updated the projects DB
410413 await _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects (
411- project_id , postgres_db
414+ project_id , sqlalchemy_async_engine
412415 )
413416
414417
415418@pytest .mark .parametrize (* user_role_response (), ids = str )
416419async def test_run_pipeline_and_check_state (
417420 client : TestClient ,
418421 sleeper_service : dict [str , str ],
419- postgres_db : sa . engine . Engine ,
422+ sqlalchemy_async_engine : AsyncEngine ,
420423 # logged_user: dict[str, Any],
421424 user_project : dict [str , Any ],
422425 fake_workbench_adjacency_list : dict [str , Any ],
@@ -440,9 +443,9 @@ async def test_run_pipeline_and_check_state(
440443 assert "pipeline_id" in data
441444 assert data ["pipeline_id" ] == project_id
442445
443- _assert_db_contents (
446+ await _assert_db_contents (
444447 project_id ,
445- postgres_db ,
448+ sqlalchemy_async_engine ,
446449 fake_workbench_payload ,
447450 fake_workbench_adjacency_list ,
448451 check_outputs = False ,
@@ -508,8 +511,8 @@ async def test_run_pipeline_and_check_state(
508511 f"--> pipeline completed with state { received_study_state = } ! That's great: { json_dumps (attempt .retry_state .retry_object .statistics )} " ,
509512 )
510513 assert pipeline_state == RunningState .SUCCESS
511- comp_tasks_in_db : dict [NodeIdStr , Any ] = _get_computational_tasks_from_db (
512- project_id , postgres_db
514+ comp_tasks_in_db : dict [NodeIdStr , Any ] = await _get_computational_tasks_from_db (
515+ project_id , sqlalchemy_async_engine
513516 )
514517 is_success = [t .state == StateType .SUCCESS for t in comp_tasks_in_db .values ()]
515518 assert all (is_success ), (
@@ -518,7 +521,7 @@ async def test_run_pipeline_and_check_state(
518521 )
519522 # we need to wait until the webserver has updated the projects DB
520523 await _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects (
521- project_id , postgres_db
524+ project_id , sqlalchemy_async_engine
522525 )
523526
524527 print (f"<-- pipeline completed successfully in { time .monotonic () - start } seconds" )
@@ -530,12 +533,12 @@ async def populated_project_metadata(
530533 logged_user : dict [str , Any ],
531534 user_project : dict [str , Any ],
532535 faker : Faker ,
533- postgres_db : sa . engine . Engine ,
536+ sqlalchemy_async_engine : AsyncEngine ,
534537):
535538 assert client .app
536539 project_uuid = user_project ["uuid" ]
537- with postgres_db . connect () as con :
538- con .execute (
540+ async with sqlalchemy_async_engine . begin () as con :
541+ await con .execute (
539542 projects_metadata .insert ().values (
540543 project_uuid = project_uuid ,
541544 custom = {
@@ -546,15 +549,16 @@ async def populated_project_metadata(
546549 )
547550 )
548551 yield
549- con .execute (projects_metadata .delete ())
550- con .execute (comp_runs_collections .delete ()) # cleanup
552+ async with sqlalchemy_async_engine .begin () as con :
553+ await con .execute (projects_metadata .delete ())
554+ await con .execute (comp_runs_collections .delete ()) # cleanup
551555
552556
553557@pytest .mark .parametrize (* user_role_response (), ids = str )
554558async def test_start_multiple_computation_with_the_same_collection_run_id (
555559 client : TestClient ,
556560 sleeper_service : dict [str , str ],
557- postgres_db : sa . engine . Engine ,
561+ sqlalchemy_async_engine : AsyncEngine ,
558562 populated_project_metadata : None ,
559563 logged_user : dict [str , Any ],
560564 user_project : dict [str , Any ],
@@ -584,7 +588,7 @@ async def test_start_multiple_computation_with_the_same_collection_run_id(
584588async def test_running_computation_sends_progress_updates_via_socketio (
585589 client : TestClient ,
586590 sleeper_service : dict [str , str ],
587- postgres_db : sa . engine . Engine ,
591+ sqlalchemy_async_engine : AsyncEngine ,
588592 logged_user : dict [str , Any ],
589593 user_project : dict [str , Any ],
590594 fake_workbench_adjacency_list : dict [str , Any ],
@@ -611,9 +615,9 @@ async def test_running_computation_sends_progress_updates_via_socketio(
611615 assert "pipeline_id" in data
612616 assert data ["pipeline_id" ] == project_id
613617
614- _assert_db_contents (
618+ await _assert_db_contents (
615619 project_id ,
616- postgres_db ,
620+ sqlalchemy_async_engine ,
617621 user_project ["workbench" ],
618622 fake_workbench_adjacency_list ,
619623 check_outputs = False ,
0 commit comments