Skip to content

Commit d49fc62

Browse files
committed
Add batch feature to sync client
1 parent 4df0d61 commit d49fc62

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

infrahub_sdk/batch.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, Awaitable
3+
from concurrent.futures import ThreadPoolExecutor
34
from dataclasses import dataclass
4-
from typing import Any, Callable, Optional
5+
from typing import Any, Callable, Generator, Optional
56

6-
from .node import InfrahubNode
7+
from .node import InfrahubNode, InfrahubNodeSync
78

89

910
@dataclass
@@ -14,13 +15,32 @@ class BatchTask:
1415
node: Optional[Any] = None
1516

1617

18+
@dataclass
19+
class BatchTaskSync:
20+
task: Callable[..., Any]
21+
args: tuple[Any, ...]
22+
kwargs: dict[str, Any]
23+
node: Optional[InfrahubNodeSync] = None
24+
25+
def execute(self, return_exceptions: bool = False) -> tuple[Optional[InfrahubNodeSync], Any]:
26+
"""Executes the stored task."""
27+
result = None
28+
try:
29+
result = self.task(*self.args, **self.kwargs)
30+
except Exception as exc:
31+
if return_exceptions:
32+
return (self.node, exc)
33+
raise exc
34+
35+
return self.node, result
36+
37+
1738
async def execute_batch_task_in_pool(
1839
task: BatchTask, semaphore: asyncio.Semaphore, return_exceptions: bool = False
1940
) -> tuple[Optional[InfrahubNode], Any]:
2041
async with semaphore:
2142
try:
2243
result = await task.task(*task.args, **task.kwargs)
23-
2444
except Exception as exc: # pylint: disable=broad-exception-caught
2545
if return_exceptions:
2646
return (task.node, exc)
@@ -64,3 +84,26 @@ async def execute(self) -> AsyncGenerator:
6484
if isinstance(result, Exception) and not self.return_exceptions:
6585
raise result
6686
yield node, result
87+
88+
89+
class InfrahubBatchSync:
90+
def __init__(self, max_concurrent_execution: int = 5, return_exceptions: bool = False):
91+
self._tasks: list[BatchTaskSync] = []
92+
self.max_concurrent_execution = max_concurrent_execution
93+
self.return_exceptions = return_exceptions
94+
95+
@property
96+
def num_tasks(self) -> int:
97+
return len(self._tasks)
98+
99+
def add(self, *args: Any, task: Callable[..., Any], node: Optional[Any] = None, **kwargs: Any) -> None:
100+
self._tasks.append(BatchTaskSync(task=task, node=node, args=args, kwargs=kwargs))
101+
102+
def execute(self) -> Generator[tuple[Optional[InfrahubNodeSync], Any], None, None]:
103+
with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor:
104+
futures = [executor.submit(task.execute) for task in self._tasks]
105+
for future in futures:
106+
node, result = future.result()
107+
if isinstance(result, Exception) and not self.return_exceptions:
108+
raise result
109+
yield node, result

infrahub_sdk/client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import ujson
2424
from typing_extensions import Self
2525

26-
from .batch import InfrahubBatch
26+
from .batch import InfrahubBatch, InfrahubBatchSync
2727
from .branch import (
2828
BranchData,
2929
InfrahubBranchManager,
@@ -1454,9 +1454,6 @@ def delete(self, kind: Union[str, type[SchemaTypeSync]], id: str, branch: Option
14541454
node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id})
14551455
node.delete()
14561456

1457-
def create_batch(self, return_exceptions: bool = False) -> InfrahubBatch:
1458-
raise NotImplementedError("This method hasn't been implemented in the sync client yet.")
1459-
14601457
def clone(self) -> InfrahubClientSync:
14611458
"""Return a cloned version of the client using the same configuration"""
14621459
return InfrahubClientSync(config=self.config)
@@ -1955,6 +1952,11 @@ def get(
19551952

19561953
return results[0]
19571954

1955+
def create_batch(self, return_exceptions: bool = False) -> InfrahubBatchSync:
1956+
return InfrahubBatchSync(
1957+
max_concurrent_execution=self.max_concurrent_execution, return_exceptions=return_exceptions
1958+
)
1959+
19581960
def get_list_repositories(
19591961
self, branches: Optional[dict[str, BranchData]] = None, kind: str = "CoreGenericRepository"
19601962
) -> dict[str, RepositoryData]:

0 commit comments

Comments
 (0)