Skip to content

Commit 7989c6c

Browse files
authored
IFC-1886: Paginated branch graphql query (#7418)
* WIP * IFC-1886: Paginated branch graphql query * update repsonse format * fix mypy * refactor branch list and count logic * fix mypy * remove unused limit and offset on get list count * conditionally resolve fields * fix mypy, update schema * change response format * update status, add schema * update schema * fix mypy * use uuid for id * remove name and ids filter * update graphql schema * update graphql schema * update branch fields, validate limit and offset * update graphql schema * update test name
1 parent cd9dadc commit 7989c6c

File tree

11 files changed

+370
-13
lines changed

11 files changed

+370
-13
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
- id: check-toml
1111
- id: check-yaml
1212
- id: end-of-file-fixer
13-
exclude: ^schema/openapi\.json$
13+
exclude: ^(schema/schema\.graphql|schema/openapi\.json)$
1414

1515
- repo: https://github.com/astral-sh/ruff-pre-commit
1616
# Ruff version.

backend/infrahub/core/branch/models.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from __future__ import annotations
22

33
import re
4-
from typing import TYPE_CHECKING, Any, Optional, Self, Union
4+
from typing import TYPE_CHECKING, Any, Optional, Self, Union, cast
55

6+
from neo4j.graph import Node as Neo4jNode
67
from pydantic import Field, field_validator
78

89
from infrahub.core.branch.enums import BranchStatus
910
from infrahub.core.constants import GLOBAL_BRANCH_NAME
1011
from infrahub.core.graph import GRAPH_VERSION
1112
from infrahub.core.models import SchemaBranchHash # noqa: TC001
1213
from infrahub.core.node.standard import StandardNode
13-
from infrahub.core.query import QueryType
14+
from infrahub.core.query import Query, QueryType
1415
from infrahub.core.query.branch import (
16+
BranchNodeGetListQuery,
1517
DeleteBranchRelationshipsQuery,
1618
GetAllBranchInternalRelationshipQuery,
1719
RebaseBranchDeleteRelationshipQuery,
@@ -158,12 +160,28 @@ async def get_list(
158160
limit: int = 1000,
159161
ids: list[str] | None = None,
160162
name: str | None = None,
161-
**kwargs: dict[str, Any],
163+
**kwargs: Any,
162164
) -> list[Self]:
163-
branches = await super().get_list(db=db, limit=limit, ids=ids, name=name, **kwargs)
164-
branches = [branch for branch in branches if branch.status != BranchStatus.DELETING]
165+
query: Query = await BranchNodeGetListQuery.init(
166+
db=db, node_class=cls, ids=ids, node_name=name, limit=limit, **kwargs
167+
)
168+
await query.execute(db=db)
169+
170+
return [cls.from_db(node=cast(Neo4jNode, result.get("n"))) for result in query.get_results()]
165171

166-
return branches
172+
@classmethod
173+
async def get_list_count(
174+
cls,
175+
db: InfrahubDatabase,
176+
limit: int = 1000,
177+
ids: list[str] | None = None,
178+
name: str | None = None,
179+
**kwargs: Any,
180+
) -> int:
181+
query: Query = await BranchNodeGetListQuery.init(
182+
db=db, node_class=cls, ids=ids, node_name=name, limit=limit, **kwargs
183+
)
184+
return await query.count(db=db)
167185

168186
@classmethod
169187
def isinstance(cls, obj: Any) -> bool:

backend/infrahub/core/query/branch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from typing import TYPE_CHECKING, Any
44

55
from infrahub import config
6+
from infrahub.core.branch.enums import BranchStatus
67
from infrahub.core.constants import GLOBAL_BRANCH_NAME
78
from infrahub.core.query import Query, QueryType
9+
from infrahub.core.query.standard_node import StandardNodeGetListQuery
810

911
if TYPE_CHECKING:
1012
from infrahub.database import InfrahubDatabase
@@ -146,3 +148,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa
146148
self.add_to_query(query=query)
147149

148150
self.params["ids"] = [db.to_database_id(id) for id in self.ids]
151+
152+
153+
class BranchNodeGetListQuery(StandardNodeGetListQuery):
154+
raw_filter = f"n.status <> '{BranchStatus.DELETING.value}'"

backend/infrahub/core/query/standard_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa
132132
class StandardNodeGetListQuery(Query):
133133
name = "standard_node_list"
134134
type = QueryType.READ
135+
raw_filter: str | None = None
135136

136137
def __init__(
137138
self, node_class: StandardNode, ids: list[str] | None = None, node_name: str | None = None, **kwargs: Any
@@ -150,6 +151,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa
150151
if self.node_name:
151152
filters.append("n.name = $name")
152153
self.params["name"] = self.node_name
154+
if self.raw_filter:
155+
filters.append(self.raw_filter)
153156

154157
where = ""
155158
if filters:

backend/infrahub/graphql/queries/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .account import AccountPermissions, AccountToken
2-
from .branch import BranchQueryList
2+
from .branch import BranchQueryList, InfrahubBranchQueryList
33
from .internal import InfrahubInfo
44
from .ipam import (
55
DeprecatedIPAddressGetNextAvailable,
@@ -20,6 +20,7 @@
2020
"BranchQueryList",
2121
"DeprecatedIPAddressGetNextAvailable",
2222
"DeprecatedIPPrefixGetNextAvailable",
23+
"InfrahubBranchQueryList",
2324
"InfrahubIPAddressGetNextAvailable",
2425
"InfrahubIPPrefixGetNextAvailable",
2526
"InfrahubInfo",

backend/infrahub/graphql/queries/branch.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from typing import TYPE_CHECKING, Any
44

5-
from graphene import ID, Field, List, NonNull, String
5+
from graphene import ID, Field, Int, List, NonNull, String
66

7+
from infrahub.exceptions import ValidationError
78
from infrahub.graphql.field_extractor import extract_graphql_fields
8-
from infrahub.graphql.types import BranchType
9+
from infrahub.graphql.types import BranchType, InfrahubBranch, InfrahubBranchType
910

1011
if TYPE_CHECKING:
1112
from graphql import GraphQLResolveInfo
@@ -28,3 +29,36 @@ async def branch_resolver(
2829
resolver=branch_resolver,
2930
required=True,
3031
)
32+
33+
34+
async def infrahub_branch_resolver(
35+
root: dict, # noqa: ARG001
36+
info: GraphQLResolveInfo,
37+
limit: int | None = None,
38+
offset: int | None = None,
39+
) -> dict[str, Any]:
40+
if isinstance(limit, int) and limit < 1:
41+
raise ValidationError("limit must be >= 1")
42+
if isinstance(offset, int) and offset < 0:
43+
raise ValidationError("offset must be >= 0")
44+
45+
fields = extract_graphql_fields(info)
46+
result: dict[str, Any] = {}
47+
if "edges" in fields:
48+
branches = await InfrahubBranch.get_list(
49+
graphql_context=info.context, fields=fields.get("edges", {}).get("node", {}), limit=limit, offset=offset
50+
)
51+
result["edges"] = [{"node": branch} for branch in branches]
52+
if "count" in fields:
53+
result["count"] = await InfrahubBranchType.get_list_count(graphql_context=info.context)
54+
return result
55+
56+
57+
InfrahubBranchQueryList = Field(
58+
InfrahubBranchType,
59+
offset=Int(),
60+
limit=Int(),
61+
description="Retrieve paginated information about active branches.",
62+
resolver=infrahub_branch_resolver,
63+
required=True,
64+
)

backend/infrahub/graphql/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
BranchQueryList,
4040
DeprecatedIPAddressGetNextAvailable,
4141
DeprecatedIPPrefixGetNextAvailable,
42+
InfrahubBranchQueryList,
4243
InfrahubInfo,
4344
InfrahubIPAddressGetNextAvailable,
4445
InfrahubIPPrefixGetNextAvailable,
@@ -65,6 +66,7 @@ class InfrahubBaseQuery(ObjectType):
6566

6667
Relationship = Relationship
6768

69+
InfrahubBranch = InfrahubBranchQueryList
6870
InfrahubInfo = InfrahubInfo
6971
InfrahubStatus = InfrahubStatus
7072

backend/infrahub/graphql/types/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
StrAttributeType,
2222
TextAttributeType,
2323
)
24-
from .branch import BranchType
24+
from .branch import BranchType, InfrahubBranch, InfrahubBranchType
2525
from .interface import InfrahubInterface
2626
from .node import InfrahubObject
2727
from .permission import PaginatedObjectPermission
@@ -41,6 +41,8 @@
4141
"DropdownType",
4242
"IPHostType",
4343
"IPNetworkType",
44+
"InfrahubBranch",
45+
"InfrahubBranchType",
4446
"InfrahubInterface",
4547
"InfrahubObject",
4648
"InfrahubObjectType",

backend/infrahub/graphql/types/branch.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from typing import TYPE_CHECKING, Any
44

5-
from graphene import Boolean, Field, Int, String
5+
from graphene import Boolean, Field, Int, List, NonNull, String
66

77
from infrahub.core.branch import Branch
88
from infrahub.core.constants import GLOBAL_BRANCH_NAME
99

10+
from ...exceptions import BranchNotFoundError
1011
from .enums import InfrahubBranchStatus
1112
from .standard_node import InfrahubObjectType
1213

@@ -33,6 +34,10 @@ class Meta:
3334
name = "Branch"
3435
model = Branch
3536

37+
@staticmethod
38+
async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
39+
return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME]
40+
3641
@classmethod
3742
async def get_list(
3843
cls,
@@ -46,4 +51,82 @@ async def get_list(
4651
if not objs:
4752
return []
4853

49-
return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME]
54+
return await cls._map_fields_to_graphql(objs=objs, fields=fields)
55+
56+
57+
class RequiredStringValueField(InfrahubObjectType):
58+
value = String(required=True)
59+
60+
61+
class NonRequiredStringValueField(InfrahubObjectType):
62+
value = String(required=False)
63+
64+
65+
class NonRequiredIntValueField(InfrahubObjectType):
66+
value = Int(required=False)
67+
68+
69+
class NonRequiredBooleanValueField(InfrahubObjectType):
70+
value = Boolean(required=False)
71+
72+
73+
class StatusField(InfrahubObjectType):
74+
value = InfrahubBranchStatus(required=True)
75+
76+
77+
class InfrahubBranch(BranchType):
78+
name = Field(RequiredStringValueField, required=True)
79+
description = Field(NonRequiredStringValueField, required=False)
80+
origin_branch = Field(NonRequiredStringValueField, required=False)
81+
branched_from = Field(NonRequiredStringValueField, required=False)
82+
graph_version = Field(NonRequiredIntValueField, required=False)
83+
status = Field(StatusField, required=True)
84+
sync_with_git = Field(NonRequiredBooleanValueField, required=False)
85+
is_default = Field(NonRequiredBooleanValueField, required=False)
86+
is_isolated = Field(
87+
NonRequiredBooleanValueField, required=False, deprecation_reason="non isolated mode is not supported anymore"
88+
)
89+
has_schema_changes = Field(NonRequiredBooleanValueField, required=False)
90+
91+
class Meta:
92+
description = "InfrahubBranch"
93+
name = "InfrahubBranch"
94+
95+
@staticmethod
96+
async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
97+
field_keys = fields.keys()
98+
result: list[dict[str, Any]] = []
99+
for obj in objs:
100+
if obj.name == GLOBAL_BRANCH_NAME:
101+
continue
102+
data: dict[str, Any] = {}
103+
for field in field_keys:
104+
if field == "id":
105+
data["id"] = obj.uuid
106+
continue
107+
value = getattr(obj, field, None)
108+
if isinstance(fields.get(field), dict):
109+
data[field] = {"value": value}
110+
else:
111+
data[field] = value
112+
result.append(data)
113+
return result
114+
115+
116+
class InfrahubBranchEdge(InfrahubObjectType):
117+
node = Field(InfrahubBranch, required=True)
118+
119+
120+
class InfrahubBranchType(InfrahubObjectType):
121+
count = Field(Int, description="Total number of items")
122+
edges = Field(NonNull(List(of_type=NonNull(InfrahubBranchEdge))))
123+
124+
@classmethod
125+
async def get_list_count(cls, graphql_context: GraphqlContext, **kwargs: Any) -> int:
126+
async with graphql_context.db.start_session(read_only=True) as db:
127+
count = await Branch.get_list_count(db=db, **kwargs)
128+
try:
129+
await Branch.get_by_name(name=GLOBAL_BRANCH_NAME, db=db)
130+
return count - 1
131+
except BranchNotFoundError:
132+
return count

0 commit comments

Comments
 (0)