diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d07badc6b..bde22f1922 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: check-toml - id: check-yaml - id: end-of-file-fixer - exclude: ^schema/openapi\.json$ + exclude: ^(schema/schema\.graphql|schema/openapi\.json)$ - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. diff --git a/backend/infrahub/core/branch/models.py b/backend/infrahub/core/branch/models.py index 86cfb341b5..4dc9380128 100644 --- a/backend/infrahub/core/branch/models.py +++ b/backend/infrahub/core/branch/models.py @@ -1,8 +1,9 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Any, Optional, Self, Union +from typing import TYPE_CHECKING, Any, Optional, Self, Union, cast +from neo4j.graph import Node as Neo4jNode from pydantic import Field, field_validator from infrahub.core.branch.enums import BranchStatus @@ -10,8 +11,9 @@ from infrahub.core.graph import GRAPH_VERSION from infrahub.core.models import SchemaBranchHash # noqa: TC001 from infrahub.core.node.standard import StandardNode -from infrahub.core.query import QueryType +from infrahub.core.query import Query, QueryType from infrahub.core.query.branch import ( + BranchNodeGetListQuery, DeleteBranchRelationshipsQuery, GetAllBranchInternalRelationshipQuery, RebaseBranchDeleteRelationshipQuery, @@ -158,12 +160,28 @@ async def get_list( limit: int = 1000, ids: list[str] | None = None, name: str | None = None, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> list[Self]: - branches = await super().get_list(db=db, limit=limit, ids=ids, name=name, **kwargs) - branches = [branch for branch in branches if branch.status != BranchStatus.DELETING] + query: Query = await BranchNodeGetListQuery.init( + db=db, node_class=cls, ids=ids, node_name=name, limit=limit, **kwargs + ) + await query.execute(db=db) + + return [cls.from_db(node=cast(Neo4jNode, result.get("n"))) for result in query.get_results()] - return branches + @classmethod + async def get_list_count( + cls, + db: InfrahubDatabase, + limit: int = 1000, + ids: list[str] | None = None, + name: str | None = None, + **kwargs: Any, + ) -> int: + query: Query = await BranchNodeGetListQuery.init( + db=db, node_class=cls, ids=ids, node_name=name, limit=limit, **kwargs + ) + return await query.count(db=db) @classmethod def isinstance(cls, obj: Any) -> bool: diff --git a/backend/infrahub/core/query/branch.py b/backend/infrahub/core/query/branch.py index 37f08e617f..656d5c5f80 100644 --- a/backend/infrahub/core/query/branch.py +++ b/backend/infrahub/core/query/branch.py @@ -3,8 +3,10 @@ from typing import TYPE_CHECKING, Any from infrahub import config +from infrahub.core.branch.enums import BranchStatus from infrahub.core.constants import GLOBAL_BRANCH_NAME from infrahub.core.query import Query, QueryType +from infrahub.core.query.standard_node import StandardNodeGetListQuery if TYPE_CHECKING: from infrahub.database import InfrahubDatabase @@ -146,3 +148,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa self.add_to_query(query=query) self.params["ids"] = [db.to_database_id(id) for id in self.ids] + + +class BranchNodeGetListQuery(StandardNodeGetListQuery): + raw_filter = f"n.status <> '{BranchStatus.DELETING.value}'" diff --git a/backend/infrahub/core/query/standard_node.py b/backend/infrahub/core/query/standard_node.py index 1fc212edc9..92ae1ec45e 100644 --- a/backend/infrahub/core/query/standard_node.py +++ b/backend/infrahub/core/query/standard_node.py @@ -132,6 +132,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa class StandardNodeGetListQuery(Query): name = "standard_node_list" type = QueryType.READ + raw_filter: str | None = None def __init__( self, node_class: StandardNode, ids: list[str] | None = None, node_name: str | None = None, **kwargs: Any @@ -150,6 +151,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa if self.node_name: filters.append("n.name = $name") self.params["name"] = self.node_name + if self.raw_filter: + filters.append(self.raw_filter) where = "" if filters: diff --git a/backend/infrahub/graphql/queries/__init__.py b/backend/infrahub/graphql/queries/__init__.py index 1082887a40..e1da8ca9af 100644 --- a/backend/infrahub/graphql/queries/__init__.py +++ b/backend/infrahub/graphql/queries/__init__.py @@ -1,5 +1,5 @@ from .account import AccountPermissions, AccountToken -from .branch import BranchQueryList +from .branch import BranchQueryList, InfrahubBranchQueryList from .internal import InfrahubInfo from .ipam import ( DeprecatedIPAddressGetNextAvailable, @@ -20,6 +20,7 @@ "BranchQueryList", "DeprecatedIPAddressGetNextAvailable", "DeprecatedIPPrefixGetNextAvailable", + "InfrahubBranchQueryList", "InfrahubIPAddressGetNextAvailable", "InfrahubIPPrefixGetNextAvailable", "InfrahubInfo", diff --git a/backend/infrahub/graphql/queries/branch.py b/backend/infrahub/graphql/queries/branch.py index 4c1402fcbb..0b63f7ba27 100644 --- a/backend/infrahub/graphql/queries/branch.py +++ b/backend/infrahub/graphql/queries/branch.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING, Any -from graphene import ID, Field, List, NonNull, String +from graphene import ID, Field, Int, List, NonNull, String from infrahub.graphql.field_extractor import extract_graphql_fields -from infrahub.graphql.types import BranchType +from infrahub.graphql.types import BranchType, InfrahubBranch, InfrahubBranchType if TYPE_CHECKING: from graphql import GraphQLResolveInfo @@ -28,3 +28,31 @@ async def branch_resolver( resolver=branch_resolver, required=True, ) + + +async def infrahub_branch_resolver( + root: dict, # noqa: ARG001 + info: GraphQLResolveInfo, + limit: int | None = None, + offset: int | None = None, +) -> dict[str, Any]: + fields = extract_graphql_fields(info) + result: dict[str, Any] = {} + if "edges" in fields: + branches = await InfrahubBranch.get_list( + graphql_context=info.context, fields=fields.get("edges", {}).get("node", {}), limit=limit, offset=offset + ) + result["edges"] = [{"node": branch} for branch in branches] + if "count" in fields: + result["count"] = await InfrahubBranchType.get_list_count(graphql_context=info.context) + return result + + +InfrahubBranchQueryList = Field( + InfrahubBranchType, + offset=Int(), + limit=Int(), + description="Retrieve paginated information about active branches.", + resolver=infrahub_branch_resolver, + required=True, +) diff --git a/backend/infrahub/graphql/schema.py b/backend/infrahub/graphql/schema.py index 30c5155516..1db404b913 100644 --- a/backend/infrahub/graphql/schema.py +++ b/backend/infrahub/graphql/schema.py @@ -39,6 +39,7 @@ BranchQueryList, DeprecatedIPAddressGetNextAvailable, DeprecatedIPPrefixGetNextAvailable, + InfrahubBranchQueryList, InfrahubInfo, InfrahubIPAddressGetNextAvailable, InfrahubIPPrefixGetNextAvailable, @@ -65,6 +66,7 @@ class InfrahubBaseQuery(ObjectType): Relationship = Relationship + InfrahubBranch = InfrahubBranchQueryList InfrahubInfo = InfrahubInfo InfrahubStatus = InfrahubStatus diff --git a/backend/infrahub/graphql/types/__init__.py b/backend/infrahub/graphql/types/__init__.py index 45aabd62e3..63280557f6 100644 --- a/backend/infrahub/graphql/types/__init__.py +++ b/backend/infrahub/graphql/types/__init__.py @@ -21,7 +21,7 @@ StrAttributeType, TextAttributeType, ) -from .branch import BranchType +from .branch import BranchType, InfrahubBranch, InfrahubBranchType from .interface import InfrahubInterface from .node import InfrahubObject from .permission import PaginatedObjectPermission @@ -41,6 +41,8 @@ "DropdownType", "IPHostType", "IPNetworkType", + "InfrahubBranch", + "InfrahubBranchType", "InfrahubInterface", "InfrahubObject", "InfrahubObjectType", diff --git a/backend/infrahub/graphql/types/branch.py b/backend/infrahub/graphql/types/branch.py index 285e02b026..da2b9babf1 100644 --- a/backend/infrahub/graphql/types/branch.py +++ b/backend/infrahub/graphql/types/branch.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING, Any -from graphene import Boolean, Field, Int, String +from graphene import Boolean, Field, Int, List, NonNull, String from infrahub.core.branch import Branch from infrahub.core.constants import GLOBAL_BRANCH_NAME +from ...exceptions import BranchNotFoundError from .enums import InfrahubBranchStatus from .standard_node import InfrahubObjectType @@ -33,6 +34,10 @@ class Meta: name = "Branch" model = Branch + @staticmethod + async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]: + return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME] + @classmethod async def get_list( cls, @@ -46,4 +51,80 @@ async def get_list( if not objs: return [] - return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME] + return await cls._map_fields_to_graphql(objs=objs, fields=fields) + + +class RequiredStringValueField(InfrahubObjectType): + value = String(required=True) + + +class NonRequiredStringValueField(InfrahubObjectType): + value = String(required=False) + + +class NonRequiredBooleanValueField(InfrahubObjectType): + value = Boolean(required=False) + + +class StatusField(InfrahubObjectType): + value = InfrahubBranchStatus(required=True) + + +class InfrahubBranch(BranchType): + id = String(required=True) + created_at = String(required=False) + + name = Field(RequiredStringValueField, required=True) + description = Field(NonRequiredStringValueField, required=False) + origin_branch = Field(NonRequiredStringValueField, required=False) + branched_from = Field(NonRequiredStringValueField, required=False) + status = Field(StatusField, required=True) + sync_with_git = Field(NonRequiredBooleanValueField, required=False) + is_default = Field(NonRequiredBooleanValueField, required=False) + is_isolated = Field( + NonRequiredBooleanValueField, required=False, deprecation_reason="non isolated mode is not supported anymore" + ) + has_schema_changes = Field(NonRequiredBooleanValueField, required=False) + + class Meta: + description = "InfrahubBranch" + name = "InfrahubBranch" + + @staticmethod + async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]: + field_keys = fields.keys() + result: list[dict[str, Any]] = [] + for obj in objs: + if obj.name == GLOBAL_BRANCH_NAME: + continue + data: dict[str, Any] = {} + for field in field_keys: + if field == "id": + data["id"] = obj.uuid + continue + value = getattr(obj, field, None) + if isinstance(fields.get(field), dict): + data[field] = {"value": value} + else: + data[field] = value + result.append(data) + return result + + +class InfrahubBranchEdge(InfrahubObjectType): + node = Field(InfrahubBranch, required=True) + + +class InfrahubBranchType(InfrahubObjectType): + count = Field(Int, description="Total number of items") + edges = Field(NonNull(List(of_type=NonNull(InfrahubBranchEdge)))) + + @classmethod + async def get_list_count(cls, graphql_context: GraphqlContext, **kwargs: Any) -> int: + async with graphql_context.db.start_session(read_only=True) as db: + count = await Branch.get_list_count(db=db, **kwargs) + try: + await Branch.get_by_name(name=GLOBAL_BRANCH_NAME, db=db) + return count - 1 + except BranchNotFoundError: + return count diff --git a/backend/tests/unit/graphql/queries/test_branch.py b/backend/tests/unit/graphql/queries/test_branch.py index 7263e06b74..b1bf35a010 100644 --- a/backend/tests/unit/graphql/queries/test_branch.py +++ b/backend/tests/unit/graphql/queries/test_branch.py @@ -141,3 +141,95 @@ async def test_branch_query( assert id_response.data assert id_response.data["Branch"][0]["name"] == "branch3" assert len(id_response.data["Branch"]) == 1 + + async def test_paginated_branch_query( + self, + db: InfrahubDatabase, + default_branch: Branch, + register_core_models_schema, + session_admin, + client, + service, + ): + for i in range(10): + create_branch_query = """ + mutation { + BranchCreate(data: { name: "%s", description: "%s" }) { + ok + object { + id + name + } + } + } + """ % ( + f"sample-branch-{i}", + f"sample description {i}", + ) + + gql_params = await prepare_graphql_params( + db=db, + branch=default_branch, + account_session=session_admin, + service=service, + ) + branch_result = await graphql( + schema=gql_params.schema, + source=create_branch_query, + context_value=gql_params.context, + root_value=None, + variable_values={}, + ) + assert branch_result.errors is None + assert branch_result.data + + query = """ + query { + InfrahubBranch(offset: 2, limit: 5) { + count + edges { + node { + name { + value + } + description { + value + } + } + } + } + } + """ + gql_params = await prepare_graphql_params(db=db, branch=default_branch, service=service) + all_branches = await graphql( + schema=gql_params.schema, + source=query, + context_value=gql_params.context, + root_value=None, + variable_values={}, + ) + assert all_branches.errors is None + assert all_branches.data + assert all_branches.data["InfrahubBranch"]["count"] == 12 # 10 created here + 1 created above + main branch + + expected_branches = [ + { + "description": {"value": "Default Branch"}, + "name": {"value": "main"}, + }, + { + "description": {"value": "my description"}, + "name": {"value": "branch3"}, + }, + *[ + { + "description": {"value": f"sample description {i}"}, + "name": {"value": f"sample-branch-{i}"}, + } + for i in range(10) + ], + ] + all_branches_data_only = [branch.get("node") for branch in all_branches.data["InfrahubBranch"]["edges"]] + assert all_branches_data_only.sort(key=lambda x: x["name"]["value"]) == expected_branches.sort( + key=lambda x: x["name"]["value"] + ) diff --git a/schema/schema.graphql b/schema/schema.graphql index 36812404d0..30d1b0c408 100644 --- a/schema/schema.graphql +++ b/schema/schema.graphql @@ -7242,6 +7242,32 @@ input InfrahubAccountUpdateSelfInput { password: String } +"""InfrahubBranch""" +type InfrahubBranch { + branched_from: NonRequiredStringValueField + created_at: String + description: NonRequiredStringValueField + graph_version: Int + has_schema_changes: NonRequiredBooleanValueField + id: String! + is_default: NonRequiredBooleanValueField + is_isolated: NonRequiredBooleanValueField @deprecated(reason: "non isolated mode is not supported anymore") + name: RequiredStringValueField! + origin_branch: NonRequiredStringValueField + status: StatusField! + sync_with_git: NonRequiredBooleanValueField +} + +type InfrahubBranchEdge { + node: InfrahubBranch! +} + +type InfrahubBranchType { + """Total number of items""" + count: Int + edges: [InfrahubBranchEdge!]! +} + input InfrahubComputedAttributeRecomputeInput { """Name of the computed attribute that must be recomputed""" attribute: String! @@ -9380,6 +9406,14 @@ type NodeMutatedEvent implements EventNodeInterface { relationships: [InfrahubMutatedRelationship!]! } +type NonRequiredBooleanValueField { + value: Boolean +} + +type NonRequiredStringValueField { + value: String +} + """Attribute of type Number""" type NumberAttribute implements AttributeInterface { id: String @@ -10755,6 +10789,8 @@ type Query { IPAddressGetNextAvailable(prefix_id: String!, prefix_length: Int): IPAddressGetNextAvailable! @deprecated(reason: "This query has been renamed to 'InfrahubIPAddressGetNextAvailable'. It will be removed in the next version of Infrahub.") IPPrefixGetNextAvailable(prefix_id: String!, prefix_length: Int): IPPrefixGetNextAvailable! @deprecated(reason: "This query has been renamed to 'InfrahubIPPrefixGetNextAvailable'. It will be removed in the next version of Infrahub.") InfrahubAccountToken(limit: Int, offset: Int): AccountTokenEdges! + """Retrieve paginated information about active branches.""" + InfrahubBranch(limit: Int, offset: Int): InfrahubBranchType! InfrahubEvent( """Filter the query to specific accounts""" account__ids: [String!] @@ -10902,6 +10938,10 @@ type Relationships { edges: [RelationshipNode!]! } +type RequiredStringValueField { + value: String! +} + type ResolveDiffConflict { ok: Boolean } @@ -10995,6 +11035,10 @@ type Status { workers: StatusWorkerEdges! } +type StatusField { + value: BranchStatus! +} + type StatusSummary { """Indicates if the schema hash is in sync on all active workers""" schema_hash_synced: Boolean!