1212
1313@dataclass
1414class BatchTask :
15+ """Represents a single asynchronous task in a batch."""
1516 task : Callable [[Any ], Awaitable [Any ]]
1617 args : tuple [Any , ...]
1718 kwargs : dict [str , Any ]
@@ -20,13 +21,24 @@ class BatchTask:
2021
2122@dataclass
2223class BatchTaskSync :
24+ """Represents a single synchronous task in a batch."""
2325 task : Callable [..., Any ]
2426 args : tuple [Any , ...]
2527 kwargs : dict [str , Any ]
2628 node : InfrahubNodeSync | None = None
2729
2830 def execute (self , return_exceptions : bool = False ) -> tuple [InfrahubNodeSync | None , Any ]:
29- """Executes the stored task."""
31+ """Executes the stored synchronous task.
32+
33+ Args:
34+ return_exceptions: If True, exceptions are returned instead of raised.
35+
36+ Returns:
37+ A tuple containing the task's node (if any) and the result or exception.
38+
39+ Raises:
40+ Exception: If `return_exceptions` is False and the task raises an exception.
41+ """
3042 result = None
3143 try :
3244 result = self .task (* self .args , ** self .kwargs )
@@ -41,6 +53,16 @@ def execute(self, return_exceptions: bool = False) -> tuple[InfrahubNodeSync | N
4153async def execute_batch_task_in_pool (
4254 task : BatchTask , semaphore : asyncio .Semaphore , return_exceptions : bool = False
4355) -> tuple [InfrahubNode | None , Any ]:
56+ """Executes a BatchTask within a semaphore-controlled pool.
57+
58+ Args:
59+ task: The BatchTask to execute.
60+ semaphore: An asyncio.Semaphore to limit concurrent executions.
61+ return_exceptions: If True, exceptions are returned instead of raised.
62+
63+ Returns:
64+ A tuple containing the task's node (if any) and the result or exception.
65+ """
4466 async with semaphore :
4567 try :
4668 result = await task .task (* task .args , ** task .kwargs )
@@ -53,24 +75,51 @@ async def execute_batch_task_in_pool(
5375
5476
5577class InfrahubBatch :
78+ """Manages and executes a batch of asynchronous tasks concurrently."""
5679 def __init__ (
5780 self ,
5881 semaphore : asyncio .Semaphore | None = None ,
5982 max_concurrent_execution : int = 5 ,
6083 return_exceptions : bool = False ,
6184 ):
85+ """Initializes the InfrahubBatch.
86+
87+ Args:
88+ semaphore: An asyncio.Semaphore to limit concurrent executions.
89+ If None, a new one is created with `max_concurrent_execution`.
90+ max_concurrent_execution: The maximum number of tasks to run concurrently.
91+ Only used if `semaphore` is None.
92+ return_exceptions: If True, exceptions from tasks are returned instead of raised.
93+ """
6294 self ._tasks : list [BatchTask ] = []
6395 self .semaphore = semaphore or asyncio .Semaphore (value = max_concurrent_execution )
6496 self .return_exceptions = return_exceptions
6597
6698 @property
6799 def num_tasks (self ) -> int :
100+ """Returns the number of tasks currently in the batch."""
68101 return len (self ._tasks )
69102
70103 def add (self , * args : Any , task : Callable , node : Any | None = None , ** kwargs : Any ) -> None :
104+ """Adds a new task to the batch.
105+
106+ Args:
107+ task: The callable to be executed.
108+ node: An optional node associated with this task.
109+ *args: Positional arguments to pass to the task.
110+ **kwargs: Keyword arguments to pass to the task.
111+ """
71112 self ._tasks .append (BatchTask (task = task , node = node , args = args , kwargs = kwargs ))
72113
73- async def execute (self ) -> AsyncGenerator :
114+ async def execute (self ) -> AsyncGenerator [tuple [InfrahubNode | None , Any ], None , None ]:
115+ """Executes all tasks in the batch concurrently.
116+
117+ Yields:
118+ A tuple containing the task's node (if any) and the result or exception.
119+
120+ Raises:
121+ Exception: If `return_exceptions` is False and a task raises an exception.
122+ """
74123 tasks = []
75124
76125 for batch_task in self ._tasks :
@@ -90,19 +139,43 @@ async def execute(self) -> AsyncGenerator:
90139
91140
92141class InfrahubBatchSync :
142+ """Manages and executes a batch of synchronous tasks concurrently using a thread pool."""
93143 def __init__ (self , max_concurrent_execution : int = 5 , return_exceptions : bool = False ):
144+ """Initializes the InfrahubBatchSync.
145+
146+ Args:
147+ max_concurrent_execution: The maximum number of tasks to run concurrently in the thread pool.
148+ return_exceptions: If True, exceptions from tasks are returned instead of raised.
149+ """
94150 self ._tasks : list [BatchTaskSync ] = []
95151 self .max_concurrent_execution = max_concurrent_execution
96152 self .return_exceptions = return_exceptions
97153
98154 @property
99155 def num_tasks (self ) -> int :
156+ """Returns the number of tasks currently in the batch."""
100157 return len (self ._tasks )
101158
102159 def add (self , * args : Any , task : Callable [..., Any ], node : Any | None = None , ** kwargs : Any ) -> None :
160+ """Adds a new synchronous task to the batch.
161+
162+ Args:
163+ task: The callable to be executed.
164+ node: An optional node associated with this task.
165+ *args: Positional arguments to pass to the task.
166+ **kwargs: Keyword arguments to pass to the task.
167+ """
103168 self ._tasks .append (BatchTaskSync (task = task , node = node , args = args , kwargs = kwargs ))
104169
105170 def execute (self ) -> Generator [tuple [InfrahubNodeSync | None , Any ], None , None ]:
171+ """Executes all tasks in the batch concurrently using a ThreadPoolExecutor.
172+
173+ Yields:
174+ A tuple containing the task's node (if any) and the result or exception.
175+
176+ Raises:
177+ Exception: If `return_exceptions` is False and a task raises an exception.
178+ """
106179 with ThreadPoolExecutor (max_workers = self .max_concurrent_execution ) as executor :
107180 futures = [executor .submit (task .execute , return_exceptions = self .return_exceptions ) for task in self ._tasks ]
108181 for future in futures :
0 commit comments