- 
                Notifications
    
You must be signed in to change notification settings  - Fork 35
 
IFC-1886: Paginated branch graphql query #7418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 14 commits
1102ae7
              4e16ec7
              c0b3b31
              9a07e23
              41b1849
              c8fbca1
              dd242d9
              9f1dc4a
              3a7ff90
              e659340
              2ad4b2d
              32c72ad
              a38cec2
              84fac0a
              f7d4828
              491ad04
              86dbefa
              9eba0b8
              9fe2763
              8e4dca9
              f0aa561
              089b015
              fb7d0a6
              1157cc4
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -2,10 +2,10 @@ | |
| 
     | 
||
| from typing import TYPE_CHECKING, Any | ||
| 
     | 
||
| from graphene import ID, Field, List, NonNull, String | ||
| from graphene import ID, Field, Int, List, NonNull, String | ||
| 
     | 
||
| from infrahub.graphql.field_extractor import extract_graphql_fields | ||
| from infrahub.graphql.types import BranchType | ||
| from infrahub.graphql.types import BranchType, InfrahubBranch, InfrahubBranchType | ||
| 
     | 
||
| if TYPE_CHECKING: | ||
| from graphql import GraphQLResolveInfo | ||
| 
        
          
        
         | 
    @@ -28,3 +28,33 @@ async def branch_resolver( | |
| resolver=branch_resolver, | ||
| required=True, | ||
| ) | ||
| 
     | 
||
| 
     | 
||
| async def infrahub_branch_resolver( | ||
| root: dict, # noqa: ARG001 | ||
| info: GraphQLResolveInfo, | ||
| limit: int | None = None, | ||
| offset: int | None = None, | ||
| ) -> dict[str, Any]: | ||
| 
         
      Comment on lines
    
      +33
     to 
      +38
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add validation for pagination parameters. The  Apply this diff to add validation:  async def infrahub_branch_resolver(
     root: dict,  # noqa: ARG001
     info: GraphQLResolveInfo,
     limit: int | None = None,
     offset: int | None = None,
 ) -> dict[str, Any]:
+    """..."""  # Add docstring as per previous comment
+    
+    # Validate and set defaults for pagination parameters
+    if limit is None:
+        limit = 100
+    if offset is None:
+        offset = 0
+    
+    if limit < 1:
+        raise ValueError("limit must be at least 1")
+    if offset < 0:
+        raise ValueError("offset must be non-negative")
+    
     fields = extract_graphql_fields(info)🤖 Prompt for AI Agents | 
||
| fields = extract_graphql_fields(info) | ||
| result: dict[str, Any] = {} | ||
| if "edges" in fields: | ||
| branches = await InfrahubBranch.get_list( | ||
| graphql_context=info.context, fields=fields.get("edges", {}).get("node", {}), limit=limit, offset=offset | ||
| ) | ||
| result["edges"] = [{"node": branch} for branch in branches] | ||
| if "count" in fields: | ||
| result["count"] = await InfrahubBranchType.get_list_count(graphql_context=info.context) | ||
| return result | ||
| 
         
      Comment on lines
    
      +33
     to 
      +48
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolver ignores ids/name and forwards None pagination; fix signature, validation, and forwarding. Currently ids/name are accepted by the field but dropped by the resolver; passing limit/offset=None can also override model defaults. Add proper params, validate, and only pass provided args. As per coding guidelines, add a docstring too. Apply: -async def infrahub_branch_resolver(
-    root: dict,  # noqa: ARG001
-    info: GraphQLResolveInfo,
-    limit: int | None = None,
-    offset: int | None = None,
-) -> dict[str, Any]:
-    fields = extract_graphql_fields(info)
-    result: dict[str, Any] = {}
-    if "edges" in fields:
-        branches = await InfrahubBranch.get_list(
-            graphql_context=info.context, fields=fields.get("edges", {}).get("node", {}), limit=limit, offset=offset
-        )
-        result["edges"] = [{"node": branch} for branch in branches]
-    if "count" in fields:
-        result["count"] = await InfrahubBranchType.get_list_count(graphql_context=info.context)
-    return result
+async def infrahub_branch_resolver(
+    root: dict,  # noqa: ARG001
+    info: GraphQLResolveInfo,
+    ids: list[str] | None = None,
+    name: str | None = None,
+    limit: int | None = None,
+    offset: int | None = None,
+) -> dict[str, Any]:
+    """Resolve paginated InfrahubBranch query.
+    
+    Args:
+        root: Unused GraphQL root value.
+        info: Resolve info with context and selection set.
+        ids: Optional list of branch IDs to filter.
+        name: Optional branch name filter.
+        limit: Max items to return; must be >= 1 if provided.
+        offset: Number of items to skip; must be >= 0 if provided.
+    Returns:
+        Dict containing requested fields (edges and/or count).
+    """
+    fields = extract_graphql_fields(info)
+    result: dict[str, Any] = {}
+
+    # Normalize pagination; only pass if provided and valid
+    query_kwargs: dict[str, int] = {}
+    if isinstance(limit, int):
+        if limit < 1:
+            raise ValueError("limit must be >= 1")
+        query_kwargs["limit"] = limit
+    if isinstance(offset, int):
+        if offset < 0:
+            raise ValueError("offset must be >= 0")
+        query_kwargs["offset"] = offset
+
+    if "edges" in fields:
+        node_fields = fields.get("edges", {}).get("node", {})
+        branches = await InfrahubBranch.get_list(
+            graphql_context=info.context,
+            fields=node_fields,
+            ids=ids,
+            name=name,
+            **query_kwargs,
+        )
+        result["edges"] = [{"node": branch} for branch in branches]
+
+    if "count" in fields:
+        result["count"] = await InfrahubBranchType.get_list_count(
+            graphql_context=info.context, ids=ids, name=name
+        )
+    return resultAs per coding guidelines. 🤖 Prompt for AI Agents | 
||
| 
     | 
||
| 
     | 
||
| InfrahubBranchQueryList = Field( | ||
| InfrahubBranchType, | ||
| ids=List(of_type=NonNull(ID)), | ||
| name=String(), | ||
                
       | 
||
| offset=Int(), | ||
| limit=Int(), | ||
| description="Retrieve paginated information about active branches.", | ||
| resolver=infrahub_branch_resolver, | ||
| required=True, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -2,11 +2,12 @@ | |
| 
     | 
||
| from typing import TYPE_CHECKING, Any | ||
| 
     | 
||
| from graphene import Boolean, Field, String | ||
| from graphene import Boolean, Field, Int, List, NonNull, String | ||
| 
     | 
||
| from infrahub.core.branch import Branch | ||
| from infrahub.core.constants import GLOBAL_BRANCH_NAME | ||
| 
     | 
||
| from ...exceptions import BranchNotFoundError | ||
| from .standard_node import InfrahubObjectType | ||
| 
     | 
||
| if TYPE_CHECKING: | ||
| 
        
          
        
         | 
    @@ -30,6 +31,10 @@ class Meta: | |
| name = "Branch" | ||
| model = Branch | ||
| 
     | 
||
| @staticmethod | ||
| async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]: | ||
| return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME] | ||
| 
         
      Comment on lines
    
      +36
     to 
      +38
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add method docstring. The  As per coding guidelines, apply this diff:      @staticmethod
     async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
+        """Map Branch objects to GraphQL field dictionaries, excluding the global branch.
+        
+        Args:
+            objs: List of Branch objects to map.
+            fields: Dictionary of field names to include in the output.
+            
+        Returns:
+            List of dictionaries containing GraphQL field data for each branch.
+        """
         return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME]🤖 Prompt for AI Agents | 
||
| 
     | 
||
| @classmethod | ||
| async def get_list( | ||
| cls, | ||
| 
        
          
        
         | 
    @@ -43,4 +48,72 @@ async def get_list( | |
| if not objs: | ||
| return [] | ||
| 
     | 
||
| return [await obj.to_graphql(fields=fields) for obj in objs if obj.name != GLOBAL_BRANCH_NAME] | ||
| return await cls._map_fields_to_graphql(objs=objs, fields=fields) | ||
| 
     | 
||
| 
     | 
||
| class RequiredStringValueField(InfrahubObjectType): | ||
| value = String(required=True) | ||
| 
     | 
||
| 
     | 
||
| class NonRequiredStringValueField(InfrahubObjectType): | ||
| value = String(required=False) | ||
| 
     | 
||
| 
     | 
||
| class NonRequiredBooleanValueField(InfrahubObjectType): | ||
| value = Boolean(required=False) | ||
| 
     | 
||
| 
     | 
||
| class InfrahubBranch(BranchType): | ||
| id = String(required=True) | ||
| created_at = String(required=False) | ||
| 
     | 
||
| name = Field(RequiredStringValueField, required=True) | ||
| description = Field(NonRequiredStringValueField, required=False) | ||
| origin_branch = Field(NonRequiredStringValueField, required=False) | ||
| branched_from = Field(NonRequiredStringValueField, required=False) | ||
| sync_with_git = Field(NonRequiredBooleanValueField, required=False) | ||
| is_default = Field(NonRequiredBooleanValueField, required=False) | ||
| is_isolated = Field( | ||
| NonRequiredBooleanValueField, required=False, deprecation_reason="non isolated mode is not supported anymore" | ||
| ) | ||
| has_schema_changes = Field(NonRequiredBooleanValueField, required=False) | ||
| 
     | 
||
| class Meta: | ||
| description = "InfrahubBranch" | ||
| name = "InfrahubBranch" | ||
| 
     | 
||
| @staticmethod | ||
| async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]: | ||
| field_keys = fields.keys() | ||
| result: list[dict[str, Any]] = [] | ||
| for obj in objs: | ||
| if obj.name == GLOBAL_BRANCH_NAME: | ||
| continue | ||
| data = {} | ||
| for field in field_keys: | ||
| value = getattr(obj, field, None) | ||
| if isinstance(fields.get(field), dict): | ||
| data[field] = {"value": value} | ||
| else: | ||
| data[field] = value | ||
| result.append(data) | ||
| return result | ||
| 
         
      Comment on lines
    
      92
     to 
      110
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add method docstring and clarify field wrapping logic. The  As per coding guidelines, apply this diff:      @staticmethod
     async def _map_fields_to_graphql(objs: list[Branch], fields: dict) -> list[dict[str, Any]]:
+        """Map Branch objects to GraphQL field dictionaries with wrapped field values.
+        
+        Fields that have nested field selections in the GraphQL query (indicated by
+        dict values in the fields parameter) are wrapped in {"value": ...} structures
+        to match the value wrapper types. Other fields are returned as-is.
+        
+        Args:
+            objs: List of Branch objects to map.
+            fields: Dictionary mapping field names to their selection sets.
+            
+        Returns:
+            List of dictionaries with field data, excluding GLOBAL_BRANCH_NAME.
+        """
         field_keys = fields.keys()
         result: list[dict[str, Any]] = []
         for obj in objs:
             if obj.name == GLOBAL_BRANCH_NAME:
                 continue
             data = {}
             for field in field_keys:
                 value = getattr(obj, field, None)
+                # Wrap value if the field has nested selections (value wrapper types)
                 if isinstance(fields.get(field), dict):
                     data[field] = {"value": value}
                 else:
                     data[field] = value
             result.append(data)
         return result🤖 Prompt for AI Agents | 
||
| 
     | 
||
| 
     | 
||
| class InfrahubBranchEdge(InfrahubObjectType): | ||
| node = Field(InfrahubBranch, required=True) | ||
| 
     | 
||
| 
     | 
||
| class InfrahubBranchType(InfrahubObjectType): | ||
| count = Field(Int, description="Total number of items") | ||
| edges = Field(NonNull(List(of_type=NonNull(InfrahubBranchEdge)))) | ||
| 
     | 
||
| @classmethod | ||
| async def get_list_count(cls, graphql_context: GraphqlContext, **kwargs: Any) -> int: | ||
| async with graphql_context.db.start_session(read_only=True) as db: | ||
| count = await Branch.get_list_count(db=db, **kwargs) | ||
| try: | ||
| await Branch.get_by_name(name=GLOBAL_BRANCH_NAME, db=db) | ||
| return count - 1 | ||
| except BranchNotFoundError: | ||
| return count | ||
| 
         
      Comment on lines
    
      +121
     to 
      +129
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add method docstring (regression). This method's docstring was marked as addressed in a previous commit (32c72ad) but is missing in the current code. Per coding guidelines, add a Google-style docstring documenting the purpose, parameters, and return value. Apply this diff:      @classmethod
     async def get_list_count(cls, graphql_context: GraphqlContext, **kwargs: Any) -> int:
+        """Get the total count of branches excluding the global branch.
+        
+        Args:
+            graphql_context: The GraphQL context containing database connection.
+            **kwargs: Additional filter parameters passed to Branch.get_list_count.
+            
+        Returns:
+            The count of branches, excluding GLOBAL_BRANCH_NAME if it exists.
+        """
         async with graphql_context.db.start_session(read_only=True) as db:🤖 Prompt for AI Agents | 
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add docstring and type annotation.
The class is missing a docstring and type annotation, which are required by the coding guidelines.
Apply this diff to add the missing docstring and type annotation:
As per coding guidelines.
📝 Committable suggestion
🤖 Prompt for AI Agents