1- import asyncio
21import contextlib
32import logging
43from collections .abc import AsyncIterator , Callable
@@ -132,7 +131,7 @@ async def _start_tasks(
132131 RunningState .PENDING ,
133132 )
134133 # each task is started independently
135- results : list [list [PublishedComputationTask ]] = await asyncio . gather (
134+ results : list [list [PublishedComputationTask ]] = await limited_gather (
136135 * (
137136 client .send_computation_tasks (
138137 user_id = user_id ,
@@ -147,17 +146,21 @@ async def _start_tasks(
147146 )
148147 for node_id , task in scheduled_tasks .items ()
149148 ),
149+ log = _logger ,
150+ limit = MAX_CONCURRENT_PIPELINE_SCHEDULING ,
150151 )
151152
152153 # update the database so we do have the correct job_ids there
153- await asyncio . gather (
154+ await limited_gather (
154155 * (
155156 comp_tasks_repo .update_project_task_job_id (
156157 project_id , task .node_id , comp_run .run_id , task .job_id
157158 )
158159 for task_sents in results
159160 for task in task_sents
160- )
161+ ),
162+ log = _logger ,
163+ limit = MAX_CONCURRENT_PIPELINE_SCHEDULING ,
161164 )
162165
163166 async def _get_tasks_status (
@@ -279,23 +282,24 @@ async def _stop_tasks(
279282 run_id = comp_run .run_id ,
280283 run_metadata = comp_run .metadata ,
281284 ) as client :
282- await asyncio . gather (
283- * [
285+ await limited_gather (
286+ * (
284287 client .abort_computation_task (t .job_id )
285288 for t in tasks
286289 if t .job_id
287- ]
290+ ),
291+ log = _logger ,
292+ limit = MAX_CONCURRENT_PIPELINE_SCHEDULING ,
288293 )
289294 # tasks that have no-worker must be unpublished as these are blocking forever
290- tasks_with_no_worker = [
291- t for t in tasks if t .state is RunningState .WAITING_FOR_RESOURCES
292- ]
293- await asyncio .gather (
294- * [
295+ await limited_gather (
296+ * (
295297 client .release_task_result (t .job_id )
296- for t in tasks_with_no_worker
297- if t .job_id
298- ]
298+ for t in tasks
299+ if t .state is RunningState .WAITING_FOR_RESOURCES and t .job_id
300+ ),
301+ log = _logger ,
302+ limit = MAX_CONCURRENT_PIPELINE_SCHEDULING ,
299303 )
300304
301305 async def _process_completed_tasks (
@@ -313,12 +317,14 @@ async def _process_completed_tasks(
313317 run_id = comp_run .run_id ,
314318 run_metadata = comp_run .metadata ,
315319 ) as client :
316- tasks_results = await asyncio . gather (
317- * [
320+ tasks_results = await limited_gather (
321+ * (
318322 client .get_task_result (t .current .job_id or "undefined" )
319323 for t in tasks
320- ],
321- return_exceptions = True ,
324+ ),
325+ reraise = True ,
326+ log = _logger ,
327+ limit = MAX_CONCURRENT_PIPELINE_SCHEDULING ,
322328 )
323329 async for future in limited_as_completed (
324330 (
0 commit comments