Skip to content

Commit a156879

Browse files
authored
IHS-74 Add batch feature to sync client (#169)
1 parent 4df0d61 commit a156879

File tree

4 files changed

+189
-80
lines changed

4 files changed

+189
-80
lines changed

infrahub_sdk/batch.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, Awaitable
3+
from concurrent.futures import ThreadPoolExecutor
34
from dataclasses import dataclass
4-
from typing import Any, Callable, Optional
5+
from typing import Any, Callable, Generator, Optional
56

6-
from .node import InfrahubNode
7+
from .node import InfrahubNode, InfrahubNodeSync
78

89

910
@dataclass
@@ -14,13 +15,32 @@ class BatchTask:
1415
node: Optional[Any] = None
1516

1617

18+
@dataclass
19+
class BatchTaskSync:
20+
task: Callable[..., Any]
21+
args: tuple[Any, ...]
22+
kwargs: dict[str, Any]
23+
node: Optional[InfrahubNodeSync] = None
24+
25+
def execute(self, return_exceptions: bool = False) -> tuple[Optional[InfrahubNodeSync], Any]:
26+
"""Executes the stored task."""
27+
result = None
28+
try:
29+
result = self.task(*self.args, **self.kwargs)
30+
except Exception as exc: # pylint: disable=broad-exception-caught
31+
if return_exceptions:
32+
return self.node, exc
33+
raise exc
34+
35+
return self.node, result
36+
37+
1738
async def execute_batch_task_in_pool(
1839
task: BatchTask, semaphore: asyncio.Semaphore, return_exceptions: bool = False
1940
) -> tuple[Optional[InfrahubNode], Any]:
2041
async with semaphore:
2142
try:
2243
result = await task.task(*task.args, **task.kwargs)
23-
2444
except Exception as exc: # pylint: disable=broad-exception-caught
2545
if return_exceptions:
2646
return (task.node, exc)
@@ -64,3 +84,26 @@ async def execute(self) -> AsyncGenerator:
6484
if isinstance(result, Exception) and not self.return_exceptions:
6585
raise result
6686
yield node, result
87+
88+
89+
class InfrahubBatchSync:
90+
def __init__(self, max_concurrent_execution: int = 5, return_exceptions: bool = False):
91+
self._tasks: list[BatchTaskSync] = []
92+
self.max_concurrent_execution = max_concurrent_execution
93+
self.return_exceptions = return_exceptions
94+
95+
@property
96+
def num_tasks(self) -> int:
97+
return len(self._tasks)
98+
99+
def add(self, *args: Any, task: Callable[..., Any], node: Optional[Any] = None, **kwargs: Any) -> None:
100+
self._tasks.append(BatchTaskSync(task=task, node=node, args=args, kwargs=kwargs))
101+
102+
def execute(self) -> Generator[tuple[Optional[InfrahubNodeSync], Any], None, None]:
103+
with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor:
104+
futures = [executor.submit(task.execute, return_exceptions=self.return_exceptions) for task in self._tasks]
105+
for future in futures:
106+
node, result = future.result()
107+
if isinstance(result, Exception) and not self.return_exceptions:
108+
raise result
109+
yield node, result

infrahub_sdk/client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import ujson
2424
from typing_extensions import Self
2525

26-
from .batch import InfrahubBatch
26+
from .batch import InfrahubBatch, InfrahubBatchSync
2727
from .branch import (
2828
BranchData,
2929
InfrahubBranchManager,
@@ -1454,9 +1454,6 @@ def delete(self, kind: Union[str, type[SchemaTypeSync]], id: str, branch: Option
14541454
node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id})
14551455
node.delete()
14561456

1457-
def create_batch(self, return_exceptions: bool = False) -> InfrahubBatch:
1458-
raise NotImplementedError("This method hasn't been implemented in the sync client yet.")
1459-
14601457
def clone(self) -> InfrahubClientSync:
14611458
"""Return a cloned version of the client using the same configuration"""
14621459
return InfrahubClientSync(config=self.config)
@@ -1955,6 +1952,16 @@ def get(
19551952

19561953
return results[0]
19571954

1955+
def create_batch(self, return_exceptions: bool = False) -> InfrahubBatchSync:
1956+
"""Create a batch to execute multiple queries concurrently.
1957+
1958+
Executing the batch will be performed using a thread pool, meaning it cannot guarantee the execution order. It is not recommended to use such
1959+
batch to manipulate objects that depend on each others.
1960+
"""
1961+
return InfrahubBatchSync(
1962+
max_concurrent_execution=self.max_concurrent_execution, return_exceptions=return_exceptions
1963+
)
1964+
19581965
def get_list_repositories(
19591966
self, branches: Optional[dict[str, BranchData]] = None, kind: str = "CoreGenericRepository"
19601967
) -> dict[str, RepositoryData]:

tests/unit/sdk/conftest.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,32 @@ async def echo_clients(clients: BothClients) -> AsyncGenerator[BothClients, None
5555

5656

5757
@pytest.fixture
58-
def replace_async_return_annotation():
58+
def return_annotation_map() -> dict[str, str]:
59+
return {
60+
"type[SchemaType]": "type[SchemaTypeSync]",
61+
"SchemaType": "SchemaTypeSync",
62+
"CoreNode": "CoreNodeSync",
63+
"Optional[CoreNode]": "Optional[CoreNodeSync]",
64+
"Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]",
65+
"Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]",
66+
"Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]",
67+
"Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]",
68+
"InfrahubClient": "InfrahubClientSync",
69+
"InfrahubNode": "InfrahubNodeSync",
70+
"list[InfrahubNode]": "list[InfrahubNodeSync]",
71+
"Optional[InfrahubNode]": "Optional[InfrahubNodeSync]",
72+
"Optional[type[SchemaType]]": "Optional[type[SchemaTypeSync]]",
73+
"Optional[Union[CoreNode, SchemaType]]": "Optional[Union[CoreNodeSync, SchemaTypeSync]]",
74+
"InfrahubBatch": "InfrahubBatchSync",
75+
}
76+
77+
78+
@pytest.fixture
79+
def replace_async_return_annotation(return_annotation_map: dict[str, str]):
5980
"""Allows for comparison between sync and async return annotations."""
6081

6182
def replace_annotation(annotation: str) -> str:
62-
replacements = {
63-
"type[SchemaType]": "type[SchemaTypeSync]",
64-
"SchemaType": "SchemaTypeSync",
65-
"CoreNode": "CoreNodeSync",
66-
"Optional[CoreNode]": "Optional[CoreNodeSync]",
67-
"Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]",
68-
"Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]",
69-
"Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]",
70-
"Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]",
71-
"InfrahubClient": "InfrahubClientSync",
72-
"InfrahubNode": "InfrahubNodeSync",
73-
"list[InfrahubNode]": "list[InfrahubNodeSync]",
74-
"Optional[InfrahubNode]": "Optional[InfrahubNodeSync]",
75-
"Optional[type[SchemaType]]": "Optional[type[SchemaTypeSync]]",
76-
"Optional[Union[CoreNode, SchemaType]]": "Optional[Union[CoreNodeSync, SchemaTypeSync]]",
77-
}
78-
return replacements.get(annotation) or annotation
83+
return return_annotation_map.get(annotation) or annotation
7984

8085
return replace_annotation
8186

@@ -95,26 +100,11 @@ def replace_annotations(parameters: Mapping[str, Parameter]) -> tuple[str, str]:
95100

96101

97102
@pytest.fixture
98-
def replace_sync_return_annotation() -> str:
103+
def replace_sync_return_annotation(return_annotation_map: dict[str, str]) -> str:
99104
"""Allows for comparison between sync and async return annotations."""
100105

101106
def replace_annotation(annotation: str) -> str:
102-
replacements = {
103-
"type[SchemaTypeSync]": "type[SchemaType]",
104-
"SchemaTypeSync": "SchemaType",
105-
"CoreNodeSync": "CoreNode",
106-
"Optional[CoreNodeSync]": "Optional[CoreNode]",
107-
"Union[str, type[SchemaTypeSync]]": "Union[str, type[SchemaType]]",
108-
"Union[InfrahubNodeSync, SchemaTypeSync]": "Union[InfrahubNode, SchemaType]",
109-
"Union[InfrahubNodeSync, SchemaTypeSync, None]": "Union[InfrahubNode, SchemaType, None]",
110-
"Union[list[InfrahubNodeSync], list[SchemaTypeSync]]": "Union[list[InfrahubNode], list[SchemaType]]",
111-
"InfrahubClientSync": "InfrahubClient",
112-
"InfrahubNodeSync": "InfrahubNode",
113-
"list[InfrahubNodeSync]": "list[InfrahubNode]",
114-
"Optional[InfrahubNodeSync]": "Optional[InfrahubNode]",
115-
"Optional[type[SchemaTypeSync]]": "Optional[type[SchemaType]]",
116-
"Optional[Union[CoreNodeSync, SchemaTypeSync]]": "Optional[Union[CoreNode, SchemaType]]",
117-
}
107+
replacements = {v: k for k, v in return_annotation_map.items()}
118108
return replacements.get(annotation) or annotation
119109

120110
return replace_annotation

tests/unit/sdk/test_batch.py

Lines changed: 107 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,48 +11,117 @@
1111

1212
from tests.unit.sdk.conftest import BothClients
1313

14+
client_types = ["standard", "sync"]
1415

16+
17+
@pytest.mark.parametrize("client_type", client_types)
18+
async def test_batch_execution(clients: BothClients, client_type: str):
19+
r: list[int] = []
20+
tasks_number = 10
21+
22+
if client_type == "standard":
23+
24+
async def test_func() -> int:
25+
return 1
26+
27+
batch = await clients.standard.create_batch()
28+
for _ in range(tasks_number):
29+
batch.add(task=test_func)
30+
31+
assert batch.num_tasks == tasks_number
32+
async for _, result in batch.execute():
33+
r.append(result)
34+
else:
35+
36+
def test_func() -> int:
37+
return 1
38+
39+
batch = clients.sync.create_batch()
40+
for _ in range(tasks_number):
41+
batch.add(task=test_func)
42+
43+
assert batch.num_tasks == tasks_number
44+
for _, result in batch.execute():
45+
r.append(result)
46+
47+
assert r == [1] * tasks_number
48+
49+
50+
@pytest.mark.parametrize("client_type", client_types)
1551
async def test_batch_return_exception(
16-
httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients
52+
httpx_mock: HTTPXMock,
53+
mock_query_mutation_location_create_failed,
54+
mock_schema_query_01,
55+
clients: BothClients,
56+
client_type: str,
1757
): # pylint: disable=unused-argument
18-
batch = await clients.standard.create_batch(return_exceptions=True)
19-
locations = ["JFK1", "JFK1"]
20-
results = []
21-
for location_name in locations:
22-
data = {"name": {"value": location_name, "is_protected": True}}
23-
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
24-
batch.add(task=obj.save, node=obj)
25-
results.append(obj)
26-
27-
result_iter = batch.execute()
28-
# Assert first node success
29-
node, result = await result_iter.__anext__()
30-
assert node == results[0]
31-
assert not isinstance(result, Exception)
32-
33-
# Assert second node failure
34-
node, result = await result_iter.__anext__()
35-
assert node == results[1]
36-
assert isinstance(result, GraphQLError)
37-
assert "An error occurred while executing the GraphQL Query" in str(result)
58+
if client_type == "standard":
59+
batch = await clients.standard.create_batch(return_exceptions=True)
60+
locations = ["JFK1", "JFK1"]
61+
results = []
62+
for location_name in locations:
63+
data = {"name": {"value": location_name, "is_protected": True}}
64+
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
65+
batch.add(task=obj.save, node=obj)
66+
results.append(obj)
67+
68+
result_iter = batch.execute()
69+
# Assert first node success
70+
node, result = await result_iter.__anext__()
71+
assert node == results[0]
72+
assert not isinstance(result, Exception)
3873

74+
# Assert second node failure
75+
node, result = await result_iter.__anext__()
76+
assert node == results[1]
77+
assert isinstance(result, GraphQLError)
78+
assert "An error occurred while executing the GraphQL Query" in str(result)
79+
else:
80+
batch = clients.sync.create_batch(return_exceptions=True)
81+
locations = ["JFK1", "JFK1"]
82+
results = []
83+
for location_name in locations:
84+
data = {"name": {"value": location_name, "is_protected": True}}
85+
obj = clients.sync.create(kind="BuiltinLocation", data=data)
86+
batch.add(task=obj.save, node=obj)
87+
results.append(obj)
3988

89+
results = [r for _, r in batch.execute()]
90+
# Must have one exception and one graphqlerror
91+
assert len(results) == 2
92+
assert any(isinstance(r, Exception) for r in results)
93+
assert any(isinstance(r, GraphQLError) for r in results)
94+
95+
96+
@pytest.mark.parametrize("client_type", client_types)
4097
async def test_batch_exception(
41-
httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients
98+
httpx_mock: HTTPXMock,
99+
mock_query_mutation_location_create_failed,
100+
mock_schema_query_01,
101+
clients: BothClients,
102+
client_type: str,
42103
): # pylint: disable=unused-argument
43-
batch = await clients.standard.create_batch(return_exceptions=False)
44-
locations = ["JFK1", "JFK1"]
45-
for location_name in locations:
46-
data = {"name": {"value": location_name, "is_protected": True}}
47-
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
48-
batch.add(task=obj.save, node=obj)
49-
50-
with pytest.raises(GraphQLError) as exc:
51-
async for _, _ in batch.execute():
52-
pass
53-
assert "An error occurred while executing the GraphQL Query" in str(exc.value)
54-
55-
56-
async def test_batch_not_implemented_sync(clients: BothClients):
57-
with pytest.raises(NotImplementedError):
58-
clients.sync.create_batch()
104+
if client_type == "standard":
105+
batch = await clients.standard.create_batch(return_exceptions=False)
106+
locations = ["JFK1", "JFK1"]
107+
for location_name in locations:
108+
data = {"name": {"value": location_name, "is_protected": True}}
109+
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
110+
batch.add(task=obj.save, node=obj)
111+
112+
with pytest.raises(GraphQLError) as exc:
113+
async for _, _ in batch.execute():
114+
pass
115+
assert "An error occurred while executing the GraphQL Query" in str(exc.value)
116+
else:
117+
batch = clients.sync.create_batch(return_exceptions=False)
118+
locations = ["JFK1", "JFK1"]
119+
for location_name in locations:
120+
data = {"name": {"value": location_name, "is_protected": True}}
121+
obj = clients.sync.create(kind="BuiltinLocation", data=data)
122+
batch.add(task=obj.save, node=obj)
123+
124+
with pytest.raises(GraphQLError) as exc:
125+
for _, _ in batch.execute():
126+
pass
127+
assert "An error occurred while executing the GraphQL Query" in str(exc.value)

0 commit comments

Comments
 (0)