Skip to content

Commit 20caa9e

Browse files
committed
moving to asyncpg
1 parent 887dcc6 commit 20caa9e

File tree

1 file changed

+67
-63
lines changed

1 file changed

+67
-63
lines changed

services/web/server/tests/integration/01/test_computation.py

Lines changed: 67 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from simcore_service_webserver.session.plugin import setup_session
5353
from simcore_service_webserver.socketio.plugin import setup_socketio
5454
from simcore_service_webserver.users.plugin import setup_users
55+
from sqlalchemy.ext.asyncio import AsyncEngine
5556
from tenacity.asyncio import AsyncRetrying
5657
from tenacity.retry import retry_if_exception_type
5758
from tenacity.stop import stop_after_delay
@@ -118,7 +119,7 @@ def user_role_response():
118119

119120
@pytest.fixture
120121
async def client(
121-
postgres_db: sa.engine.Engine,
122+
sqlalchemy_async_engine: AsyncEngine,
122123
rabbit_service: RabbitSettings,
123124
redis_settings: RedisSettings,
124125
aiohttp_client: Callable,
@@ -170,29 +171,29 @@ def fake_workbench_adjacency_list(tests_data_dir: Path) -> dict[str, Any]:
170171
return json.load(fp)
171172

172173

173-
def _assert_db_contents(
174+
async def _assert_db_contents(
174175
project_id: str,
175-
postgres_db: sa.engine.Engine,
176+
sqlalchemy_async_engine: AsyncEngine,
176177
fake_workbench_payload: dict[str, Any],
177178
fake_workbench_adjacency_list: dict[str, Any],
178179
check_outputs: bool,
179180
) -> None:
180-
with postgres_db.connect() as conn:
181-
pipeline_db = conn.execute(
182-
sa.select(comp_pipeline).where(comp_pipeline.c.project_id == project_id)
183-
).fetchone()
184-
assert pipeline_db
181+
async with sqlalchemy_async_engine.connect() as conn:
182+
pipeline_db = (
183+
await conn.execute(
184+
sa.select(comp_pipeline).where(comp_pipeline.c.project_id == project_id)
185+
)
186+
).one()
185187

186-
assert pipeline_db[comp_pipeline.c.project_id] == project_id
187-
assert (
188-
pipeline_db[comp_pipeline.c.dag_adjacency_list]
189-
== fake_workbench_adjacency_list
190-
)
188+
assert pipeline_db.project_id == project_id
189+
assert pipeline_db.dag_adjacency_list == fake_workbench_adjacency_list
191190

192191
# check db comp_tasks
193-
tasks_db = conn.execute(
194-
sa.select(comp_tasks).where(comp_tasks.c.project_id == project_id)
195-
).fetchall()
192+
tasks_db = (
193+
await conn.execute(
194+
sa.select(comp_tasks).where(comp_tasks.c.project_id == project_id)
195+
)
196+
).all()
196197
assert tasks_db
197198

198199
mock_pipeline = fake_workbench_payload
@@ -214,45 +215,43 @@ def _assert_db_contents(
214215
NodeIdStr = str
215216

216217

217-
def _get_computational_tasks_from_db(
218+
async def _get_computational_tasks_from_db(
218219
project_id: str,
219-
postgres_db: sa.engine.Engine,
220+
sqlalchemy_async_engine: AsyncEngine,
220221
) -> dict[NodeIdStr, Any]:
221222
# this check is only there to check the comp_pipeline is there
222-
with postgres_db.connect() as conn:
223+
async with sqlalchemy_async_engine.connect() as conn:
223224
assert (
224-
conn.execute(
225+
await conn.execute(
225226
sa.select(comp_pipeline).where(comp_pipeline.c.project_id == project_id)
226-
).fetchone()
227-
is not None
228-
), f"missing pipeline in the database under comp_pipeline {project_id}"
227+
)
228+
).one(), f"missing pipeline in the database under comp_pipeline {project_id}"
229229

230230
# get the computational tasks
231-
tasks_db = conn.execute(
232-
sa.select(comp_tasks).where(
233-
(comp_tasks.c.project_id == project_id)
234-
& (comp_tasks.c.node_class == NodeClass.COMPUTATIONAL)
231+
tasks_db = (
232+
await conn.execute(
233+
sa.select(comp_tasks).where(
234+
(comp_tasks.c.project_id == project_id)
235+
& (comp_tasks.c.node_class == NodeClass.COMPUTATIONAL)
236+
)
235237
)
236-
).fetchall()
238+
).all()
237239

238240
print(f"--> tasks from DB: {tasks_db=}")
239241
return {t.node_id: t for t in tasks_db}
240242

241243

242-
def _get_project_workbench_from_db(
244+
async def _get_project_workbench_from_db(
243245
project_id: str,
244-
postgres_db: sa.engine.Engine,
246+
sqlalchemy_async_engine: AsyncEngine,
245247
) -> dict[str, Any]:
246248
# this check is only there to check the comp_pipeline is there
247249
print(f"--> looking for project {project_id=} in projects table...")
248-
with postgres_db.connect() as conn:
249-
project_in_db = conn.execute(
250-
sa.select(projects).where(projects.c.uuid == project_id)
251-
).fetchone()
252-
253-
assert (
254-
project_in_db
255-
), f"missing pipeline in the database under comp_pipeline {project_id}"
250+
async with sqlalchemy_async_engine.connect() as conn:
251+
project_in_db = (
252+
await conn.execute(sa.select(projects).where(projects.c.uuid == project_id))
253+
).one()
254+
256255
print(
257256
f"<-- found following workbench: {json_dumps(project_in_db.workbench, indent=2)}"
258257
)
@@ -295,7 +294,7 @@ async def _assert_and_wait_for_pipeline_state(
295294

296295
async def _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
297296
project_id: str,
298-
postgres_db: sa.engine.Engine,
297+
sqlalchemy_async_engine: AsyncEngine,
299298
) -> None:
300299
async for attempt in AsyncRetrying(
301300
reraise=True,
@@ -307,11 +306,15 @@ async def _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
307306
print(
308307
f"--> waiting for pipeline results to move to projects table, attempt {attempt.retry_state.attempt_number}..."
309308
)
310-
comp_tasks_in_db: dict[NodeIdStr, Any] = _get_computational_tasks_from_db(
311-
project_id, postgres_db
309+
comp_tasks_in_db: dict[NodeIdStr, Any] = (
310+
await _get_computational_tasks_from_db(
311+
project_id, sqlalchemy_async_engine
312+
)
312313
)
313-
workbench_in_db: dict[NodeIdStr, Any] = _get_project_workbench_from_db(
314-
project_id, postgres_db
314+
workbench_in_db: dict[NodeIdStr, Any] = (
315+
await _get_project_workbench_from_db(
316+
project_id, sqlalchemy_async_engine
317+
)
315318
)
316319
for node_id, node_values in comp_tasks_in_db.items():
317320
assert (
@@ -343,7 +346,7 @@ async def _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
343346
async def test_start_stop_computation(
344347
client: TestClient,
345348
sleeper_service: dict[str, str],
346-
postgres_db: sa.engine.Engine,
349+
sqlalchemy_async_engine: AsyncEngine,
347350
logged_user: dict[str, Any],
348351
user_project: dict[str, Any],
349352
fake_workbench_adjacency_list: dict[str, Any],
@@ -369,9 +372,9 @@ async def test_start_stop_computation(
369372
assert "pipeline_id" in data
370373
assert data["pipeline_id"] == project_id
371374

372-
_assert_db_contents(
375+
await _assert_db_contents(
373376
project_id,
374-
postgres_db,
377+
sqlalchemy_async_engine,
375378
fake_workbench_payload,
376379
fake_workbench_adjacency_list,
377380
check_outputs=False,
@@ -382,7 +385,7 @@ async def test_start_stop_computation(
382385
)
383386
# we need to wait until the webserver has updated the projects DB before starting another round
384387
await _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
385-
project_id, postgres_db
388+
project_id, sqlalchemy_async_engine
386389
)
387390
# restart the computation, this should produce a 422 since the computation was complete
388391
resp = await client.post(f"{url_start}")
@@ -408,15 +411,15 @@ async def test_start_stop_computation(
408411
)
409412
# we need to wait until the webserver has updated the projects DB
410413
await _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
411-
project_id, postgres_db
414+
project_id, sqlalchemy_async_engine
412415
)
413416

414417

415418
@pytest.mark.parametrize(*user_role_response(), ids=str)
416419
async def test_run_pipeline_and_check_state(
417420
client: TestClient,
418421
sleeper_service: dict[str, str],
419-
postgres_db: sa.engine.Engine,
422+
sqlalchemy_async_engine: AsyncEngine,
420423
# logged_user: dict[str, Any],
421424
user_project: dict[str, Any],
422425
fake_workbench_adjacency_list: dict[str, Any],
@@ -440,9 +443,9 @@ async def test_run_pipeline_and_check_state(
440443
assert "pipeline_id" in data
441444
assert data["pipeline_id"] == project_id
442445

443-
_assert_db_contents(
446+
await _assert_db_contents(
444447
project_id,
445-
postgres_db,
448+
sqlalchemy_async_engine,
446449
fake_workbench_payload,
447450
fake_workbench_adjacency_list,
448451
check_outputs=False,
@@ -508,8 +511,8 @@ async def test_run_pipeline_and_check_state(
508511
f"--> pipeline completed with state {received_study_state=}! That's great: {json_dumps(attempt.retry_state.retry_object.statistics)}",
509512
)
510513
assert pipeline_state == RunningState.SUCCESS
511-
comp_tasks_in_db: dict[NodeIdStr, Any] = _get_computational_tasks_from_db(
512-
project_id, postgres_db
514+
comp_tasks_in_db: dict[NodeIdStr, Any] = await _get_computational_tasks_from_db(
515+
project_id, sqlalchemy_async_engine
513516
)
514517
is_success = [t.state == StateType.SUCCESS for t in comp_tasks_in_db.values()]
515518
assert all(is_success), (
@@ -518,7 +521,7 @@ async def test_run_pipeline_and_check_state(
518521
)
519522
# we need to wait until the webserver has updated the projects DB
520523
await _assert_and_wait_for_comp_task_states_to_be_transmitted_in_projects(
521-
project_id, postgres_db
524+
project_id, sqlalchemy_async_engine
522525
)
523526

524527
print(f"<-- pipeline completed successfully in {time.monotonic() - start} seconds")
@@ -530,12 +533,12 @@ async def populated_project_metadata(
530533
logged_user: dict[str, Any],
531534
user_project: dict[str, Any],
532535
faker: Faker,
533-
postgres_db: sa.engine.Engine,
536+
sqlalchemy_async_engine: AsyncEngine,
534537
):
535538
assert client.app
536539
project_uuid = user_project["uuid"]
537-
with postgres_db.connect() as con:
538-
con.execute(
540+
async with sqlalchemy_async_engine.begin() as con:
541+
await con.execute(
539542
projects_metadata.insert().values(
540543
project_uuid=project_uuid,
541544
custom={
@@ -546,15 +549,16 @@ async def populated_project_metadata(
546549
)
547550
)
548551
yield
549-
con.execute(projects_metadata.delete())
550-
con.execute(comp_runs_collections.delete()) # cleanup
552+
async with sqlalchemy_async_engine.begin() as con:
553+
await con.execute(projects_metadata.delete())
554+
await con.execute(comp_runs_collections.delete()) # cleanup
551555

552556

553557
@pytest.mark.parametrize(*user_role_response(), ids=str)
554558
async def test_start_multiple_computation_with_the_same_collection_run_id(
555559
client: TestClient,
556560
sleeper_service: dict[str, str],
557-
postgres_db: sa.engine.Engine,
561+
sqlalchemy_async_engine: AsyncEngine,
558562
populated_project_metadata: None,
559563
logged_user: dict[str, Any],
560564
user_project: dict[str, Any],
@@ -584,7 +588,7 @@ async def test_start_multiple_computation_with_the_same_collection_run_id(
584588
async def test_running_computation_sends_progress_updates_via_socketio(
585589
client: TestClient,
586590
sleeper_service: dict[str, str],
587-
postgres_db: sa.engine.Engine,
591+
sqlalchemy_async_engine: AsyncEngine,
588592
logged_user: dict[str, Any],
589593
user_project: dict[str, Any],
590594
fake_workbench_adjacency_list: dict[str, Any],
@@ -611,9 +615,9 @@ async def test_running_computation_sends_progress_updates_via_socketio(
611615
assert "pipeline_id" in data
612616
assert data["pipeline_id"] == project_id
613617

614-
_assert_db_contents(
618+
await _assert_db_contents(
615619
project_id,
616-
postgres_db,
620+
sqlalchemy_async_engine,
617621
user_project["workbench"],
618622
fake_workbench_adjacency_list,
619623
check_outputs=False,

0 commit comments

Comments
 (0)