diff --git a/changelog/18.added.md b/changelog/18.added.md new file mode 100644 index 00000000..e0400493 --- /dev/null +++ b/changelog/18.added.md @@ -0,0 +1 @@ +Add support for Enum in GraphQL query and mutation. \ No newline at end of file diff --git a/infrahub_sdk/graphql.py b/infrahub_sdk/graphql.py index 9b7722da..abb12d8b 100644 --- a/infrahub_sdk/graphql.py +++ b/infrahub_sdk/graphql.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum from typing import Any from pydantic import BaseModel @@ -7,19 +8,30 @@ VARIABLE_TYPE_MAPPING = ((str, "String!"), (int, "Int!"), (float, "Float!"), (bool, "Boolean!")) -def convert_to_graphql_as_string(value: str | bool | list) -> str: +def convert_to_graphql_as_string(value: str | bool | list | BaseModel | Enum | Any, convert_enum: bool = False) -> str: # noqa: PLR0911 if isinstance(value, str) and value.startswith("$"): return value + if isinstance(value, Enum): + if convert_enum: + return convert_to_graphql_as_string(value=value.value, convert_enum=True) + return value.name if isinstance(value, str): return f'"{value}"' if isinstance(value, bool): return repr(value).lower() if isinstance(value, list): - values_as_string = [convert_to_graphql_as_string(item) for item in value] + values_as_string = [convert_to_graphql_as_string(value=item, convert_enum=convert_enum) for item in value] return "[" + ", ".join(values_as_string) + "]" if isinstance(value, BaseModel): data = value.model_dump() - return "{ " + ", ".join(f"{key}: {convert_to_graphql_as_string(val)}" for key, val in data.items()) + " }" + return ( + "{ " + + ", ".join( + f"{key}: {convert_to_graphql_as_string(value=val, convert_enum=convert_enum)}" + for key, val in data.items() + ) + + " }" + ) return str(value) @@ -38,7 +50,7 @@ def render_variables_to_string(data: dict[str, type[str | int | float | bool]]) return ", ".join([f"{key}: {value}" for key, value in vars_dict.items()]) -def render_query_block(data: dict, offset: int = 4, indentation: int = 4) -> list[str]: +def render_query_block(data: dict, offset: int = 4, indentation: int = 4, convert_enum: bool = False) -> list[str]: FILTERS_KEY = "@filters" ALIAS_KEY = "@alias" KEYWORDS_TO_SKIP = [FILTERS_KEY, ALIAS_KEY] @@ -60,25 +72,36 @@ def render_query_block(data: dict, offset: int = 4, indentation: int = 4) -> lis if value.get(FILTERS_KEY): filters_str = ", ".join( - [f"{key2}: {convert_to_graphql_as_string(value2)}" for key2, value2 in value[FILTERS_KEY].items()] + [ + f"{key2}: {convert_to_graphql_as_string(value=value2, convert_enum=convert_enum)}" + for key2, value2 in value[FILTERS_KEY].items() + ] ) lines.append(f"{offset_str}{key_str}({filters_str}) " + "{") else: lines.append(f"{offset_str}{key_str} " + "{") - lines.extend(render_query_block(data=value, offset=offset + indentation, indentation=indentation)) + lines.extend( + render_query_block( + data=value, offset=offset + indentation, indentation=indentation, convert_enum=convert_enum + ) + ) lines.append(offset_str + "}") return lines -def render_input_block(data: dict, offset: int = 4, indentation: int = 4) -> list[str]: +def render_input_block(data: dict, offset: int = 4, indentation: int = 4, convert_enum: bool = False) -> list[str]: offset_str = " " * offset lines = [] for key, value in data.items(): if isinstance(value, dict): lines.append(f"{offset_str}{key}: " + "{") - lines.extend(render_input_block(data=value, offset=offset + indentation, indentation=indentation)) + lines.extend( + render_input_block( + data=value, offset=offset + indentation, indentation=indentation, convert_enum=convert_enum + ) + ) lines.append(offset_str + "}") elif isinstance(value, list): lines.append(f"{offset_str}{key}: " + "[") @@ -90,14 +113,17 @@ def render_input_block(data: dict, offset: int = 4, indentation: int = 4) -> lis data=item, offset=offset + indentation + indentation, indentation=indentation, + convert_enum=convert_enum, ) ) lines.append(f"{offset_str}{' ' * indentation}" + "},") else: - lines.append(f"{offset_str}{' ' * indentation}{convert_to_graphql_as_string(item)},") + lines.append( + f"{offset_str}{' ' * indentation}{convert_to_graphql_as_string(value=item, convert_enum=convert_enum)}," + ) lines.append(offset_str + "]") else: - lines.append(f"{offset_str}{key}: {convert_to_graphql_as_string(value)}") + lines.append(f"{offset_str}{key}: {convert_to_graphql_as_string(value=value, convert_enum=convert_enum)}") return lines @@ -127,9 +153,13 @@ def render_first_line(self) -> str: class Query(BaseGraphQLQuery): query_type = "query" - def render(self) -> str: + def render(self, convert_enum: bool = False) -> str: lines = [self.render_first_line()] - lines.extend(render_query_block(data=self.query, indentation=self.indentation, offset=self.indentation)) + lines.extend( + render_query_block( + data=self.query, indentation=self.indentation, offset=self.indentation, convert_enum=convert_enum + ) + ) lines.append("}") return "\n" + "\n".join(lines) + "\n" @@ -143,7 +173,7 @@ def __init__(self, *args: Any, mutation: str, input_data: dict, **kwargs: Any): self.mutation = mutation super().__init__(*args, **kwargs) - def render(self) -> str: + def render(self, convert_enum: bool = False) -> str: lines = [self.render_first_line()] lines.append(" " * self.indentation + f"{self.mutation}(") lines.extend( @@ -151,6 +181,7 @@ def render(self) -> str: data=self.input_data, indentation=self.indentation, offset=self.indentation * 2, + convert_enum=convert_enum, ) ) lines.append(" " * self.indentation + "){") @@ -159,6 +190,7 @@ def render(self) -> str: data=self.query, indentation=self.indentation, offset=self.indentation * 2, + convert_enum=convert_enum, ) ) lines.append(" " * self.indentation + "}") diff --git a/infrahub_sdk/task/__init__.py b/infrahub_sdk/task/__init__.py index 8b137891..60180315 100644 --- a/infrahub_sdk/task/__init__.py +++ b/infrahub_sdk/task/__init__.py @@ -1 +1,11 @@ +from __future__ import annotations +from .models import Task, TaskFilter, TaskLog, TaskRelatedNode, TaskState + +__all__ = [ + "Task", + "TaskFilter", + "TaskLog", + "TaskRelatedNode", + "TaskState", +] diff --git a/infrahub_sdk/task/manager.py b/infrahub_sdk/task/manager.py index dddce05a..910030dd 100644 --- a/infrahub_sdk/task/manager.py +++ b/infrahub_sdk/task/manager.py @@ -14,8 +14,9 @@ class InfraHubTaskManagerBase: + @classmethod def _generate_query( - self, + cls, filters: TaskFilter | None = None, include_logs: bool = False, include_related_nodes: bool = False, @@ -69,7 +70,8 @@ def _generate_query( return Query(query=query) - def _generate_count_query(self, filters: TaskFilter | None = None) -> Query: + @classmethod + def _generate_count_query(cls, filters: TaskFilter | None = None) -> Query: query: dict[str, Any] = { "InfrahubTask": { "count": None, @@ -98,7 +100,9 @@ async def count(self, filters: TaskFilter | None = None) -> int: """ query = self._generate_count_query(filters=filters) - response = await self.client.execute_graphql(query=query.render(), tracker="query-tasks-count") + response = await self.client.execute_graphql( + query=query.render(convert_enum=False), tracker="query-tasks-count" + ) return int(response["InfrahubTask"]["count"]) async def all( @@ -236,7 +240,7 @@ async def process_page( """ response = await client.execute_graphql( - query=query.render(), + query=query.render(convert_enum=False), tracker=f"query-tasks-page{page_number}", timeout=timeout, ) @@ -298,6 +302,7 @@ async def process_non_batch( limit=self.client.pagination_size, include_logs=include_logs, include_related_nodes=include_related_nodes, + count=True, ) new_tasks, count = await self.process_page( client=self.client, query=query, page_number=page_number, timeout=timeout @@ -330,7 +335,7 @@ def count(self, filters: TaskFilter | None = None) -> int: """ query = self._generate_count_query(filters=filters) - response = self.client.execute_graphql(query=query.render(), tracker="query-tasks-count") + response = self.client.execute_graphql(query=query.render(convert_enum=False), tracker="query-tasks-count") return int(response["InfrahubTask"]["count"]) def all( @@ -468,7 +473,7 @@ def process_page( """ response = client.execute_graphql( - query=query.render(), + query=query.render(convert_enum=False), tracker=f"query-tasks-page{page_number}", timeout=timeout, ) @@ -530,6 +535,7 @@ def process_non_batch( limit=self.client.pagination_size, include_logs=include_logs, include_related_nodes=include_related_nodes, + count=True, ) new_tasks, count = self.process_page( client=self.client, query=query, page_number=page_number, timeout=timeout diff --git a/tests/integration/test_infrahub_client.py b/tests/integration/test_infrahub_client.py index 20af4a35..6eea3e80 100644 --- a/tests/integration/test_infrahub_client.py +++ b/tests/integration/test_infrahub_client.py @@ -8,6 +8,7 @@ from infrahub_sdk.exceptions import BranchNotFoundError, URLNotFoundError from infrahub_sdk.node import InfrahubNode from infrahub_sdk.schema import ProfileSchemaAPI +from infrahub_sdk.task.models import Task, TaskFilter, TaskLog, TaskState from infrahub_sdk.testing.docker import TestInfrahubDockerClient from infrahub_sdk.testing.schemas.animal import TESTING_ANIMAL, TESTING_CAT, TESTING_DOG, TESTING_PERSON, SchemaAnimal @@ -31,6 +32,13 @@ async def base_dataset( ): await client.branch.create(branch_name="branch01") + @pytest.fixture + async def set_pagination_size3(self, client: InfrahubClient): + original_pagination_size = client.pagination_size + client.pagination_size = 3 + yield + client.pagination_size = original_pagination_size + async def test_query_branches(self, client: InfrahubClient, base_dataset): branches = await client.branch.all() main = await client.branch.get(branch_name="main") @@ -162,6 +170,44 @@ async def test_create_generic_rel_with_hfid( person_sophia = await client.get(kind=TESTING_PERSON, id=person_sophia.id, prefetch_relationships=True) assert person_sophia.favorite_animal.id == cat_luna.id + async def test_task_query(self, client: InfrahubClient, base_dataset, set_pagination_size3): + nbr_tasks = await client.task.count() + assert nbr_tasks + + tasks = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED])) + assert tasks + task_ids = [task.id for task in tasks] + + # Query Tasks using Parallel mode + tasks_parallel = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]), parallel=True) + assert tasks_parallel + assert len(tasks_parallel) == len(tasks) + + # Query Tasks by ID + tasks_parallel_filtered = await client.task.filter(filter=TaskFilter(ids=task_ids[:2]), parallel=True) + assert tasks_parallel_filtered + assert len(tasks_parallel_filtered) == 2 + + # Query individual Task + task = await client.task.get(id=tasks[0].id) + assert task + assert isinstance(task, Task) + assert task.logs == [] + + # Wait for Task completion + task = await client.task.wait_for_completion(id=tasks[0].id) + assert task + assert isinstance(task, Task) + + # Query Tasks with logs + tasks = await client.task.filter(filter=TaskFilter(state=[TaskState.COMPLETED]), include_logs=True) + all_logs = [log for task in tasks for log in task.logs] + assert all_logs + assert isinstance(all_logs[0], TaskLog) + assert all_logs[0].message + assert all_logs[0].timestamp + assert all_logs[0].severity + # async def test_get_generic_filter_source(self, client: InfrahubClient, base_dataset): # admin = await client.get(kind="CoreAccount", name__value="admin") diff --git a/tests/unit/sdk/test_graphql.py b/tests/unit/sdk/test_graphql.py index de16fe02..06b23022 100644 --- a/tests/unit/sdk/test_graphql.py +++ b/tests/unit/sdk/test_graphql.py @@ -1,8 +1,20 @@ +from enum import Enum + import pytest from infrahub_sdk.graphql import Mutation, Query, render_input_block, render_query_block +class MyStrEnum(str, Enum): + VALUE1 = "value1" + VALUE2 = "value2" + + +class MyIntEnum(int, Enum): + VALUE1 = 12 + VALUE2 = 24 + + @pytest.fixture def query_data_no_filter(): data = { @@ -78,10 +90,10 @@ def query_data_filters_01(): def query_data_filters_02(): data = { "device": { - "@filters": {"name__value": "myname", "integer__value": 44}, + "@filters": {"name__value": "myname", "integer__value": 44, "enumstr__value": MyStrEnum.VALUE2}, "name": {"value": None}, "interfaces": { - "@filters": {"enabled__value": True}, + "@filters": {"enabled__value": True, "enumint__value": MyIntEnum.VALUE1}, "name": {"value": None}, }, } @@ -324,11 +336,11 @@ def test_query_rendering_with_filters(query_data_filters_02): expected_query = """ query { - device(name__value: "myname", integer__value: 44) { + device(name__value: "myname", integer__value: 44, enumstr__value: VALUE2) { name { value } - interfaces(enabled__value: true) { + interfaces(enabled__value: true, enumint__value: VALUE1) { name { value } @@ -339,6 +351,26 @@ def test_query_rendering_with_filters(query_data_filters_02): assert query.render() == expected_query +def test_query_rendering_with_filters_convert_enum(query_data_filters_02): + query = Query(query=query_data_filters_02) + + expected_query = """ +query { + device(name__value: "myname", integer__value: 44, enumstr__value: "value2") { + name { + value + } + interfaces(enabled__value: true, enumint__value: 12) { + name { + value + } + } + } +} +""" + assert query.render(convert_enum=True) == expected_query + + def test_mutation_rendering_no_vars(input_data_01): query_data = {"ok": None, "object": {"id": None}} @@ -425,6 +457,40 @@ def test_mutation_rendering_many_relationships(): assert query.render() == expected_query +def test_mutation_rendering_enum(): + query_data = {"ok": None, "object": {"id": None}} + input_data = { + "data": { + "description": {"value": MyStrEnum.VALUE1}, + "size": {"value": MyIntEnum.VALUE2}, + } + } + + query = Mutation(mutation="myobject", query=query_data, input_data=input_data) + + expected_query = """ +mutation { + myobject( + data: { + description: { + value: VALUE1 + } + size: { + value: VALUE2 + } + } + ){ + ok + object { + id + } + } +} +""" + assert query.render_first_line() == "mutation {" + assert query.render() == expected_query + + def test_mutation_rendering_with_vars(input_data_01): query_data = {"ok": None, "object": {"id": None}} variables = {"name": str, "description": str, "number": int} diff --git a/tests/unit/sdk/test_task.py b/tests/unit/sdk/test_task.py index aef369dd..8837d36f 100644 --- a/tests/unit/sdk/test_task.py +++ b/tests/unit/sdk/test_task.py @@ -3,6 +3,7 @@ import pytest from infrahub_sdk.task.exceptions import TaskNotFoundError, TooManyTasksError +from infrahub_sdk.task.manager import InfraHubTaskManagerBase from infrahub_sdk.task.models import Task, TaskFilter, TaskState client_types = ["standard", "sync"] @@ -30,6 +31,36 @@ async def test_method_all_full(clients, mock_query_tasks_01, client_type): assert isinstance(tasks[0], Task) +async def test_generate_count_query(): + query = InfraHubTaskManagerBase._generate_count_query() + assert query + assert ( + query.render() + == """ +query { + InfrahubTask { + count + } +} +""" + ) + + query2 = InfraHubTaskManagerBase._generate_count_query( + filters=TaskFilter(ids=["azerty", "qwerty"], state=[TaskState.COMPLETED]) + ) + assert query2 + assert ( + query2.render() + == """ +query { + InfrahubTask(ids: ["azerty", "qwerty"], state: [COMPLETED]) { + count + } +} +""" + ) + + @pytest.mark.parametrize("client_type", client_types) async def test_method_filters(clients, mock_query_tasks_02_main, client_type): if client_type == "standard":