Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/18.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for Enum in GraphQL query and mutation.
58 changes: 45 additions & 13 deletions infrahub_sdk/graphql.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
from __future__ import annotations

from enum import Enum
from typing import Any

from pydantic import BaseModel

VARIABLE_TYPE_MAPPING = ((str, "String!"), (int, "Int!"), (float, "Float!"), (bool, "Boolean!"))


def convert_to_graphql_as_string(value: str | bool | list) -> str:
def convert_to_graphql_as_string(value: str | bool | list | BaseModel | Enum | Any, convert_enum: bool = False) -> str: # noqa: PLR0911
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some other places for PLR0911 we've assigned the return value within multiple if-statements / match cases like this to avoid having a high number of returns.

if isinstance(value, str) and value.startswith("$"):
return value
if isinstance(value, Enum):
if convert_enum:
return convert_to_graphql_as_string(value=value.value, convert_enum=True)
return value.name
if isinstance(value, str):
return f'"{value}"'
if isinstance(value, bool):
return repr(value).lower()
if isinstance(value, list):
values_as_string = [convert_to_graphql_as_string(item) for item in value]
values_as_string = [convert_to_graphql_as_string(value=item, convert_enum=convert_enum) for item in value]
return "[" + ", ".join(values_as_string) + "]"
if isinstance(value, BaseModel):
data = value.model_dump()
return "{ " + ", ".join(f"{key}: {convert_to_graphql_as_string(val)}" for key, val in data.items()) + " }"
return (
"{ "
+ ", ".join(
f"{key}: {convert_to_graphql_as_string(value=val, convert_enum=convert_enum)}"
for key, val in data.items()
)
+ " }"
)

return str(value)

Expand All @@ -38,7 +50,7 @@ def render_variables_to_string(data: dict[str, type[str | int | float | bool]])
return ", ".join([f"{key}: {value}" for key, value in vars_dict.items()])


def render_query_block(data: dict, offset: int = 4, indentation: int = 4) -> list[str]:
def render_query_block(data: dict, offset: int = 4, indentation: int = 4, convert_enum: bool = False) -> list[str]:
FILTERS_KEY = "@filters"
ALIAS_KEY = "@alias"
KEYWORDS_TO_SKIP = [FILTERS_KEY, ALIAS_KEY]
Expand All @@ -60,25 +72,36 @@ def render_query_block(data: dict, offset: int = 4, indentation: int = 4) -> lis

if value.get(FILTERS_KEY):
filters_str = ", ".join(
[f"{key2}: {convert_to_graphql_as_string(value2)}" for key2, value2 in value[FILTERS_KEY].items()]
[
f"{key2}: {convert_to_graphql_as_string(value=value2, convert_enum=convert_enum)}"
for key2, value2 in value[FILTERS_KEY].items()
]
)
lines.append(f"{offset_str}{key_str}({filters_str}) " + "{")
else:
lines.append(f"{offset_str}{key_str} " + "{")

lines.extend(render_query_block(data=value, offset=offset + indentation, indentation=indentation))
lines.extend(
render_query_block(
data=value, offset=offset + indentation, indentation=indentation, convert_enum=convert_enum
)
)
lines.append(offset_str + "}")

return lines


def render_input_block(data: dict, offset: int = 4, indentation: int = 4) -> list[str]:
def render_input_block(data: dict, offset: int = 4, indentation: int = 4, convert_enum: bool = False) -> list[str]:
offset_str = " " * offset
lines = []
for key, value in data.items():
if isinstance(value, dict):
lines.append(f"{offset_str}{key}: " + "{")
lines.extend(render_input_block(data=value, offset=offset + indentation, indentation=indentation))
lines.extend(
render_input_block(
data=value, offset=offset + indentation, indentation=indentation, convert_enum=convert_enum
)
)
lines.append(offset_str + "}")
elif isinstance(value, list):
lines.append(f"{offset_str}{key}: " + "[")
Expand All @@ -90,14 +113,17 @@ def render_input_block(data: dict, offset: int = 4, indentation: int = 4) -> lis
data=item,
offset=offset + indentation + indentation,
indentation=indentation,
convert_enum=convert_enum,
)
)
lines.append(f"{offset_str}{' ' * indentation}" + "},")
else:
lines.append(f"{offset_str}{' ' * indentation}{convert_to_graphql_as_string(item)},")
lines.append(
f"{offset_str}{' ' * indentation}{convert_to_graphql_as_string(value=item, convert_enum=convert_enum)},"
)
lines.append(offset_str + "]")
else:
lines.append(f"{offset_str}{key}: {convert_to_graphql_as_string(value)}")
lines.append(f"{offset_str}{key}: {convert_to_graphql_as_string(value=value, convert_enum=convert_enum)}")
return lines


Expand Down Expand Up @@ -127,9 +153,13 @@ def render_first_line(self) -> str:
class Query(BaseGraphQLQuery):
query_type = "query"

def render(self) -> str:
def render(self, convert_enum: bool = False) -> str:
lines = [self.render_first_line()]
lines.extend(render_query_block(data=self.query, indentation=self.indentation, offset=self.indentation))
lines.extend(
render_query_block(
data=self.query, indentation=self.indentation, offset=self.indentation, convert_enum=convert_enum
)
)
lines.append("}")

return "\n" + "\n".join(lines) + "\n"
Expand All @@ -143,14 +173,15 @@ def __init__(self, *args: Any, mutation: str, input_data: dict, **kwargs: Any):
self.mutation = mutation
super().__init__(*args, **kwargs)

def render(self) -> str:
def render(self, convert_enum: bool = False) -> str:
lines = [self.render_first_line()]
lines.append(" " * self.indentation + f"{self.mutation}(")
lines.extend(
render_input_block(
data=self.input_data,
indentation=self.indentation,
offset=self.indentation * 2,
convert_enum=convert_enum,
)
)
lines.append(" " * self.indentation + "){")
Expand All @@ -159,6 +190,7 @@ def render(self) -> str:
data=self.query,
indentation=self.indentation,
offset=self.indentation * 2,
convert_enum=convert_enum,
)
)
lines.append(" " * self.indentation + "}")
Expand Down
10 changes: 10 additions & 0 deletions infrahub_sdk/task/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
from __future__ import annotations

from .models import Task, TaskFilter, TaskLog, TaskRelatedNode, TaskState

__all__ = [
"Task",
"TaskFilter",
"TaskLog",
"TaskRelatedNode",
"TaskState",
]
18 changes: 12 additions & 6 deletions infrahub_sdk/task/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@


class InfraHubTaskManagerBase:
@classmethod
def _generate_query(
self,
cls,
filters: TaskFilter | None = None,
include_logs: bool = False,
include_related_nodes: bool = False,
Expand Down Expand Up @@ -69,7 +70,8 @@

return Query(query=query)

def _generate_count_query(self, filters: TaskFilter | None = None) -> Query:
@classmethod
def _generate_count_query(cls, filters: TaskFilter | None = None) -> Query:
query: dict[str, Any] = {
"InfrahubTask": {
"count": None,
Expand Down Expand Up @@ -98,7 +100,9 @@
"""

query = self._generate_count_query(filters=filters)
response = await self.client.execute_graphql(query=query.render(), tracker="query-tasks-count")
response = await self.client.execute_graphql(
query=query.render(convert_enum=False), tracker="query-tasks-count"
)
return int(response["InfrahubTask"]["count"])

async def all(
Expand Down Expand Up @@ -236,7 +240,7 @@
"""

response = await client.execute_graphql(
query=query.render(),
query=query.render(convert_enum=False),
tracker=f"query-tasks-page{page_number}",
timeout=timeout,
)
Expand Down Expand Up @@ -298,6 +302,7 @@
limit=self.client.pagination_size,
include_logs=include_logs,
include_related_nodes=include_related_nodes,
count=True,
)
new_tasks, count = await self.process_page(
client=self.client, query=query, page_number=page_number, timeout=timeout
Expand Down Expand Up @@ -330,7 +335,7 @@
"""

query = self._generate_count_query(filters=filters)
response = self.client.execute_graphql(query=query.render(), tracker="query-tasks-count")
response = self.client.execute_graphql(query=query.render(convert_enum=False), tracker="query-tasks-count")

Check warning on line 338 in infrahub_sdk/task/manager.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/task/manager.py#L338

Added line #L338 was not covered by tests
return int(response["InfrahubTask"]["count"])

def all(
Expand Down Expand Up @@ -468,7 +473,7 @@
"""

response = client.execute_graphql(
query=query.render(),
query=query.render(convert_enum=False),
tracker=f"query-tasks-page{page_number}",
timeout=timeout,
)
Expand Down Expand Up @@ -530,6 +535,7 @@
limit=self.client.pagination_size,
include_logs=include_logs,
include_related_nodes=include_related_nodes,
count=True,
)
new_tasks, count = self.process_page(
client=self.client, query=query, page_number=page_number, timeout=timeout
Expand Down
46 changes: 46 additions & 0 deletions tests/integration/test_infrahub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from infrahub_sdk.exceptions import BranchNotFoundError, URLNotFoundError
from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.schema import ProfileSchemaAPI
from infrahub_sdk.task.models import Task, TaskFilter, TaskLog, TaskState
from infrahub_sdk.testing.docker import TestInfrahubDockerClient
from infrahub_sdk.testing.schemas.animal import TESTING_ANIMAL, TESTING_CAT, TESTING_DOG, TESTING_PERSON, SchemaAnimal

Expand All @@ -31,6 +32,13 @@ async def base_dataset(
):
await client.branch.create(branch_name="branch01")

@pytest.fixture
async def set_pagination_size3(self, client: InfrahubClient):
original_pagination_size = client.pagination_size
client.pagination_size = 3
yield
client.pagination_size = original_pagination_size

async def test_query_branches(self, client: InfrahubClient, base_dataset):
branches = await client.branch.all()
main = await client.branch.get(branch_name="main")
Expand Down Expand Up @@ -162,6 +170,44 @@ async def test_create_generic_rel_with_hfid(
person_sophia = await client.get(kind=TESTING_PERSON, id=person_sophia.id, prefetch_relationships=True)
assert person_sophia.favorite_animal.id == cat_luna.id

async def test_task_query(self, client: InfrahubClient, base_dataset, set_pagination_size3):
nbr_tasks = await client.task.count()
assert nbr_tasks

tasks = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]))
assert tasks
task_ids = [task.id for task in tasks]

# Query Tasks using Parallel mode
tasks_parallel = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]), parallel=True)
assert tasks_parallel
assert len(tasks_parallel) == len(tasks)

# Query Tasks by ID
tasks_parallel_filtered = await client.task.filter(filter=TaskFilter(ids=task_ids[:2]), parallel=True)
assert tasks_parallel_filtered
assert len(tasks_parallel_filtered) == 2

# Query individual Task
task = await client.task.get(id=tasks[0].id)
assert task
assert isinstance(task, Task)
assert task.logs == []

# Wait for Task completion
task = await client.task.wait_for_completion(id=tasks[0].id)
assert task
assert isinstance(task, Task)

# Query Tasks with logs
tasks = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]), include_logs=True)
all_logs = [log for task in tasks for log in task.logs]
assert all_logs
assert isinstance(all_logs[0], TaskLog)
assert all_logs[0].message
assert all_logs[0].timestamp
assert all_logs[0].severity

# async def test_get_generic_filter_source(self, client: InfrahubClient, base_dataset):
# admin = await client.get(kind="CoreAccount", name__value="admin")

Expand Down
Loading