Skip to content

Commit cb59360

Browse files
committed
Add unit tests
1 parent d099d74 commit cb59360

File tree

2 files changed

+116
-36
lines changed

2 files changed

+116
-36
lines changed

infrahub_sdk/batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def execute(self, return_exceptions: bool = False) -> tuple[Optional[InfrahubNod
2727
result = None
2828
try:
2929
result = self.task(*self.args, **self.kwargs)
30-
except Exception as exc:
30+
except Exception as exc: # pylint: disable=broad-exception-caught
3131
if return_exceptions:
32-
return (self.node, exc)
32+
return self.node, exc
3333
raise exc
3434

3535
return self.node, result
@@ -101,7 +101,7 @@ def add(self, *args: Any, task: Callable[..., Any], node: Optional[Any] = None,
101101

102102
def execute(self) -> Generator[tuple[Optional[InfrahubNodeSync], Any], None, None]:
103103
with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor:
104-
futures = [executor.submit(task.execute) for task in self._tasks]
104+
futures = [executor.submit(task.execute, return_exceptions=self.return_exceptions) for task in self._tasks]
105105
for future in futures:
106106
node, result = future.result()
107107
if isinstance(result, Exception) and not self.return_exceptions:

tests/unit/sdk/test_batch.py

Lines changed: 113 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,43 +11,123 @@
1111

1212
from tests.unit.sdk.conftest import BothClients
1313

14+
client_types = ["standard", "sync"]
1415

16+
17+
@pytest.mark.parametrize("client_type", client_types)
18+
async def test_batch_execution(clients: BothClients, client_type: str):
19+
r: list[int] = []
20+
tasks_number = 10
21+
22+
if client_type == "standard":
23+
24+
async def test_func() -> int:
25+
return 1
26+
27+
batch = await clients.standard.create_batch()
28+
for _ in range(tasks_number):
29+
batch.add(task=test_func)
30+
31+
assert batch.num_tasks == tasks_number
32+
async for _, result in batch.execute():
33+
r.append(result)
34+
else:
35+
36+
def test_func() -> int:
37+
return 1
38+
39+
batch = clients.sync.create_batch()
40+
for _ in range(tasks_number):
41+
batch.add(task=test_func)
42+
43+
assert batch.num_tasks == tasks_number
44+
for _, result in batch.execute():
45+
r.append(result)
46+
47+
assert r == [1] * tasks_number
48+
49+
50+
@pytest.mark.parametrize("client_type", client_types)
1551
async def test_batch_return_exception(
16-
httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients
52+
httpx_mock: HTTPXMock,
53+
mock_query_mutation_location_create_failed,
54+
mock_schema_query_01,
55+
clients: BothClients,
56+
client_type: str,
1757
): # pylint: disable=unused-argument
18-
batch = await clients.standard.create_batch(return_exceptions=True)
19-
locations = ["JFK1", "JFK1"]
20-
results = []
21-
for location_name in locations:
22-
data = {"name": {"value": location_name, "is_protected": True}}
23-
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
24-
batch.add(task=obj.save, node=obj)
25-
results.append(obj)
26-
27-
result_iter = batch.execute()
28-
# Assert first node success
29-
node, result = await result_iter.__anext__()
30-
assert node == results[0]
31-
assert not isinstance(result, Exception)
32-
33-
# Assert second node failure
34-
node, result = await result_iter.__anext__()
35-
assert node == results[1]
36-
assert isinstance(result, GraphQLError)
37-
assert "An error occurred while executing the GraphQL Query" in str(result)
58+
if client_type == "standard":
59+
batch = await clients.standard.create_batch(return_exceptions=True)
60+
locations = ["JFK1", "JFK1"]
61+
results = []
62+
for location_name in locations:
63+
data = {"name": {"value": location_name, "is_protected": True}}
64+
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
65+
batch.add(task=obj.save, node=obj)
66+
results.append(obj)
67+
68+
result_iter = batch.execute()
69+
# Assert first node success
70+
node, result = await result_iter.__anext__()
71+
assert node == results[0]
72+
assert not isinstance(result, Exception)
73+
74+
# Assert second node failure
75+
node, result = await result_iter.__anext__()
76+
assert node == results[1]
77+
assert isinstance(result, GraphQLError)
78+
assert "An error occurred while executing the GraphQL Query" in str(result)
79+
else:
80+
batch = clients.sync.create_batch(return_exceptions=True)
81+
locations = ["JFK2", "JFK2"]
82+
results = []
83+
for location_name in locations:
84+
data = {"name": {"value": location_name, "is_protected": True}}
85+
obj = clients.sync.create(kind="BuiltinLocation", data=data)
86+
batch.add(task=obj.save, node=obj)
87+
results.append(obj)
3888

89+
result_iter = batch.execute()
90+
# Assert first node success
91+
node, result = next(result_iter)
92+
assert node == results[0]
93+
assert not isinstance(result, Exception)
3994

95+
# Assert second node failure
96+
node, result = next(result_iter)
97+
assert node == results[1]
98+
assert isinstance(result, GraphQLError)
99+
assert "An error occurred while executing the GraphQL Query" in str(result)
100+
101+
102+
@pytest.mark.parametrize("client_type", client_types)
40103
async def test_batch_exception(
41-
httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients
104+
httpx_mock: HTTPXMock,
105+
mock_query_mutation_location_create_failed,
106+
mock_schema_query_01,
107+
clients: BothClients,
108+
client_type: str,
42109
): # pylint: disable=unused-argument
43-
batch = await clients.standard.create_batch(return_exceptions=False)
44-
locations = ["JFK1", "JFK1"]
45-
for location_name in locations:
46-
data = {"name": {"value": location_name, "is_protected": True}}
47-
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
48-
batch.add(task=obj.save, node=obj)
49-
50-
with pytest.raises(GraphQLError) as exc:
51-
async for _, _ in batch.execute():
52-
pass
53-
assert "An error occurred while executing the GraphQL Query" in str(exc.value)
110+
if client_type == "standard":
111+
batch = await clients.standard.create_batch(return_exceptions=False)
112+
locations = ["JFK1", "JFK1"]
113+
for location_name in locations:
114+
data = {"name": {"value": location_name, "is_protected": True}}
115+
obj = await clients.standard.create(kind="BuiltinLocation", data=data)
116+
batch.add(task=obj.save, node=obj)
117+
118+
with pytest.raises(GraphQLError) as exc:
119+
async for _, _ in batch.execute():
120+
pass
121+
assert "An error occurred while executing the GraphQL Query" in str(exc.value)
122+
else:
123+
batch = clients.sync.create_batch(return_exceptions=False)
124+
locations = ["JFK2", "JFK2"]
125+
for location_name in locations:
126+
data = {"name": {"value": location_name, "is_protected": True}}
127+
obj = clients.sync.create(kind="BuiltinLocation", data=data)
128+
batch.add(task=obj.save, node=obj)
129+
130+
with pytest.raises(GraphQLError) as exc:
131+
for _, _ in batch.execute():
132+
pass
133+
assert "An error occurred while executing the GraphQL Query" in str(exc.value)

0 commit comments

Comments
 (0)