diff --git a/infrahub_sdk/batch.py b/infrahub_sdk/batch.py index fa219585..43304a10 100644 --- a/infrahub_sdk/batch.py +++ b/infrahub_sdk/batch.py @@ -1,9 +1,10 @@ import asyncio from collections.abc import AsyncGenerator, Awaitable +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable, Generator, Optional -from .node import InfrahubNode +from .node import InfrahubNode, InfrahubNodeSync @dataclass @@ -14,13 +15,32 @@ class BatchTask: node: Optional[Any] = None +@dataclass +class BatchTaskSync: + task: Callable[..., Any] + args: tuple[Any, ...] + kwargs: dict[str, Any] + node: Optional[InfrahubNodeSync] = None + + def execute(self, return_exceptions: bool = False) -> tuple[Optional[InfrahubNodeSync], Any]: + """Executes the stored task.""" + result = None + try: + result = self.task(*self.args, **self.kwargs) + except Exception as exc: # pylint: disable=broad-exception-caught + if return_exceptions: + return self.node, exc + raise exc + + return self.node, result + + async def execute_batch_task_in_pool( task: BatchTask, semaphore: asyncio.Semaphore, return_exceptions: bool = False ) -> tuple[Optional[InfrahubNode], Any]: async with semaphore: try: result = await task.task(*task.args, **task.kwargs) - except Exception as exc: # pylint: disable=broad-exception-caught if return_exceptions: return (task.node, exc) @@ -64,3 +84,26 @@ async def execute(self) -> AsyncGenerator: if isinstance(result, Exception) and not self.return_exceptions: raise result yield node, result + + +class InfrahubBatchSync: + def __init__(self, max_concurrent_execution: int = 5, return_exceptions: bool = False): + self._tasks: list[BatchTaskSync] = [] + self.max_concurrent_execution = max_concurrent_execution + self.return_exceptions = return_exceptions + + @property + def num_tasks(self) -> int: + return len(self._tasks) + + def add(self, *args: Any, task: Callable[..., Any], node: Optional[Any] = None, **kwargs: Any) -> None: + self._tasks.append(BatchTaskSync(task=task, node=node, args=args, kwargs=kwargs)) + + def execute(self) -> Generator[tuple[Optional[InfrahubNodeSync], Any], None, None]: + with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor: + futures = [executor.submit(task.execute, return_exceptions=self.return_exceptions) for task in self._tasks] + for future in futures: + node, result = future.result() + if isinstance(result, Exception) and not self.return_exceptions: + raise result + yield node, result diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index cdf7c08f..00bdd19c 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -23,7 +23,7 @@ import ujson from typing_extensions import Self -from .batch import InfrahubBatch +from .batch import InfrahubBatch, InfrahubBatchSync from .branch import ( BranchData, InfrahubBranchManager, @@ -1454,9 +1454,6 @@ def delete(self, kind: Union[str, type[SchemaTypeSync]], id: str, branch: Option node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id}) node.delete() - def create_batch(self, return_exceptions: bool = False) -> InfrahubBatch: - raise NotImplementedError("This method hasn't been implemented in the sync client yet.") - def clone(self) -> InfrahubClientSync: """Return a cloned version of the client using the same configuration""" return InfrahubClientSync(config=self.config) @@ -1955,6 +1952,16 @@ def get( return results[0] + def create_batch(self, return_exceptions: bool = False) -> InfrahubBatchSync: + """Create a batch to execute multiple queries concurrently. + + Executing the batch will be performed using a thread pool, meaning it cannot guarantee the execution order. It is not recommended to use such + batch to manipulate objects that depend on each others. + """ + return InfrahubBatchSync( + max_concurrent_execution=self.max_concurrent_execution, return_exceptions=return_exceptions + ) + def get_list_repositories( self, branches: Optional[dict[str, BranchData]] = None, kind: str = "CoreGenericRepository" ) -> dict[str, RepositoryData]: diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index 7a8f5688..c3473b2e 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -55,27 +55,32 @@ async def echo_clients(clients: BothClients) -> AsyncGenerator[BothClients, None @pytest.fixture -def replace_async_return_annotation(): +def return_annotation_map() -> dict[str, str]: + return { + "type[SchemaType]": "type[SchemaTypeSync]", + "SchemaType": "SchemaTypeSync", + "CoreNode": "CoreNodeSync", + "Optional[CoreNode]": "Optional[CoreNodeSync]", + "Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]", + "Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]", + "Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]", + "Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]", + "InfrahubClient": "InfrahubClientSync", + "InfrahubNode": "InfrahubNodeSync", + "list[InfrahubNode]": "list[InfrahubNodeSync]", + "Optional[InfrahubNode]": "Optional[InfrahubNodeSync]", + "Optional[type[SchemaType]]": "Optional[type[SchemaTypeSync]]", + "Optional[Union[CoreNode, SchemaType]]": "Optional[Union[CoreNodeSync, SchemaTypeSync]]", + "InfrahubBatch": "InfrahubBatchSync", + } + + +@pytest.fixture +def replace_async_return_annotation(return_annotation_map: dict[str, str]): """Allows for comparison between sync and async return annotations.""" def replace_annotation(annotation: str) -> str: - replacements = { - "type[SchemaType]": "type[SchemaTypeSync]", - "SchemaType": "SchemaTypeSync", - "CoreNode": "CoreNodeSync", - "Optional[CoreNode]": "Optional[CoreNodeSync]", - "Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]", - "Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]", - "Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]", - "Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]", - "InfrahubClient": "InfrahubClientSync", - "InfrahubNode": "InfrahubNodeSync", - "list[InfrahubNode]": "list[InfrahubNodeSync]", - "Optional[InfrahubNode]": "Optional[InfrahubNodeSync]", - "Optional[type[SchemaType]]": "Optional[type[SchemaTypeSync]]", - "Optional[Union[CoreNode, SchemaType]]": "Optional[Union[CoreNodeSync, SchemaTypeSync]]", - } - return replacements.get(annotation) or annotation + return return_annotation_map.get(annotation) or annotation return replace_annotation @@ -95,26 +100,11 @@ def replace_annotations(parameters: Mapping[str, Parameter]) -> tuple[str, str]: @pytest.fixture -def replace_sync_return_annotation() -> str: +def replace_sync_return_annotation(return_annotation_map: dict[str, str]) -> str: """Allows for comparison between sync and async return annotations.""" def replace_annotation(annotation: str) -> str: - replacements = { - "type[SchemaTypeSync]": "type[SchemaType]", - "SchemaTypeSync": "SchemaType", - "CoreNodeSync": "CoreNode", - "Optional[CoreNodeSync]": "Optional[CoreNode]", - "Union[str, type[SchemaTypeSync]]": "Union[str, type[SchemaType]]", - "Union[InfrahubNodeSync, SchemaTypeSync]": "Union[InfrahubNode, SchemaType]", - "Union[InfrahubNodeSync, SchemaTypeSync, None]": "Union[InfrahubNode, SchemaType, None]", - "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]": "Union[list[InfrahubNode], list[SchemaType]]", - "InfrahubClientSync": "InfrahubClient", - "InfrahubNodeSync": "InfrahubNode", - "list[InfrahubNodeSync]": "list[InfrahubNode]", - "Optional[InfrahubNodeSync]": "Optional[InfrahubNode]", - "Optional[type[SchemaTypeSync]]": "Optional[type[SchemaType]]", - "Optional[Union[CoreNodeSync, SchemaTypeSync]]": "Optional[Union[CoreNode, SchemaType]]", - } + replacements = {v: k for k, v in return_annotation_map.items()} return replacements.get(annotation) or annotation return replace_annotation diff --git a/tests/unit/sdk/test_batch.py b/tests/unit/sdk/test_batch.py index 83beaa53..8b498170 100644 --- a/tests/unit/sdk/test_batch.py +++ b/tests/unit/sdk/test_batch.py @@ -11,48 +11,117 @@ from tests.unit.sdk.conftest import BothClients +client_types = ["standard", "sync"] + +@pytest.mark.parametrize("client_type", client_types) +async def test_batch_execution(clients: BothClients, client_type: str): + r: list[int] = [] + tasks_number = 10 + + if client_type == "standard": + + async def test_func() -> int: + return 1 + + batch = await clients.standard.create_batch() + for _ in range(tasks_number): + batch.add(task=test_func) + + assert batch.num_tasks == tasks_number + async for _, result in batch.execute(): + r.append(result) + else: + + def test_func() -> int: + return 1 + + batch = clients.sync.create_batch() + for _ in range(tasks_number): + batch.add(task=test_func) + + assert batch.num_tasks == tasks_number + for _, result in batch.execute(): + r.append(result) + + assert r == [1] * tasks_number + + +@pytest.mark.parametrize("client_type", client_types) async def test_batch_return_exception( - httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients + httpx_mock: HTTPXMock, + mock_query_mutation_location_create_failed, + mock_schema_query_01, + clients: BothClients, + client_type: str, ): # pylint: disable=unused-argument - batch = await clients.standard.create_batch(return_exceptions=True) - locations = ["JFK1", "JFK1"] - results = [] - for location_name in locations: - data = {"name": {"value": location_name, "is_protected": True}} - obj = await clients.standard.create(kind="BuiltinLocation", data=data) - batch.add(task=obj.save, node=obj) - results.append(obj) - - result_iter = batch.execute() - # Assert first node success - node, result = await result_iter.__anext__() - assert node == results[0] - assert not isinstance(result, Exception) - - # Assert second node failure - node, result = await result_iter.__anext__() - assert node == results[1] - assert isinstance(result, GraphQLError) - assert "An error occurred while executing the GraphQL Query" in str(result) + if client_type == "standard": + batch = await clients.standard.create_batch(return_exceptions=True) + locations = ["JFK1", "JFK1"] + results = [] + for location_name in locations: + data = {"name": {"value": location_name, "is_protected": True}} + obj = await clients.standard.create(kind="BuiltinLocation", data=data) + batch.add(task=obj.save, node=obj) + results.append(obj) + + result_iter = batch.execute() + # Assert first node success + node, result = await result_iter.__anext__() + assert node == results[0] + assert not isinstance(result, Exception) + # Assert second node failure + node, result = await result_iter.__anext__() + assert node == results[1] + assert isinstance(result, GraphQLError) + assert "An error occurred while executing the GraphQL Query" in str(result) + else: + batch = clients.sync.create_batch(return_exceptions=True) + locations = ["JFK1", "JFK1"] + results = [] + for location_name in locations: + data = {"name": {"value": location_name, "is_protected": True}} + obj = clients.sync.create(kind="BuiltinLocation", data=data) + batch.add(task=obj.save, node=obj) + results.append(obj) + results = [r for _, r in batch.execute()] + # Must have one exception and one graphqlerror + assert len(results) == 2 + assert any(isinstance(r, Exception) for r in results) + assert any(isinstance(r, GraphQLError) for r in results) + + +@pytest.mark.parametrize("client_type", client_types) async def test_batch_exception( - httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients + httpx_mock: HTTPXMock, + mock_query_mutation_location_create_failed, + mock_schema_query_01, + clients: BothClients, + client_type: str, ): # pylint: disable=unused-argument - batch = await clients.standard.create_batch(return_exceptions=False) - locations = ["JFK1", "JFK1"] - for location_name in locations: - data = {"name": {"value": location_name, "is_protected": True}} - obj = await clients.standard.create(kind="BuiltinLocation", data=data) - batch.add(task=obj.save, node=obj) - - with pytest.raises(GraphQLError) as exc: - async for _, _ in batch.execute(): - pass - assert "An error occurred while executing the GraphQL Query" in str(exc.value) - - -async def test_batch_not_implemented_sync(clients: BothClients): - with pytest.raises(NotImplementedError): - clients.sync.create_batch() + if client_type == "standard": + batch = await clients.standard.create_batch(return_exceptions=False) + locations = ["JFK1", "JFK1"] + for location_name in locations: + data = {"name": {"value": location_name, "is_protected": True}} + obj = await clients.standard.create(kind="BuiltinLocation", data=data) + batch.add(task=obj.save, node=obj) + + with pytest.raises(GraphQLError) as exc: + async for _, _ in batch.execute(): + pass + assert "An error occurred while executing the GraphQL Query" in str(exc.value) + else: + batch = clients.sync.create_batch(return_exceptions=False) + locations = ["JFK1", "JFK1"] + for location_name in locations: + data = {"name": {"value": location_name, "is_protected": True}} + obj = clients.sync.create(kind="BuiltinLocation", data=data) + batch.add(task=obj.save, node=obj) + + with pytest.raises(GraphQLError) as exc: + for _, _ in batch.execute(): + pass + assert "An error occurred while executing the GraphQL Query" in str(exc.value)