diff --git a/infrahub_sdk/branch.py b/infrahub_sdk/branch.py index 713c0063..18d32020 100644 --- a/infrahub_sdk/branch.py +++ b/infrahub_sdk/branch.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Union +import warnings +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload from urllib.parse import urlencode from pydantic import BaseModel @@ -72,14 +73,44 @@ class InfrahubBranchManager(InfraHubBranchManagerBase): def __init__(self, client: InfrahubClient): self.client = client + @overload async def create( self, branch_name: str, sync_with_git: bool = True, description: str = "", - background_execution: bool = False, - ) -> BranchData: + wait_until_completion: Literal[True] = True, + background_execution: Optional[bool] = False, + ) -> BranchData: ... + + @overload + async def create( + self, + branch_name: str, + sync_with_git: bool = True, + description: str = "", + wait_until_completion: Literal[False] = False, + background_execution: Optional[bool] = False, + ) -> str: ... + + async def create( + self, + branch_name: str, + sync_with_git: bool = True, + description: str = "", + wait_until_completion: bool = True, + background_execution: Optional[bool] = False, + ) -> Union[BranchData, str]: + if background_execution is not None: + warnings.warn( + "`background_execution` is deprecated, please use `wait_until_completion` instead.", + DeprecationWarning, + stacklevel=1, + ) + + background_execution = background_execution or not wait_until_completion input_data = { + # Should be switched to `wait_until_completion` once `background_execution` is removed server side. "background_execution": background_execution, "data": { "name": branch_name, @@ -91,6 +122,10 @@ async def create( query = Mutation(mutation="BranchCreate", input_data=input_data, query=MUTATION_QUERY_DATA) response = await self.client.execute_graphql(query=query.render(), tracker="mutation-branch-create") + # Make sure server version is recent enough to support background execution, as previously + # using background_execution=True had no effect. + if background_execution and "task" in response["BranchCreate"]: + return BranchData(**response["BranchCreate"]["task"]["id"]) return BranchData(**response["BranchCreate"]["object"]) async def delete(self, branch_name: str) -> bool: @@ -209,14 +244,44 @@ def get(self, branch_name: str) -> BranchData: raise BranchNotFoundError(identifier=branch_name) return BranchData(**data["Branch"][0]) + @overload + def create( + self, + branch_name: str, + sync_with_git: bool = True, + description: str = "", + wait_until_completion: Literal[True] = True, + background_execution: Optional[bool] = False, + ) -> BranchData: ... + + @overload + def create( + self, + branch_name: str, + sync_with_git: bool = True, + description: str = "", + wait_until_completion: Literal[False] = False, + background_execution: Optional[bool] = False, + ) -> str: ... + def create( self, branch_name: str, sync_with_git: bool = True, description: str = "", - background_execution: bool = False, - ) -> BranchData: + wait_until_completion: bool = True, + background_execution: Optional[bool] = False, + ) -> Union[BranchData, str]: + if background_execution is not None: + warnings.warn( + "`background_execution` is deprecated, please use `wait_until_completion` instead.", + DeprecationWarning, + stacklevel=1, + ) + + background_execution = background_execution or not wait_until_completion input_data = { + # Should be switched to `wait_until_completion` once `background_execution` is removed server side. "background_execution": background_execution, "data": { "name": branch_name, @@ -228,6 +293,10 @@ def create( query = Mutation(mutation="BranchCreate", input_data=input_data, query=MUTATION_QUERY_DATA) response = self.client.execute_graphql(query=query.render(), tracker="mutation-branch-create") + # Make sure server version is recent enough to support background execution, as previously + # using background_execution=True had no effect. + if background_execution and "task" in response["BranchCreate"]: + return BranchData(**response["BranchCreate"]["task"]["id"]) return BranchData(**response["BranchCreate"]["object"]) def delete(self, branch_name: str) -> bool: diff --git a/tests/integration/test_infrahub_client.py b/tests/integration/test_infrahub_client.py index b07d79e9..34703ab2 100644 --- a/tests/integration/test_infrahub_client.py +++ b/tests/integration/test_infrahub_client.py @@ -9,6 +9,7 @@ from infrahub.server import app from infrahub_sdk import Config, InfrahubClient +from infrahub_sdk.branch import BranchData from infrahub_sdk.constants import InfrahubClientMode from infrahub_sdk.exceptions import BranchNotFoundError from infrahub_sdk.node import InfrahubNode @@ -283,3 +284,12 @@ async def test_profile(self, client: InfrahubClient, db: InfrahubDatabase, init_ obj1 = await client.get(kind="BuiltinStatus", id=obj.id) assert obj1.description.value == "description in profile" + + async def test_create_branch(self, client: InfrahubClient, db: InfrahubDatabase, init_db_base, base_dataset): + branch = await client.branch.create(branch_name="new-branch-1") + assert isinstance(branch, BranchData) + assert branch.id is not None + + async def test_create_branch_async(self, client: InfrahubClient, db: InfrahubDatabase, init_db_base, base_dataset): + task_id = await client.branch.create(branch_name="new-branch-2", wait_until_completion=False) + assert isinstance(task_id, str) diff --git a/tests/integration/test_infrahub_client_sync.py b/tests/integration/test_infrahub_client_sync.py index 5be96418..75429d4e 100644 --- a/tests/integration/test_infrahub_client_sync.py +++ b/tests/integration/test_infrahub_client_sync.py @@ -8,6 +8,7 @@ from infrahub.server import app from infrahub_sdk import Config, InfrahubClientSync +from infrahub_sdk.branch import BranchData from infrahub_sdk.constants import InfrahubClientMode from infrahub_sdk.exceptions import BranchNotFoundError from infrahub_sdk.node import InfrahubNodeSync @@ -285,3 +286,12 @@ def test_profile(self, client: InfrahubClientSync, db: InfrahubDatabase, init_db obj1 = client.get(kind="BuiltinStatus", id=obj.id) assert obj1.description.value == "description in profile" + + def test_create_branch(self, client: InfrahubClientSync, db: InfrahubDatabase, init_db_base, base_dataset): + branch = client.branch.create(branch_name="new-branch-1") + assert isinstance(branch, BranchData) + assert branch.id is not None + + def test_create_branch_async(self, client: InfrahubClientSync, db: InfrahubDatabase, init_db_base, base_dataset): + task_id = client.branch.create(branch_name="new-branch-2", wait_until_completion=False) + assert isinstance(task_id, str)