66
77
88import datetime
9- from collections .abc import Awaitable , Callable , Iterator
10- from typing import Any , AsyncIterator , cast
9+ from collections .abc import AsyncIterator , Awaitable , Callable
10+ from typing import Any , cast
1111from uuid import uuid4
1212
1313import arrow
4040
4141
4242@pytest .fixture
43- def pipeline (
44- postgres_db : sa . engine . Engine ,
45- ) -> Iterator [Callable [..., CompPipelineAtDB ]]:
43+ async def create_pipeline (
44+ sqlalchemy_async_engine : AsyncEngine ,
45+ ) -> AsyncIterator [Callable [..., Awaitable [ CompPipelineAtDB ] ]]:
4646 created_pipeline_ids : list [str ] = []
4747
48- def creator (** pipeline_kwargs ) -> CompPipelineAtDB :
48+ async def _ (** pipeline_kwargs ) -> CompPipelineAtDB :
4949 pipeline_config = {
5050 "project_id" : f"{ uuid4 ()} " ,
5151 "dag_adjacency_list" : {},
5252 "state" : StateType .NOT_STARTED ,
5353 }
5454 pipeline_config .update (** pipeline_kwargs )
55- with postgres_db .begin () as conn :
56- result = conn .execute (
55+ async with sqlalchemy_async_engine .begin () as conn :
56+ result = await conn .execute (
5757 comp_pipeline .insert ()
5858 .values (** pipeline_config )
5959 .returning (sa .literal_column ("*" ))
@@ -64,11 +64,11 @@ def creator(**pipeline_kwargs) -> CompPipelineAtDB:
6464 created_pipeline_ids .append (f"{ new_pipeline .project_id } " )
6565 return new_pipeline
6666
67- yield creator
67+ yield _
6868
6969 # cleanup
70- with postgres_db .connect () as conn :
71- conn .execute (
70+ async with sqlalchemy_async_engine .connect () as conn :
71+ await conn .execute (
7272 comp_pipeline .delete ().where (
7373 comp_pipeline .c .project_id .in_ (created_pipeline_ids )
7474 )
@@ -81,7 +81,7 @@ async def create_tasks(
8181) -> AsyncIterator [Callable [..., Awaitable [list [CompTaskAtDB ]]]]:
8282 created_task_ids : list [int ] = []
8383
84- async def creator (
84+ async def _ (
8585 user : dict [str , Any ], project : ProjectAtDB , ** overrides_kwargs
8686 ) -> list [CompTaskAtDB ]:
8787 created_tasks : list [CompTaskAtDB ] = []
@@ -144,7 +144,7 @@ async def creator(
144144 created_task_ids .extend ([t .task_id for t in created_tasks if t .task_id ])
145145 return created_tasks
146146
147- yield creator
147+ yield _
148148
149149 # cleanup
150150 async with sqlalchemy_async_engine .connect () as conn :
@@ -224,29 +224,37 @@ async def _(
224224
225225
226226@pytest .fixture
227- def cluster (
228- postgres_db : sa . engine . Engine ,
229- ) -> Iterator [Callable [..., Cluster ]]:
227+ async def create_cluster (
228+ sqlalchemy_async_engine : AsyncEngine ,
229+ ) -> AsyncIterator [Callable [..., Awaitable [ Cluster ] ]]:
230230 created_cluster_ids : list [str ] = []
231231
232- def creator (user : dict [str , Any ], ** cluster_kwargs ) -> Cluster :
232+ async def _ (user : dict [str , Any ], ** cluster_kwargs ) -> Cluster :
233+ assert "json_schema_extra" in Cluster .model_config
234+ assert isinstance (Cluster .model_config ["json_schema_extra" ], dict )
235+ assert isinstance (Cluster .model_config ["json_schema_extra" ]["examples" ], list )
236+ assert isinstance (
237+ Cluster .model_config ["json_schema_extra" ]["examples" ][1 ], dict
238+ )
233239 cluster_config = Cluster .model_config ["json_schema_extra" ]["examples" ][1 ]
234240 cluster_config ["owner" ] = user ["primary_gid" ]
235241 cluster_config .update (** cluster_kwargs )
236242 new_cluster = Cluster .model_validate (cluster_config )
237243 assert new_cluster
238244
239- with postgres_db .connect () as conn :
245+ async with sqlalchemy_async_engine .connect () as conn :
240246 # insert basic cluster
241- created_cluster = conn .execute (
242- sa .insert (clusters )
243- .values (to_clusters_db (new_cluster , only_update = False ))
244- .returning (sa .literal_column ("*" ))
247+ created_cluster = (
248+ await conn .execute (
249+ sa .insert (clusters )
250+ .values (to_clusters_db (new_cluster , only_update = False ))
251+ .returning (sa .literal_column ("*" ))
252+ )
245253 ).one ()
246254 created_cluster_ids .append (created_cluster .id )
247255 if "access_rights" in cluster_kwargs :
248256 for gid , rights in cluster_kwargs ["access_rights" ].items ():
249- conn .execute (
257+ await conn .execute (
250258 pg_insert (cluster_to_groups )
251259 .values (
252260 cluster_id = created_cluster .id ,
@@ -259,7 +267,7 @@ def creator(user: dict[str, Any], **cluster_kwargs) -> Cluster:
259267 )
260268 )
261269 access_rights_in_db = {}
262- for row in conn .execute (
270+ for row in await conn .execute (
263271 sa .select (
264272 cluster_to_groups .c .gid ,
265273 cluster_to_groups .c .read ,
@@ -287,12 +295,11 @@ def creator(user: dict[str, Any], **cluster_kwargs) -> Cluster:
287295 thumbnail = None ,
288296 )
289297
290- yield creator
298+ yield _
291299
292300 # cleanup
293- with postgres_db .connect () as conn :
294- conn .execute (
295- # pylint: disable=no-value-for-parameter
301+ async with sqlalchemy_async_engine .connect () as conn :
302+ await conn .execute (
296303 clusters .delete ().where (clusters .c .id .in_ (created_cluster_ids ))
297304 )
298305
@@ -301,8 +308,8 @@ def creator(user: dict[str, Any], **cluster_kwargs) -> Cluster:
301308async def publish_project (
302309 registered_user : Callable [..., dict [str , Any ]],
303310 project : Callable [..., Awaitable [ProjectAtDB ]],
304- pipeline : Callable [..., CompPipelineAtDB ],
305- create_tasks : Callable [..., list [CompTaskAtDB ]],
311+ create_pipeline : Callable [..., Awaitable [ CompPipelineAtDB ] ],
312+ create_tasks : Callable [..., Awaitable [ list [CompTaskAtDB ] ]],
306313 fake_workbench_without_outputs : dict [str , Any ],
307314 fake_workbench_adjacency : dict [str , Any ],
308315) -> Callable [[], Awaitable [PublishedProject ]]:
@@ -312,11 +319,11 @@ async def _() -> PublishedProject:
312319 created_project = await project (user , workbench = fake_workbench_without_outputs )
313320 return PublishedProject (
314321 project = created_project ,
315- pipeline = pipeline (
322+ pipeline = await create_pipeline (
316323 project_id = f"{ created_project .uuid } " ,
317324 dag_adjacency_list = fake_workbench_adjacency ,
318325 ),
319- tasks = create_tasks (
326+ tasks = await create_tasks (
320327 user = user , project = created_project , state = StateType .PUBLISHED
321328 ),
322329 )
@@ -335,9 +342,9 @@ async def published_project(
335342async def running_project (
336343 registered_user : Callable [..., dict [str , Any ]],
337344 project : Callable [..., Awaitable [ProjectAtDB ]],
338- pipeline : Callable [..., CompPipelineAtDB ],
339- create_tasks : Callable [..., list [CompTaskAtDB ]],
340- create_comp_run : Callable [..., CompRunsAtDB ],
345+ create_pipeline : Callable [..., Awaitable [ CompPipelineAtDB ] ],
346+ create_tasks : Callable [..., Awaitable [ list [CompTaskAtDB ] ]],
347+ create_comp_run : Callable [..., Awaitable [ CompRunsAtDB ] ],
341348 fake_workbench_without_outputs : dict [str , Any ],
342349 fake_workbench_adjacency : dict [str , Any ],
343350) -> RunningProject :
@@ -346,18 +353,18 @@ async def running_project(
346353 now_time = arrow .utcnow ().datetime
347354 return RunningProject (
348355 project = created_project ,
349- pipeline = pipeline (
356+ pipeline = await create_pipeline (
350357 project_id = f"{ created_project .uuid } " ,
351358 dag_adjacency_list = fake_workbench_adjacency ,
352359 ),
353- tasks = create_tasks (
360+ tasks = await create_tasks (
354361 user = user ,
355362 project = created_project ,
356363 state = StateType .RUNNING ,
357364 progress = 0.0 ,
358365 start = now_time ,
359366 ),
360- runs = create_comp_run (
367+ runs = await create_comp_run (
361368 user = user ,
362369 project = created_project ,
363370 started = now_time ,
@@ -371,9 +378,9 @@ async def running_project(
371378async def running_project_mark_for_cancellation (
372379 registered_user : Callable [..., dict [str , Any ]],
373380 project : Callable [..., Awaitable [ProjectAtDB ]],
374- pipeline : Callable [..., CompPipelineAtDB ],
375- create_tasks : Callable [..., list [CompTaskAtDB ]],
376- create_comp_run : Callable [..., CompRunsAtDB ],
381+ create_pipeline : Callable [..., Awaitable [ CompPipelineAtDB ] ],
382+ create_tasks : Callable [..., Awaitable [ list [CompTaskAtDB ] ]],
383+ create_comp_run : Callable [..., Awaitable [ CompRunsAtDB ] ],
377384 fake_workbench_without_outputs : dict [str , Any ],
378385 fake_workbench_adjacency : dict [str , Any ],
379386) -> RunningProject :
@@ -382,18 +389,18 @@ async def running_project_mark_for_cancellation(
382389 now_time = arrow .utcnow ().datetime
383390 return RunningProject (
384391 project = created_project ,
385- pipeline = pipeline (
392+ pipeline = await create_pipeline (
386393 project_id = f"{ created_project .uuid } " ,
387394 dag_adjacency_list = fake_workbench_adjacency ,
388395 ),
389- tasks = create_tasks (
396+ tasks = await create_tasks (
390397 user = user ,
391398 project = created_project ,
392399 state = StateType .RUNNING ,
393400 progress = 0.0 ,
394401 start = now_time ,
395402 ),
396- runs = create_comp_run (
403+ runs = await create_comp_run (
397404 user = user ,
398405 project = created_project ,
399406 result = StateType .RUNNING ,
0 commit comments