Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions infrahub_sdk/batch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from collections.abc import AsyncGenerator, Awaitable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, Generator, Optional

from .node import InfrahubNode
from .node import InfrahubNode, InfrahubNodeSync


@dataclass
Expand All @@ -14,13 +15,32 @@
node: Optional[Any] = None


@dataclass
class BatchTaskSync:
task: Callable[..., Any]
args: tuple[Any, ...]
kwargs: dict[str, Any]
node: Optional[InfrahubNodeSync] = None

def execute(self, return_exceptions: bool = False) -> tuple[Optional[InfrahubNodeSync], Any]:
"""Executes the stored task."""
result = None
try:
result = self.task(*self.args, **self.kwargs)
except Exception as exc: # pylint: disable=broad-exception-caught
if return_exceptions:
return self.node, exc
raise exc

return self.node, result


async def execute_batch_task_in_pool(
task: BatchTask, semaphore: asyncio.Semaphore, return_exceptions: bool = False
) -> tuple[Optional[InfrahubNode], Any]:
async with semaphore:
try:
result = await task.task(*task.args, **task.kwargs)

except Exception as exc: # pylint: disable=broad-exception-caught
if return_exceptions:
return (task.node, exc)
Expand Down Expand Up @@ -64,3 +84,26 @@
if isinstance(result, Exception) and not self.return_exceptions:
raise result
yield node, result


class InfrahubBatchSync:
def __init__(self, max_concurrent_execution: int = 5, return_exceptions: bool = False):
self._tasks: list[BatchTaskSync] = []
self.max_concurrent_execution = max_concurrent_execution
self.return_exceptions = return_exceptions

@property
def num_tasks(self) -> int:
return len(self._tasks)

def add(self, *args: Any, task: Callable[..., Any], node: Optional[Any] = None, **kwargs: Any) -> None:
self._tasks.append(BatchTaskSync(task=task, node=node, args=args, kwargs=kwargs))

def execute(self) -> Generator[tuple[Optional[InfrahubNodeSync], Any], None, None]:
with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor:
futures = [executor.submit(task.execute, return_exceptions=self.return_exceptions) for task in self._tasks]
for future in futures:
node, result = future.result()
if isinstance(result, Exception) and not self.return_exceptions:
raise result

Check warning on line 108 in infrahub_sdk/batch.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/batch.py#L108

Added line #L108 was not covered by tests
yield node, result
15 changes: 11 additions & 4 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import ujson
from typing_extensions import Self

from .batch import InfrahubBatch
from .batch import InfrahubBatch, InfrahubBatchSync
from .branch import (
BranchData,
InfrahubBranchManager,
Expand Down Expand Up @@ -1454,9 +1454,6 @@ def delete(self, kind: Union[str, type[SchemaTypeSync]], id: str, branch: Option
node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id})
node.delete()

def create_batch(self, return_exceptions: bool = False) -> InfrahubBatch:
raise NotImplementedError("This method hasn't been implemented in the sync client yet.")

def clone(self) -> InfrahubClientSync:
"""Return a cloned version of the client using the same configuration"""
return InfrahubClientSync(config=self.config)
Expand Down Expand Up @@ -1955,6 +1952,16 @@ def get(

return results[0]

def create_batch(self, return_exceptions: bool = False) -> InfrahubBatchSync:
"""Create a batch to execute multiple queries concurrently.

Executing the batch will be performed using a thread pool, meaning it cannot guarantee the execution order. It is not recommended to use such
batch to manipulate objects that depend on each others.
"""
return InfrahubBatchSync(
max_concurrent_execution=self.max_concurrent_execution, return_exceptions=return_exceptions
)

def get_list_repositories(
self, branches: Optional[dict[str, BranchData]] = None, kind: str = "CoreGenericRepository"
) -> dict[str, RepositoryData]:
Expand Down
60 changes: 25 additions & 35 deletions tests/unit/sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,32 @@ async def echo_clients(clients: BothClients) -> AsyncGenerator[BothClients, None


@pytest.fixture
def replace_async_return_annotation():
def return_annotation_map() -> dict[str, str]:
return {
"type[SchemaType]": "type[SchemaTypeSync]",
"SchemaType": "SchemaTypeSync",
"CoreNode": "CoreNodeSync",
"Optional[CoreNode]": "Optional[CoreNodeSync]",
"Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]",
"Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]",
"Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]",
"Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]",
"InfrahubClient": "InfrahubClientSync",
"InfrahubNode": "InfrahubNodeSync",
"list[InfrahubNode]": "list[InfrahubNodeSync]",
"Optional[InfrahubNode]": "Optional[InfrahubNodeSync]",
"Optional[type[SchemaType]]": "Optional[type[SchemaTypeSync]]",
"Optional[Union[CoreNode, SchemaType]]": "Optional[Union[CoreNodeSync, SchemaTypeSync]]",
"InfrahubBatch": "InfrahubBatchSync",
}


@pytest.fixture
def replace_async_return_annotation(return_annotation_map: dict[str, str]):
"""Allows for comparison between sync and async return annotations."""

def replace_annotation(annotation: str) -> str:
replacements = {
"type[SchemaType]": "type[SchemaTypeSync]",
"SchemaType": "SchemaTypeSync",
"CoreNode": "CoreNodeSync",
"Optional[CoreNode]": "Optional[CoreNodeSync]",
"Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]",
"Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]",
"Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]",
"Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]",
"InfrahubClient": "InfrahubClientSync",
"InfrahubNode": "InfrahubNodeSync",
"list[InfrahubNode]": "list[InfrahubNodeSync]",
"Optional[InfrahubNode]": "Optional[InfrahubNodeSync]",
"Optional[type[SchemaType]]": "Optional[type[SchemaTypeSync]]",
"Optional[Union[CoreNode, SchemaType]]": "Optional[Union[CoreNodeSync, SchemaTypeSync]]",
}
return replacements.get(annotation) or annotation
return return_annotation_map.get(annotation) or annotation

return replace_annotation

Expand All @@ -95,26 +100,11 @@ def replace_annotations(parameters: Mapping[str, Parameter]) -> tuple[str, str]:


@pytest.fixture
def replace_sync_return_annotation() -> str:
def replace_sync_return_annotation(return_annotation_map: dict[str, str]) -> str:
"""Allows for comparison between sync and async return annotations."""

def replace_annotation(annotation: str) -> str:
replacements = {
"type[SchemaTypeSync]": "type[SchemaType]",
"SchemaTypeSync": "SchemaType",
"CoreNodeSync": "CoreNode",
"Optional[CoreNodeSync]": "Optional[CoreNode]",
"Union[str, type[SchemaTypeSync]]": "Union[str, type[SchemaType]]",
"Union[InfrahubNodeSync, SchemaTypeSync]": "Union[InfrahubNode, SchemaType]",
"Union[InfrahubNodeSync, SchemaTypeSync, None]": "Union[InfrahubNode, SchemaType, None]",
"Union[list[InfrahubNodeSync], list[SchemaTypeSync]]": "Union[list[InfrahubNode], list[SchemaType]]",
"InfrahubClientSync": "InfrahubClient",
"InfrahubNodeSync": "InfrahubNode",
"list[InfrahubNodeSync]": "list[InfrahubNode]",
"Optional[InfrahubNodeSync]": "Optional[InfrahubNode]",
"Optional[type[SchemaTypeSync]]": "Optional[type[SchemaType]]",
"Optional[Union[CoreNodeSync, SchemaTypeSync]]": "Optional[Union[CoreNode, SchemaType]]",
}
replacements = {v: k for k, v in return_annotation_map.items()}
return replacements.get(annotation) or annotation

return replace_annotation
Expand Down
145 changes: 107 additions & 38 deletions tests/unit/sdk/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,48 +11,117 @@

from tests.unit.sdk.conftest import BothClients

client_types = ["standard", "sync"]


@pytest.mark.parametrize("client_type", client_types)
async def test_batch_execution(clients: BothClients, client_type: str):
r: list[int] = []
tasks_number = 10

if client_type == "standard":

async def test_func() -> int:
return 1

batch = await clients.standard.create_batch()
for _ in range(tasks_number):
batch.add(task=test_func)

assert batch.num_tasks == tasks_number
async for _, result in batch.execute():
r.append(result)
else:

def test_func() -> int:
return 1

batch = clients.sync.create_batch()
for _ in range(tasks_number):
batch.add(task=test_func)

assert batch.num_tasks == tasks_number
for _, result in batch.execute():
r.append(result)

assert r == [1] * tasks_number


@pytest.mark.parametrize("client_type", client_types)
async def test_batch_return_exception(
httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients
httpx_mock: HTTPXMock,
mock_query_mutation_location_create_failed,
mock_schema_query_01,
clients: BothClients,
client_type: str,
): # pylint: disable=unused-argument
batch = await clients.standard.create_batch(return_exceptions=True)
locations = ["JFK1", "JFK1"]
results = []
for location_name in locations:
data = {"name": {"value": location_name, "is_protected": True}}
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
batch.add(task=obj.save, node=obj)
results.append(obj)

result_iter = batch.execute()
# Assert first node success
node, result = await result_iter.__anext__()
assert node == results[0]
assert not isinstance(result, Exception)

# Assert second node failure
node, result = await result_iter.__anext__()
assert node == results[1]
assert isinstance(result, GraphQLError)
assert "An error occurred while executing the GraphQL Query" in str(result)
if client_type == "standard":
batch = await clients.standard.create_batch(return_exceptions=True)
locations = ["JFK1", "JFK1"]
results = []
for location_name in locations:
data = {"name": {"value": location_name, "is_protected": True}}
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
batch.add(task=obj.save, node=obj)
results.append(obj)

result_iter = batch.execute()
# Assert first node success
node, result = await result_iter.__anext__()
assert node == results[0]
assert not isinstance(result, Exception)

# Assert second node failure
node, result = await result_iter.__anext__()
assert node == results[1]
assert isinstance(result, GraphQLError)
assert "An error occurred while executing the GraphQL Query" in str(result)
else:
batch = clients.sync.create_batch(return_exceptions=True)
locations = ["JFK1", "JFK1"]
results = []
for location_name in locations:
data = {"name": {"value": location_name, "is_protected": True}}
obj = clients.sync.create(kind="BuiltinLocation", data=data)
batch.add(task=obj.save, node=obj)
results.append(obj)

results = [r for _, r in batch.execute()]
# Must have one exception and one graphqlerror
assert len(results) == 2
assert any(isinstance(r, Exception) for r in results)
assert any(isinstance(r, GraphQLError) for r in results)


@pytest.mark.parametrize("client_type", client_types)
async def test_batch_exception(
httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients
httpx_mock: HTTPXMock,
mock_query_mutation_location_create_failed,
mock_schema_query_01,
clients: BothClients,
client_type: str,
): # pylint: disable=unused-argument
batch = await clients.standard.create_batch(return_exceptions=False)
locations = ["JFK1", "JFK1"]
for location_name in locations:
data = {"name": {"value": location_name, "is_protected": True}}
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
batch.add(task=obj.save, node=obj)

with pytest.raises(GraphQLError) as exc:
async for _, _ in batch.execute():
pass
assert "An error occurred while executing the GraphQL Query" in str(exc.value)


async def test_batch_not_implemented_sync(clients: BothClients):
with pytest.raises(NotImplementedError):
clients.sync.create_batch()
if client_type == "standard":
batch = await clients.standard.create_batch(return_exceptions=False)
locations = ["JFK1", "JFK1"]
for location_name in locations:
data = {"name": {"value": location_name, "is_protected": True}}
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
batch.add(task=obj.save, node=obj)

with pytest.raises(GraphQLError) as exc:
async for _, _ in batch.execute():
pass
assert "An error occurred while executing the GraphQL Query" in str(exc.value)
else:
batch = clients.sync.create_batch(return_exceptions=False)
locations = ["JFK1", "JFK1"]
for location_name in locations:
data = {"name": {"value": location_name, "is_protected": True}}
obj = clients.sync.create(kind="BuiltinLocation", data=data)
batch.add(task=obj.save, node=obj)

with pytest.raises(GraphQLError) as exc:
for _, _ in batch.execute():
pass
assert "An error occurred while executing the GraphQL Query" in str(exc.value)
Loading