11import asyncio
22from collections .abc import AsyncGenerator , Awaitable
3+ from concurrent .futures import ThreadPoolExecutor
34from dataclasses import dataclass
4- from typing import Any , Callable , Optional
5+ from typing import Any , Callable , Generator , Optional
56
6- from .node import InfrahubNode
7+ from .node import InfrahubNode , InfrahubNodeSync
78
89
910@dataclass
@@ -14,13 +15,32 @@ class BatchTask:
1415 node : Optional [Any ] = None
1516
1617
18+ @dataclass
19+ class BatchTaskSync :
20+ task : Callable [..., Any ]
21+ args : tuple [Any , ...]
22+ kwargs : dict [str , Any ]
23+ node : Optional [InfrahubNodeSync ] = None
24+
25+ def execute (self , return_exceptions : bool = False ) -> tuple [Optional [InfrahubNodeSync ], Any ]:
26+ """Executes the stored task."""
27+ result = None
28+ try :
29+ result = self .task (* self .args , ** self .kwargs )
30+ except Exception as exc :
31+ if return_exceptions :
32+ return (self .node , exc )
33+ raise exc
34+
35+ return self .node , result
36+
37+
1738async def execute_batch_task_in_pool (
1839 task : BatchTask , semaphore : asyncio .Semaphore , return_exceptions : bool = False
1940) -> tuple [Optional [InfrahubNode ], Any ]:
2041 async with semaphore :
2142 try :
2243 result = await task .task (* task .args , ** task .kwargs )
23-
2444 except Exception as exc : # pylint: disable=broad-exception-caught
2545 if return_exceptions :
2646 return (task .node , exc )
@@ -64,3 +84,26 @@ async def execute(self) -> AsyncGenerator:
6484 if isinstance (result , Exception ) and not self .return_exceptions :
6585 raise result
6686 yield node , result
87+
88+
89+ class InfrahubBatchSync :
90+ def __init__ (self , max_concurrent_execution : int = 5 , return_exceptions : bool = False ):
91+ self ._tasks : list [BatchTaskSync ] = []
92+ self .max_concurrent_execution = max_concurrent_execution
93+ self .return_exceptions = return_exceptions
94+
95+ @property
96+ def num_tasks (self ) -> int :
97+ return len (self ._tasks )
98+
99+ def add (self , * args : Any , task : Callable [..., Any ], node : Optional [Any ] = None , ** kwargs : Any ) -> None :
100+ self ._tasks .append (BatchTaskSync (task = task , node = node , args = args , kwargs = kwargs ))
101+
102+ def execute (self ) -> Generator [tuple [Optional [InfrahubNodeSync ], Any ], None , None ]:
103+ with ThreadPoolExecutor (max_workers = self .max_concurrent_execution ) as executor :
104+ futures = [executor .submit (task .execute ) for task in self ._tasks ]
105+ for future in futures :
106+ node , result = future .result ()
107+ if isinstance (result , Exception ) and not self .return_exceptions :
108+ raise result
109+ yield node , result
0 commit comments