Skip to content

Commit 42ee591

Browse files
authored
Merge pull request #299 from opsmill/dga-20250309-cont-task-enum-support
Add integration tests for Task and add support for Enum
2 parents 2ea79fe + 86d8e1d commit 42ee591

File tree

7 files changed

+215
-23
lines changed

7 files changed

+215
-23
lines changed

changelog/18.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support for Enum in GraphQL query and mutation.

infrahub_sdk/graphql.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,37 @@
11
from __future__ import annotations
22

3+
from enum import Enum
34
from typing import Any
45

56
from pydantic import BaseModel
67

78
VARIABLE_TYPE_MAPPING = ((str, "String!"), (int, "Int!"), (float, "Float!"), (bool, "Boolean!"))
89

910

10-
def convert_to_graphql_as_string(value: str | bool | list) -> str:
11+
def convert_to_graphql_as_string(value: str | bool | list | BaseModel | Enum | Any, convert_enum: bool = False) -> str: # noqa: PLR0911
1112
if isinstance(value, str) and value.startswith("$"):
1213
return value
14+
if isinstance(value, Enum):
15+
if convert_enum:
16+
return convert_to_graphql_as_string(value=value.value, convert_enum=True)
17+
return value.name
1318
if isinstance(value, str):
1419
return f'"{value}"'
1520
if isinstance(value, bool):
1621
return repr(value).lower()
1722
if isinstance(value, list):
18-
values_as_string = [convert_to_graphql_as_string(item) for item in value]
23+
values_as_string = [convert_to_graphql_as_string(value=item, convert_enum=convert_enum) for item in value]
1924
return "[" + ", ".join(values_as_string) + "]"
2025
if isinstance(value, BaseModel):
2126
data = value.model_dump()
22-
return "{ " + ", ".join(f"{key}: {convert_to_graphql_as_string(val)}" for key, val in data.items()) + " }"
27+
return (
28+
"{ "
29+
+ ", ".join(
30+
f"{key}: {convert_to_graphql_as_string(value=val, convert_enum=convert_enum)}"
31+
for key, val in data.items()
32+
)
33+
+ " }"
34+
)
2335

2436
return str(value)
2537

@@ -38,7 +50,7 @@ def render_variables_to_string(data: dict[str, type[str | int | float | bool]])
3850
return ", ".join([f"{key}: {value}" for key, value in vars_dict.items()])
3951

4052

41-
def render_query_block(data: dict, offset: int = 4, indentation: int = 4) -> list[str]:
53+
def render_query_block(data: dict, offset: int = 4, indentation: int = 4, convert_enum: bool = False) -> list[str]:
4254
FILTERS_KEY = "@filters"
4355
ALIAS_KEY = "@alias"
4456
KEYWORDS_TO_SKIP = [FILTERS_KEY, ALIAS_KEY]
@@ -60,25 +72,36 @@ def render_query_block(data: dict, offset: int = 4, indentation: int = 4) -> lis
6072

6173
if value.get(FILTERS_KEY):
6274
filters_str = ", ".join(
63-
[f"{key2}: {convert_to_graphql_as_string(value2)}" for key2, value2 in value[FILTERS_KEY].items()]
75+
[
76+
f"{key2}: {convert_to_graphql_as_string(value=value2, convert_enum=convert_enum)}"
77+
for key2, value2 in value[FILTERS_KEY].items()
78+
]
6479
)
6580
lines.append(f"{offset_str}{key_str}({filters_str}) " + "{")
6681
else:
6782
lines.append(f"{offset_str}{key_str} " + "{")
6883

69-
lines.extend(render_query_block(data=value, offset=offset + indentation, indentation=indentation))
84+
lines.extend(
85+
render_query_block(
86+
data=value, offset=offset + indentation, indentation=indentation, convert_enum=convert_enum
87+
)
88+
)
7089
lines.append(offset_str + "}")
7190

7291
return lines
7392

7493

75-
def render_input_block(data: dict, offset: int = 4, indentation: int = 4) -> list[str]:
94+
def render_input_block(data: dict, offset: int = 4, indentation: int = 4, convert_enum: bool = False) -> list[str]:
7695
offset_str = " " * offset
7796
lines = []
7897
for key, value in data.items():
7998
if isinstance(value, dict):
8099
lines.append(f"{offset_str}{key}: " + "{")
81-
lines.extend(render_input_block(data=value, offset=offset + indentation, indentation=indentation))
100+
lines.extend(
101+
render_input_block(
102+
data=value, offset=offset + indentation, indentation=indentation, convert_enum=convert_enum
103+
)
104+
)
82105
lines.append(offset_str + "}")
83106
elif isinstance(value, list):
84107
lines.append(f"{offset_str}{key}: " + "[")
@@ -90,14 +113,17 @@ def render_input_block(data: dict, offset: int = 4, indentation: int = 4) -> lis
90113
data=item,
91114
offset=offset + indentation + indentation,
92115
indentation=indentation,
116+
convert_enum=convert_enum,
93117
)
94118
)
95119
lines.append(f"{offset_str}{' ' * indentation}" + "},")
96120
else:
97-
lines.append(f"{offset_str}{' ' * indentation}{convert_to_graphql_as_string(item)},")
121+
lines.append(
122+
f"{offset_str}{' ' * indentation}{convert_to_graphql_as_string(value=item, convert_enum=convert_enum)},"
123+
)
98124
lines.append(offset_str + "]")
99125
else:
100-
lines.append(f"{offset_str}{key}: {convert_to_graphql_as_string(value)}")
126+
lines.append(f"{offset_str}{key}: {convert_to_graphql_as_string(value=value, convert_enum=convert_enum)}")
101127
return lines
102128

103129

@@ -127,9 +153,13 @@ def render_first_line(self) -> str:
127153
class Query(BaseGraphQLQuery):
128154
query_type = "query"
129155

130-
def render(self) -> str:
156+
def render(self, convert_enum: bool = False) -> str:
131157
lines = [self.render_first_line()]
132-
lines.extend(render_query_block(data=self.query, indentation=self.indentation, offset=self.indentation))
158+
lines.extend(
159+
render_query_block(
160+
data=self.query, indentation=self.indentation, offset=self.indentation, convert_enum=convert_enum
161+
)
162+
)
133163
lines.append("}")
134164

135165
return "\n" + "\n".join(lines) + "\n"
@@ -143,14 +173,15 @@ def __init__(self, *args: Any, mutation: str, input_data: dict, **kwargs: Any):
143173
self.mutation = mutation
144174
super().__init__(*args, **kwargs)
145175

146-
def render(self) -> str:
176+
def render(self, convert_enum: bool = False) -> str:
147177
lines = [self.render_first_line()]
148178
lines.append(" " * self.indentation + f"{self.mutation}(")
149179
lines.extend(
150180
render_input_block(
151181
data=self.input_data,
152182
indentation=self.indentation,
153183
offset=self.indentation * 2,
184+
convert_enum=convert_enum,
154185
)
155186
)
156187
lines.append(" " * self.indentation + "){")
@@ -159,6 +190,7 @@ def render(self) -> str:
159190
data=self.query,
160191
indentation=self.indentation,
161192
offset=self.indentation * 2,
193+
convert_enum=convert_enum,
162194
)
163195
)
164196
lines.append(" " * self.indentation + "}")

infrahub_sdk/task/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1+
from __future__ import annotations
12

3+
from .models import Task, TaskFilter, TaskLog, TaskRelatedNode, TaskState
4+
5+
__all__ = [
6+
"Task",
7+
"TaskFilter",
8+
"TaskLog",
9+
"TaskRelatedNode",
10+
"TaskState",
11+
]

infrahub_sdk/task/manager.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515

1616
class InfraHubTaskManagerBase:
17+
@classmethod
1718
def _generate_query(
18-
self,
19+
cls,
1920
filters: TaskFilter | None = None,
2021
include_logs: bool = False,
2122
include_related_nodes: bool = False,
@@ -69,7 +70,8 @@ def _generate_query(
6970

7071
return Query(query=query)
7172

72-
def _generate_count_query(self, filters: TaskFilter | None = None) -> Query:
73+
@classmethod
74+
def _generate_count_query(cls, filters: TaskFilter | None = None) -> Query:
7375
query: dict[str, Any] = {
7476
"InfrahubTask": {
7577
"count": None,
@@ -98,7 +100,9 @@ async def count(self, filters: TaskFilter | None = None) -> int:
98100
"""
99101

100102
query = self._generate_count_query(filters=filters)
101-
response = await self.client.execute_graphql(query=query.render(), tracker="query-tasks-count")
103+
response = await self.client.execute_graphql(
104+
query=query.render(convert_enum=False), tracker="query-tasks-count"
105+
)
102106
return int(response["InfrahubTask"]["count"])
103107

104108
async def all(
@@ -236,7 +240,7 @@ async def process_page(
236240
"""
237241

238242
response = await client.execute_graphql(
239-
query=query.render(),
243+
query=query.render(convert_enum=False),
240244
tracker=f"query-tasks-page{page_number}",
241245
timeout=timeout,
242246
)
@@ -298,6 +302,7 @@ async def process_non_batch(
298302
limit=self.client.pagination_size,
299303
include_logs=include_logs,
300304
include_related_nodes=include_related_nodes,
305+
count=True,
301306
)
302307
new_tasks, count = await self.process_page(
303308
client=self.client, query=query, page_number=page_number, timeout=timeout
@@ -330,7 +335,7 @@ def count(self, filters: TaskFilter | None = None) -> int:
330335
"""
331336

332337
query = self._generate_count_query(filters=filters)
333-
response = self.client.execute_graphql(query=query.render(), tracker="query-tasks-count")
338+
response = self.client.execute_graphql(query=query.render(convert_enum=False), tracker="query-tasks-count")
334339
return int(response["InfrahubTask"]["count"])
335340

336341
def all(
@@ -468,7 +473,7 @@ def process_page(
468473
"""
469474

470475
response = client.execute_graphql(
471-
query=query.render(),
476+
query=query.render(convert_enum=False),
472477
tracker=f"query-tasks-page{page_number}",
473478
timeout=timeout,
474479
)
@@ -530,6 +535,7 @@ def process_non_batch(
530535
limit=self.client.pagination_size,
531536
include_logs=include_logs,
532537
include_related_nodes=include_related_nodes,
538+
count=True,
533539
)
534540
new_tasks, count = self.process_page(
535541
client=self.client, query=query, page_number=page_number, timeout=timeout

tests/integration/test_infrahub_client.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from infrahub_sdk.exceptions import BranchNotFoundError, URLNotFoundError
99
from infrahub_sdk.node import InfrahubNode
1010
from infrahub_sdk.schema import ProfileSchemaAPI
11+
from infrahub_sdk.task.models import Task, TaskFilter, TaskLog, TaskState
1112
from infrahub_sdk.testing.docker import TestInfrahubDockerClient
1213
from infrahub_sdk.testing.schemas.animal import TESTING_ANIMAL, TESTING_CAT, TESTING_DOG, TESTING_PERSON, SchemaAnimal
1314

@@ -31,6 +32,13 @@ async def base_dataset(
3132
):
3233
await client.branch.create(branch_name="branch01")
3334

35+
@pytest.fixture
36+
async def set_pagination_size3(self, client: InfrahubClient):
37+
original_pagination_size = client.pagination_size
38+
client.pagination_size = 3
39+
yield
40+
client.pagination_size = original_pagination_size
41+
3442
async def test_query_branches(self, client: InfrahubClient, base_dataset):
3543
branches = await client.branch.all()
3644
main = await client.branch.get(branch_name="main")
@@ -162,6 +170,44 @@ async def test_create_generic_rel_with_hfid(
162170
person_sophia = await client.get(kind=TESTING_PERSON, id=person_sophia.id, prefetch_relationships=True)
163171
assert person_sophia.favorite_animal.id == cat_luna.id
164172

173+
async def test_task_query(self, client: InfrahubClient, base_dataset, set_pagination_size3):
174+
nbr_tasks = await client.task.count()
175+
assert nbr_tasks
176+
177+
tasks = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]))
178+
assert tasks
179+
task_ids = [task.id for task in tasks]
180+
181+
# Query Tasks using Parallel mode
182+
tasks_parallel = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]), parallel=True)
183+
assert tasks_parallel
184+
assert len(tasks_parallel) == len(tasks)
185+
186+
# Query Tasks by ID
187+
tasks_parallel_filtered = await client.task.filter(filter=TaskFilter(ids=task_ids[:2]), parallel=True)
188+
assert tasks_parallel_filtered
189+
assert len(tasks_parallel_filtered) == 2
190+
191+
# Query individual Task
192+
task = await client.task.get(id=tasks[0].id)
193+
assert task
194+
assert isinstance(task, Task)
195+
assert task.logs == []
196+
197+
# Wait for Task completion
198+
task = await client.task.wait_for_completion(id=tasks[0].id)
199+
assert task
200+
assert isinstance(task, Task)
201+
202+
# Query Tasks with logs
203+
tasks = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]), include_logs=True)
204+
all_logs = [log for task in tasks for log in task.logs]
205+
assert all_logs
206+
assert isinstance(all_logs[0], TaskLog)
207+
assert all_logs[0].message
208+
assert all_logs[0].timestamp
209+
assert all_logs[0].severity
210+
165211
# async def test_get_generic_filter_source(self, client: InfrahubClient, base_dataset):
166212
# admin = await client.get(kind="CoreAccount", name__value="admin")
167213

0 commit comments

Comments
 (0)