Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1102ae7
WIP
solababs Oct 13, 2025
4e16ec7
Merge branch 'develop' into sb-20251013-infrahub-branch-query-ifc-1886
solababs Oct 15, 2025
c0b3b31
IFC-1886: Paginated branch graphql query
solababs Oct 16, 2025
9a07e23
update repsonse format
solababs Oct 20, 2025
41b1849
fix mypy
solababs Oct 20, 2025
c8fbca1
Merge branch 'develop' into sb-20251013-infrahub-branch-query-ifc-1886
solababs Oct 20, 2025
dd242d9
Merge branch 'develop' into sb-20251013-infrahub-branch-query-ifc-1886
solababs Oct 21, 2025
9f1dc4a
refactor branch list and count logic
solababs Oct 21, 2025
3a7ff90
fix mypy
solababs Oct 21, 2025
e659340
remove unused limit and offset on get list count
solababs Oct 21, 2025
2ad4b2d
conditionally resolve fields
solababs Oct 28, 2025
32c72ad
fix mypy, update schema
solababs Oct 29, 2025
a38cec2
Merge branch 'develop' into sb-20251013-infrahub-branch-query-ifc-1886
solababs Oct 29, 2025
84fac0a
change response format
solababs Oct 29, 2025
f7d4828
Merge branch 'develop' into sb-20251013-infrahub-branch-query-ifc-1886
solababs Oct 29, 2025
491ad04
update status, add schema
solababs Oct 29, 2025
86dbefa
update schema
solababs Oct 29, 2025
9eba0b8
fix mypy
solababs Oct 29, 2025
9fe2763
Merge branch 'develop' into sb-20251013-infrahub-branch-query-ifc-1886
solababs Oct 30, 2025
8e4dca9
use uuid for id
solababs Oct 30, 2025
f0aa561
Merge branch 'sb-20251013-infrahub-branch-query-ifc-1886' of https://…
solababs Oct 30, 2025
089b015
remove name and ids filter
solababs Oct 31, 2025
fb7d0a6
Merge branch 'develop' into sb-20251013-infrahub-branch-query-ifc-1886
solababs Oct 31, 2025
1157cc4
update graphql schema
solababs Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions backend/infrahub/core/branch/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from typing import TYPE_CHECKING, Any, Optional, Self, Union, cast

from neo4j.graph import Node as Neo4jNode
from pydantic import Field, field_validator

from infrahub.core.branch.enums import BranchStatus
Expand All @@ -11,8 +12,9 @@
)
from infrahub.core.models import SchemaBranchHash # noqa: TC001
from infrahub.core.node.standard import StandardNode
from infrahub.core.query import QueryType
from infrahub.core.query import Query, QueryType
from infrahub.core.query.branch import (
BranchNodeGetListQuery,
DeleteBranchRelationshipsQuery,
GetAllBranchInternalRelationshipQuery,
RebaseBranchDeleteRelationshipQuery,
Expand Down Expand Up @@ -158,12 +160,28 @@ async def get_list(
limit: int = 1000,
ids: list[str] | None = None,
name: str | None = None,
**kwargs: dict[str, Any],
**kwargs: Any,
) -> list[Self]:
branches = await super().get_list(db=db, limit=limit, ids=ids, name=name, **kwargs)
branches = [branch for branch in branches if branch.status != BranchStatus.DELETING]
query: Query = await BranchNodeGetListQuery.init(
db=db, node_class=cls, ids=ids, node_name=name, limit=limit, **kwargs
)
await query.execute(db=db)

return [cls.from_db(node=cast(Neo4jNode, result.get("n"))) for result in query.get_results()]

return branches
@classmethod
async def get_list_count(
cls,
db: InfrahubDatabase,
limit: int = 1000,
ids: list[str] | None = None,
name: str | None = None,
**kwargs: Any,
) -> int:
query: Query = await BranchNodeGetListQuery.init(
db=db, node_class=cls, ids=ids, node_name=name, limit=limit, **kwargs
)
return await query.count(db=db)

@classmethod
def isinstance(cls, obj: Any) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions backend/infrahub/core/query/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import TYPE_CHECKING, Any

from infrahub import config
from infrahub.core.branch.enums import BranchStatus
from infrahub.core.constants import GLOBAL_BRANCH_NAME
from infrahub.core.query import Query, QueryType
from infrahub.core.query.standard_node import StandardNodeGetListQuery

if TYPE_CHECKING:
from infrahub.database import InfrahubDatabase
Expand Down Expand Up @@ -146,3 +148,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa
self.add_to_query(query=query)

self.params["ids"] = [db.to_database_id(id) for id in self.ids]


class BranchNodeGetListQuery(StandardNodeGetListQuery):
raw_filter = f"n.status <> '{BranchStatus.DELETING.value}'"
Comment on lines +153 to +154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add docstring and type annotation.

The class is missing a docstring and type annotation, which are required by the coding guidelines.

Apply this diff to add the missing docstring and type annotation:

 class BranchNodeGetListQuery(StandardNodeGetListQuery):
+    """Query to retrieve a list of branches, excluding those with DELETING status.
+    
+    This query extends StandardNodeGetListQuery with a filter to exclude branches
+    that are currently being deleted from the results.
+    """
+    
-    raw_filter = f"n.status <> '{BranchStatus.DELETING.value}'"
+    raw_filter: str = f"n.status <> '{BranchStatus.DELETING.value}'"

As per coding guidelines.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class BranchNodeGetListQuery(StandardNodeGetListQuery):
raw_filter = f"n.status <> '{BranchStatus.DELETING.value}'"
class BranchNodeGetListQuery(StandardNodeGetListQuery):
"""Query to retrieve a list of branches, excluding those with DELETING status.
This query extends StandardNodeGetListQuery with a filter to exclude branches
that are currently being deleted from the results.
"""
raw_filter: str = f"n.status <> '{BranchStatus.DELETING.value}'"
🤖 Prompt for AI Agents
In backend/infrahub/core/query/branch.py around lines 153-154, the class
BranchNodeGetListQuery is missing a docstring and a type annotation for its
raw_filter attribute; add a concise class docstring that explains this Query
filters out deleting Branch nodes and add an explicit type annotation on
raw_filter (raw_filter: str = f"n.status <> '{BranchStatus.DELETING.value}'") to
satisfy the coding guidelines.

3 changes: 3 additions & 0 deletions backend/infrahub/core/query/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa
class StandardNodeGetListQuery(Query):
name = "standard_node_list"
type = QueryType.READ
raw_filter: str | None = None

def __init__(
self, node_class: StandardNode, ids: list[str] | None = None, node_name: str | None = None, **kwargs: Any
Expand All @@ -150,6 +151,8 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa
if self.node_name:
filters.append("n.name = $name")
self.params["name"] = self.node_name
if self.raw_filter:
filters.append(self.raw_filter)

where = ""
if filters:
Expand Down
3 changes: 2 additions & 1 deletion backend/infrahub/graphql/queries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .account import AccountPermissions, AccountToken
from .branch import BranchQueryList
from .branch import BranchQueryList, InfrahubBranchQueryList
from .internal import InfrahubInfo
from .ipam import (
DeprecatedIPAddressGetNextAvailable,
Expand All @@ -20,6 +20,7 @@
"BranchQueryList",
"DeprecatedIPAddressGetNextAvailable",
"DeprecatedIPPrefixGetNextAvailable",
"InfrahubBranchQueryList",
"InfrahubIPAddressGetNextAvailable",
"InfrahubIPPrefixGetNextAvailable",
"InfrahubInfo",
Expand Down
34 changes: 32 additions & 2 deletions backend/infrahub/graphql/queries/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import TYPE_CHECKING, Any

from graphene import ID, Field, List, NonNull, String
from graphene import ID, Field, Int, List, NonNull, String

from infrahub.graphql.field_extractor import extract_graphql_fields
from infrahub.graphql.types import BranchType
from infrahub.graphql.types import BranchType, InfrahubBranch, InfrahubBranchType

if TYPE_CHECKING:
from graphql import GraphQLResolveInfo
Expand All @@ -28,3 +28,33 @@ async def branch_resolver(
resolver=branch_resolver,
required=True,
)


async def infrahub_branch_resolver(
root: dict, # noqa: ARG001
info: GraphQLResolveInfo,
limit: int | None = None,
offset: int | None = None,
) -> dict[str, Any]:
Comment on lines +33 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add validation for pagination parameters.

The limit and offset parameters lack validation, which could lead to issues with negative values or excessively large limits. Per past review feedback, parameter defaults were removed from the field definition, but the resolver signature still has None defaults and no validation.

Apply this diff to add validation:

 async def infrahub_branch_resolver(
     root: dict,  # noqa: ARG001
     info: GraphQLResolveInfo,
     limit: int | None = None,
     offset: int | None = None,
 ) -> dict[str, Any]:
+    """..."""  # Add docstring as per previous comment
+    
+    # Validate and set defaults for pagination parameters
+    if limit is None:
+        limit = 100
+    if offset is None:
+        offset = 0
+    
+    if limit < 1:
+        raise ValueError("limit must be at least 1")
+    if offset < 0:
+        raise ValueError("offset must be non-negative")
+    
     fields = extract_graphql_fields(info)
🤖 Prompt for AI Agents
In backend/infrahub/graphql/queries/branch.py around lines 33 to 38, the
resolver accepts limit and offset but does not validate them; add input
validation that enforces offset >= 0, limit >= 0 and limit <= MAX_LIMIT (choose
a sensible cap like 100 or use a configured constant), treat None as not
provided (or apply a safe default if your schema requires one), and raise a
GraphQLError with a clear message for invalid values. Ensure the validation runs
at the start of infrahub_branch_resolver and that any error raised is
descriptive (e.g., "limit must be between 0 and 100" or "offset must be >= 0"),
avoiding silent truncation or database queries with bad parameters.

fields = extract_graphql_fields(info)
result: dict[str, Any] = {}
if "edges" in fields:
branches = await InfrahubBranch.get_list(
graphql_context=info.context, fields=fields.get("edges", {}).get("node", {}), limit=limit, offset=offset
)
result["edges"] = [{"node": branch} for branch in branches]
if "count" in fields:
result["count"] = await InfrahubBranchType.get_list_count(graphql_context=info.context)
return result
Comment on lines +33 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Resolver ignores ids/name and forwards None pagination; fix signature, validation, and forwarding.

Currently ids/name are accepted by the field but dropped by the resolver; passing limit/offset=None can also override model defaults. Add proper params, validate, and only pass provided args.

As per coding guidelines, add a docstring too. Apply:

-async def infrahub_branch_resolver(
-    root: dict,  # noqa: ARG001
-    info: GraphQLResolveInfo,
-    limit: int | None = None,
-    offset: int | None = None,
-) -> dict[str, Any]:
-    fields = extract_graphql_fields(info)
-    result: dict[str, Any] = {}
-    if "edges" in fields:
-        branches = await InfrahubBranch.get_list(
-            graphql_context=info.context, fields=fields.get("edges", {}).get("node", {}), limit=limit, offset=offset
-        )
-        result["edges"] = [{"node": branch} for branch in branches]
-    if "count" in fields:
-        result["count"] = await InfrahubBranchType.get_list_count(graphql_context=info.context)
-    return result
+async def infrahub_branch_resolver(
+    root: dict,  # noqa: ARG001
+    info: GraphQLResolveInfo,
+    ids: list[str] | None = None,
+    name: str | None = None,
+    limit: int | None = None,
+    offset: int | None = None,
+) -> dict[str, Any]:
+    """Resolve paginated InfrahubBranch query.
+    
+    Args:
+        root: Unused GraphQL root value.
+        info: Resolve info with context and selection set.
+        ids: Optional list of branch IDs to filter.
+        name: Optional branch name filter.
+        limit: Max items to return; must be >= 1 if provided.
+        offset: Number of items to skip; must be >= 0 if provided.
+    Returns:
+        Dict containing requested fields (edges and/or count).
+    """
+    fields = extract_graphql_fields(info)
+    result: dict[str, Any] = {}
+
+    # Normalize pagination; only pass if provided and valid
+    query_kwargs: dict[str, int] = {}
+    if isinstance(limit, int):
+        if limit < 1:
+            raise ValueError("limit must be >= 1")
+        query_kwargs["limit"] = limit
+    if isinstance(offset, int):
+        if offset < 0:
+            raise ValueError("offset must be >= 0")
+        query_kwargs["offset"] = offset
+
+    if "edges" in fields:
+        node_fields = fields.get("edges", {}).get("node", {})
+        branches = await InfrahubBranch.get_list(
+            graphql_context=info.context,
+            fields=node_fields,
+            ids=ids,
+            name=name,
+            **query_kwargs,
+        )
+        result["edges"] = [{"node": branch} for branch in branches]
+
+    if "count" in fields:
+        result["count"] = await InfrahubBranchType.get_list_count(
+            graphql_context=info.context, ids=ids, name=name
+        )
+    return result

As per coding guidelines.

🤖 Prompt for AI Agents
In backend/infrahub/graphql/queries/branch.py around lines 33 to 48, the
resolver currently drops incoming ids/name arguments, and unconditionally
forwards limit/offset even when None (overriding model defaults) and lacks a
docstring; update the function signature to accept optional ids: list[str] |
None and name: str | None (with proper typing), add a short docstring describing
purpose and params, validate inputs (ensure ids is a non-empty list if provided
and name is a non-empty string), build kwargs for InfrahubBranch.get_list and
InfrahubBranchType.get_list_count that include only provided arguments (i.e.,
include limit/offset only when not None, include ids/name when provided), and
pass fields extracted from info for edges as before so the resolver respects and
forwards only explicit arguments to the model calls.



InfrahubBranchQueryList = Field(
InfrahubBranchType,
ids=List(of_type=NonNull(ID)),
name=String(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parameters doesn't seem to work now.

query MyQuery {
  InfrahubBranch(name: "branch2") {
    count
    edges {
      node {
        id
        name {
          value
        }
      }
    }
  }
}

I.e infrahub_branch_resolver() doesn't accept the parameters ids or name so the above query would fail. I think perhaps we can just remove both of them in this first iteration and then consider which filter options we want and need. We will want to add some of them for sure. But it could potentially be quite verbose if we start to add all of the variations that we dynamically generate for other node types, so I think that we can have a discussion around that after this PR is merge.

I.e. if we were to follow the current approach we'd have:

  • name__value
  • name__values
  • description__value
  • description__values
  • etc for all attributes.

offset=Int(),
limit=Int(),
description="Retrieve paginated information about active branches.",
resolver=infrahub_branch_resolver,
required=True,
)
2 changes: 2 additions & 0 deletions backend/infrahub/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
BranchQueryList,
DeprecatedIPAddressGetNextAvailable,
DeprecatedIPPrefixGetNextAvailable,
InfrahubBranchQueryList,
InfrahubInfo,
InfrahubIPAddressGetNextAvailable,
InfrahubIPPrefixGetNextAvailable,
Expand All @@ -65,6 +66,7 @@ class InfrahubBaseQuery(ObjectType):

Relationship = Relationship

InfrahubBranch = InfrahubBranchQueryList
InfrahubInfo = InfrahubInfo
InfrahubStatus = InfrahubStatus

Expand Down
4 changes: 3 additions & 1 deletion backend/infrahub/graphql/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
StrAttributeType,
TextAttributeType,
)
from .branch import BranchType
from .branch import BranchType, InfrahubBranch, InfrahubBranchType
from .interface import InfrahubInterface
from .node import InfrahubObject
from .permission import PaginatedObjectPermission
Expand All @@ -41,6 +41,8 @@
"DropdownType",
"IPHostType",
"IPNetworkType",
"InfrahubBranch",
"InfrahubBranchType",
"InfrahubInterface",
"InfrahubObject",
"InfrahubObjectType",
Expand Down
82 changes: 80 additions & 2 deletions backend/infrahub/graphql/types/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from typing import TYPE_CHECKING, Any

from graphene import Boolean, Field, String
from graphene import Boolean, Field, Int, List, NonNull, String

from infrahub.core.branch import Branch
from infrahub.core.constants import GLOBAL_BRANCH_NAME

from ...exceptions import BranchNotFoundError
from .enums import InfrahubBranchStatus
from .standard_node import InfrahubObjectType

Expand All @@ -32,6 +33,10 @@ class Meta:
name = "Branch"
model = Branch

@staticmethod
async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME]
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add method docstring.

The _map_fields_to_graphql static method is missing a docstring, which violates the coding guidelines requiring Google-style docstrings for all Python methods.

As per coding guidelines, apply this diff:

     @staticmethod
     async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
+        """Map Branch objects to GraphQL field dictionaries, excluding the global branch.
+        
+        Args:
+            objs: List of Branch objects to map.
+            fields: Dictionary of field names to include in the output.
+            
+        Returns:
+            List of dictionaries containing GraphQL field data for each branch.
+        """
         return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME]
🤖 Prompt for AI Agents
In backend/infrahub/graphql/types/branch.py around lines 36 to 38, the async
static method _map_fields_to_graphql is missing a Google-style docstring; add a
docstring above the method that briefly describes its purpose (map a list of
Branch objects to GraphQL dicts while excluding branches named
GLOBAL_BRANCH_NAME), documents parameters (objs: list[Branch], fields: dict) and
their types, notes that the method is asynchronous and returns list[dict[str,
Any]], and describes the return value and behavior (awaits each obj.to_graphql
and filters out GLOBAL_BRANCH_NAME).


@classmethod
async def get_list(
cls,
Expand All @@ -45,4 +50,77 @@ async def get_list(
if not objs:
return []

return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME]
return await cls._map_fields_to_graphql(objs=objs, fields=fields)


class RequiredStringValueField(InfrahubObjectType):
value = String(required=True)


class NonRequiredStringValueField(InfrahubObjectType):
value = String(required=False)


class NonRequiredBooleanValueField(InfrahubObjectType):
value = Boolean(required=False)


class StatusField(InfrahubObjectType):
value = InfrahubBranchStatus(required=True)

Comment on lines +56 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add docstrings and clarify the purpose of value wrapper classes.

All four value wrapper classes (RequiredStringValueField, NonRequiredStringValueField, NonRequiredBooleanValueField, StatusField) are missing docstrings, which violates the coding guidelines. Additionally, the purpose of these wrapper types is not immediately clear—they add a layer of indirection without explanation.

As per coding guidelines, add docstrings to each class explaining their purpose:

 class RequiredStringValueField(InfrahubObjectType):
+    """Wrapper type for required string values in GraphQL responses.
+    
+    Provides a consistent structure for string fields that must always have a value.
+    """
     value = String(required=True)


 class NonRequiredStringValueField(InfrahubObjectType):
+    """Wrapper type for optional string values in GraphQL responses.
+    
+    Provides a consistent structure for string fields that may be null.
+    """
     value = String(required=False)


 class NonRequiredBooleanValueField(InfrahubObjectType):
+    """Wrapper type for optional boolean values in GraphQL responses.
+    
+    Provides a consistent structure for boolean fields that may be null.
+    """
     value = Boolean(required=False)


 class StatusField(InfrahubObjectType):
+    """Wrapper type for branch status enum values in GraphQL responses.
+    
+    Provides a consistent structure for status fields.
+    """
     value = InfrahubBranchStatus(required=True)
🤖 Prompt for AI Agents
In backend/infrahub/graphql/types/branch.py around lines 56 to 70, the four
small wrapper classes lack docstrings and do not explain their purpose; add
brief docstrings to each class (RequiredStringValueField,
NonRequiredStringValueField, NonRequiredBooleanValueField, StatusField) that
state they are simple GraphQL value-wrapper types used to standardize field
payloads/typing (e.g., wrapping a single value for consistency across
mutations/queries), note whether the inner value is required or optional and the
expected type (String/Boolean/InfrahubBranchStatus), and include one-line
examples or usage intent if helpful — keep them concise and follow project
docstring style.


class InfrahubBranch(BranchType):
id = String(required=True)
created_at = String(required=False)

name = Field(RequiredStringValueField, required=True)
description = Field(NonRequiredStringValueField, required=False)
origin_branch = Field(NonRequiredStringValueField, required=False)
branched_from = Field(NonRequiredStringValueField, required=False)
status = Field(StatusField, required=True)
sync_with_git = Field(NonRequiredBooleanValueField, required=False)
is_default = Field(NonRequiredBooleanValueField, required=False)
is_isolated = Field(
NonRequiredBooleanValueField, required=False, deprecation_reason="non isolated mode is not supported anymore"
)
has_schema_changes = Field(NonRequiredBooleanValueField, required=False)

class Meta:
description = "InfrahubBranch"
name = "InfrahubBranch"

Comment on lines +72 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add class docstring and clarify field definition pattern.

The InfrahubBranch class is missing a docstring, which violates the coding guidelines. Additionally, there's an inconsistency in field definitions: id and created_at are defined as plain String fields, while other fields use Field() wrappers with value wrapper types. This inconsistency is confusing.

As per coding guidelines, apply this diff:

 class InfrahubBranch(BranchType):
+    """Extended branch type with wrapped field values for GraphQL pagination queries.
+    
+    Provides an alternative representation where most fields are wrapped in value objects
+    to support consistent field resolution in paginated responses.
+    """
     id = String(required=True)
     created_at = String(required=False)

Consider documenting why id and created_at remain unwrapped while other fields use the value wrapper pattern.

@staticmethod
async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
field_keys = fields.keys()
result: list[dict[str, Any]] = []
for obj in objs:
if obj.name == GLOBAL_BRANCH_NAME:
continue
data = {}
for field in field_keys:
value = getattr(obj, field, None)
if isinstance(fields.get(field), dict):
data[field] = {"value": value}
else:
data[field] = value
result.append(data)
return result
Comment on lines 92 to 110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add method docstring and clarify field wrapping logic.

The _map_fields_to_graphql method is missing a docstring, which violates the coding guidelines. The field wrapping logic (lines 102-105) is also complex and not immediately clear—it wraps values in {"value": ...} dictionaries based on whether fields.get(field) is a dict, but the reasoning isn't documented.

As per coding guidelines, apply this diff:

     @staticmethod
     async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
+        """Map Branch objects to GraphQL field dictionaries with wrapped field values.
+        
+        Fields that have nested field selections in the GraphQL query (indicated by
+        dict values in the fields parameter) are wrapped in {"value": ...} structures
+        to match the value wrapper types. Other fields are returned as-is.
+        
+        Args:
+            objs: List of Branch objects to map.
+            fields: Dictionary mapping field names to their selection sets.
+            
+        Returns:
+            List of dictionaries with field data, excluding GLOBAL_BRANCH_NAME.
+        """
         field_keys = fields.keys()
         result: list[dict[str, Any]] = []
         for obj in objs:
             if obj.name == GLOBAL_BRANCH_NAME:
                 continue
             data = {}
             for field in field_keys:
                 value = getattr(obj, field, None)
+                # Wrap value if the field has nested selections (value wrapper types)
                 if isinstance(fields.get(field), dict):
                     data[field] = {"value": value}
                 else:
                     data[field] = value
             result.append(data)
         return result
🤖 Prompt for AI Agents
In backend/infrahub/graphql/types/branch.py around lines 92 to 107, add a
concise docstring to _map_fields_to_graphql describing the method’s purpose,
parameters (objs: list[Branch], fields: dict) and return type (list[dict[str,
Any]]), and clarify the field-wrapping logic: when fields.get(field) is a dict
it indicates GraphQL requested subfields so the code should wrap the raw
attribute value in {"value": <attr>} to match the expected nested GraphQL shape;
update the docstring to explain this behavior and its rationale so the wrapping
on lines 102-105 is clear to readers and maintainers.



class InfrahubBranchEdge(InfrahubObjectType):
node = Field(InfrahubBranch, required=True)


class InfrahubBranchType(InfrahubObjectType):
count = Field(Int, description="Total number of items")
edges = Field(NonNull(List(of_type=NonNull(InfrahubBranchEdge))))

@classmethod
async def get_list_count(cls, graphql_context: GraphqlContext, **kwargs: Any) -> int:
async with graphql_context.db.start_session(read_only=True) as db:
count = await Branch.get_list_count(db=db, **kwargs)
try:
await Branch.get_by_name(name=GLOBAL_BRANCH_NAME, db=db)
return count - 1
except BranchNotFoundError:
return count
Comment on lines +121 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add method docstring (regression).

This method's docstring was marked as addressed in a previous commit (32c72ad) but is missing in the current code. Per coding guidelines, add a Google-style docstring documenting the purpose, parameters, and return value.

Apply this diff:

     @classmethod
     async def get_list_count(cls, graphql_context: GraphqlContext, **kwargs: Any) -> int:
+        """Get the total count of branches excluding the global branch.
+        
+        Args:
+            graphql_context: The GraphQL context containing database connection.
+            **kwargs: Additional filter parameters passed to Branch.get_list_count.
+            
+        Returns:
+            The count of branches, excluding GLOBAL_BRANCH_NAME if it exists.
+        """
         async with graphql_context.db.start_session(read_only=True) as db:
🤖 Prompt for AI Agents
In backend/infrahub/graphql/types/branch.py around lines 121 to 129, the async
classmethod get_list_count is missing its Google-style docstring; add a
docstring directly under the method definition that briefly describes what the
method does (returns the number of Branch records accessible via the GraphQL
context, excluding the GLOBAL_BRANCH_NAME when present), documents parameters
(graphql_context: GraphqlContext, **kwargs: Any) and their roles, and specifies
the return type (int) and behavior (subtracts one if GLOBAL_BRANCH_NAME exists,
otherwise returns full count); keep wording concise and follow Google docstring
format for purpose, Args, and Returns.

92 changes: 92 additions & 0 deletions backend/tests/unit/graphql/queries/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,95 @@ async def test_branch_query(
assert id_response.data
assert id_response.data["Branch"][0]["name"] == "branch3"
assert len(id_response.data["Branch"]) == 1

async def test_paginated_branch_query(
self,
db: InfrahubDatabase,
default_branch: Branch,
register_core_models_schema,
session_admin,
client,
service,
):
for i in range(10):
create_branch_query = """
mutation {
BranchCreate(data: { name: "%s", description: "%s" }) {
ok
object {
id
name
}
}
}
""" % (
f"sample-branch-{i}",
f"sample description {i}",
)

gql_params = await prepare_graphql_params(
db=db,
branch=default_branch,
account_session=session_admin,
service=service,
)
branch_result = await graphql(
schema=gql_params.schema,
source=create_branch_query,
context_value=gql_params.context,
root_value=None,
variable_values={},
)
assert branch_result.errors is None
assert branch_result.data

query = """
query {
InfrahubBranch(offset: 2, limit: 5) {
count
edges {
node {
name {
value
}
description {
value
}
}
}
}
}
"""
gql_params = await prepare_graphql_params(db=db, branch=default_branch, service=service)
all_branches = await graphql(
schema=gql_params.schema,
source=query,
context_value=gql_params.context,
root_value=None,
variable_values={},
)
assert all_branches.errors is None
assert all_branches.data
assert all_branches.data["InfrahubBranch"]["count"] == 12 # 10 created here + 1 created above + main branch

expected_branches = [
{
"description": {"value": "Default Branch"},
"name": {"value": "main"},
},
{
"description": {"value": "my description"},
"name": {"value": "branch3"},
},
*[
{
"description": {"value": f"sample description {i}"},
"name": {"value": f"sample-branch-{i}"},
}
for i in range(10)
],
]
Comment on lines +203 to +231
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix count and expected set; avoid cross-test state assumptions.

  • Count should exclude the global branch and only include branches created in this test + main. That’s 11, not 12.
  • expected_branches includes branch3 which is created in a different test; fixtures reset DB, so it won’t exist here.
  • Comparing a paginated page to the full set is incorrect.

Refactor to derive ordering from the API, then compare the correct slice:

-        assert all_branches.data["InfrahubBranch"]["count"] == 12  # 10 created here + 1 created above + main branch
-
-        expected_branches = [
-            {
-                "description": {"value": "Default Branch"},
-                "name": {"value": "main"},
-            },
-            {
-                "description": {"value": "my description"},
-                "name": {"value": "branch3"},
-            },
-            *[
-                {
-                    "description": {"value": f"sample description {i}"},
-                    "name": {"value": f"sample-branch-{i}"},
-                }
-                for i in range(10)
-            ],
-        ]
-        all_branches_data_only = [branch.get("node") for branch in all_branches.data["InfrahubBranch"]["edges"]]
-        assert all_branches_data_only.sort(key=lambda x: x["name"]["value"]) == expected_branches.sort(
-            key=lambda x: x["name"]["value"]
-        )
+        # Verify total count (10 created here + main), global branch is excluded by the API.
+        assert all_branches.data["InfrahubBranch"]["count"] == 11
+
+        # Establish server ordering with a baseline query, then validate the paginated slice.
+        baseline_query = """
+            query {
+                InfrahubBranch(offset: 0, limit: 1000) {
+                    edges {
+                        node {
+                            name { value }
+                            description { value }
+                        }
+                    }
+                }
+            }
+        """
+        gql_params = await prepare_graphql_params(db=db, branch=default_branch, service=service)
+        baseline = await graphql(schema=gql_params.schema, source=baseline_query, context_value=gql_params.context)
+        assert baseline.errors is None
+        all_nodes = [e["node"] for e in baseline.data["InfrahubBranch"]["edges"]]
+
+        # Current page: offset=2, limit=5
+        page_nodes = [e["node"] for e in all_branches.data["InfrahubBranch"]["edges"]]
+        assert len(page_nodes) == 5
+        assert page_nodes == all_nodes[2:7]

Committable suggestion skipped: line range outside the PR's diff.

all_branches_data_only = [branch.get("node") for branch in all_branches.data["InfrahubBranch"]["edges"]]
assert all_branches_data_only.sort(key=lambda x: x["name"]["value"]) == expected_branches.sort(
key=lambda x: x["name"]["value"]
)
Comment on lines +232 to +235
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Replace list.sort() with sorted; current assertion always compares None to None.

.list.sort() returns None. Use sorted(...) or sort in-place and compare the lists. The refactor above removes this pitfall and validates the actual page slice deterministically.

🤖 Prompt for AI Agents
In backend/tests/unit/graphql/queries/test_branch.py around lines 232 to 235,
the test uses list.sort() inside the assertion which returns None so the
assertion compares None to None; replace list.sort() with sorted(...) (or call
.sort() on each list before the assertion) so you compare the actual sorted
lists deterministically. Ensure both all_branches_data_only and
expected_branches are sorted by the same key (lambda x: x["name"]["value"]) and
then assert equality.

Loading
Loading