Skip to content

Commit e95caeb

Browse files
committed
Add wait_until_completion to InfrahubBranchManager.create
1 parent a0afaa0 commit e95caeb

File tree

3 files changed

+92
-5
lines changed

3 files changed

+92
-5
lines changed

infrahub_sdk/branch.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Optional, Union
3+
import warnings
4+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload
45
from urllib.parse import urlencode
56

67
from pydantic import BaseModel
@@ -72,13 +73,42 @@ class InfrahubBranchManager(InfraHubBranchManagerBase):
7273
def __init__(self, client: InfrahubClient):
7374
self.client = client
7475

76+
@overload
7577
async def create(
7678
self,
7779
branch_name: str,
7880
sync_with_git: bool = True,
7981
description: str = "",
80-
background_execution: bool = False,
81-
) -> BranchData:
82+
wait_until_completion: Literal[True] = True,
83+
background_execution: Optional[bool] = False,
84+
) -> BranchData: ...
85+
86+
@overload
87+
async def create(
88+
self,
89+
branch_name: str,
90+
sync_with_git: bool = True,
91+
description: str = "",
92+
wait_until_completion: Literal[False] = False,
93+
background_execution: Optional[bool] = False,
94+
) -> str: ...
95+
96+
async def create(
97+
self,
98+
branch_name: str,
99+
sync_with_git: bool = True,
100+
description: str = "",
101+
wait_until_completion: bool = True,
102+
background_execution: Optional[bool] = False,
103+
) -> Union[BranchData, str]:
104+
if background_execution is not None:
105+
warnings.warn(
106+
"`background_execution` is deprecated, please use `wait_until_completion` instead.",
107+
DeprecationWarning,
108+
stacklevel=1,
109+
)
110+
111+
background_execution = background_execution or not wait_until_completion
82112
input_data = {
83113
"background_execution": background_execution,
84114
"data": {
@@ -91,6 +121,10 @@ async def create(
91121
query = Mutation(mutation="BranchCreate", input_data=input_data, query=MUTATION_QUERY_DATA)
92122
response = await self.client.execute_graphql(query=query.render(), tracker="mutation-branch-create")
93123

124+
# Make sure server version is recent enough to support background execution, as previously
125+
# using background_execution=True had no effect.
126+
if not wait_until_completion and "task" in response["BranchCreate"]:
127+
return BranchData(**response["BranchCreate"]["task"]["id"])
94128
return BranchData(**response["BranchCreate"]["object"])
95129

96130
async def delete(self, branch_name: str) -> bool:
@@ -209,13 +243,42 @@ def get(self, branch_name: str) -> BranchData:
209243
raise BranchNotFoundError(identifier=branch_name)
210244
return BranchData(**data["Branch"][0])
211245

246+
@overload
247+
def create(
248+
self,
249+
branch_name: str,
250+
sync_with_git: bool = True,
251+
description: str = "",
252+
wait_until_completion: Literal[True] = True,
253+
background_execution: Optional[bool] = False,
254+
) -> BranchData: ...
255+
256+
@overload
257+
def create(
258+
self,
259+
branch_name: str,
260+
sync_with_git: bool = True,
261+
description: str = "",
262+
wait_until_completion: Literal[False] = False,
263+
background_execution: Optional[bool] = False,
264+
) -> str: ...
265+
212266
def create(
213267
self,
214268
branch_name: str,
215269
sync_with_git: bool = True,
216270
description: str = "",
217-
background_execution: bool = False,
218-
) -> BranchData:
271+
wait_until_completion: bool = True,
272+
background_execution: Optional[bool] = False,
273+
) -> Union[BranchData, str]:
274+
if background_execution is not None:
275+
warnings.warn(
276+
"`background_execution` is deprecated, please use `wait_until_completion` instead.",
277+
DeprecationWarning,
278+
stacklevel=1,
279+
)
280+
281+
background_execution = background_execution or not wait_until_completion
219282
input_data = {
220283
"background_execution": background_execution,
221284
"data": {
@@ -228,6 +291,10 @@ def create(
228291
query = Mutation(mutation="BranchCreate", input_data=input_data, query=MUTATION_QUERY_DATA)
229292
response = self.client.execute_graphql(query=query.render(), tracker="mutation-branch-create")
230293

294+
# Make sure server version is recent enough to support background execution, as previously
295+
# using background_execution=True had no effect.
296+
if not wait_until_completion and "task" in response["BranchCreate"]:
297+
return BranchData(**response["BranchCreate"]["task"]["id"])
231298
return BranchData(**response["BranchCreate"]["object"])
232299

233300
def delete(self, branch_name: str) -> bool:

tests/integration/test_infrahub_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from infrahub.server import app
1010

1111
from infrahub_sdk import Config, InfrahubClient
12+
from infrahub_sdk.branch import BranchData
1213
from infrahub_sdk.constants import InfrahubClientMode
1314
from infrahub_sdk.exceptions import BranchNotFoundError
1415
from infrahub_sdk.node import InfrahubNode
@@ -283,3 +284,12 @@ async def test_profile(self, client: InfrahubClient, db: InfrahubDatabase, init_
283284

284285
obj1 = await client.get(kind="BuiltinStatus", id=obj.id)
285286
assert obj1.description.value == "description in profile"
287+
288+
async def test_create_branch(self, client: InfrahubClient, db: InfrahubDatabase, init_db_base, base_dataset):
289+
branch = await client.branch.create(branch_name="new-branch-1")
290+
assert isinstance(branch, BranchData)
291+
assert branch.id is not None
292+
293+
async def test_create_branch_async(self, client: InfrahubClient, db: InfrahubDatabase, init_db_base, base_dataset):
294+
task_id = await client.branch.create(branch_name="new-branch-2", wait_until_completion=False)
295+
assert isinstance(task_id, str)

tests/integration/test_infrahub_client_sync.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from infrahub.server import app
99

1010
from infrahub_sdk import Config, InfrahubClientSync
11+
from infrahub_sdk.branch import BranchData
1112
from infrahub_sdk.constants import InfrahubClientMode
1213
from infrahub_sdk.exceptions import BranchNotFoundError
1314
from infrahub_sdk.node import InfrahubNodeSync
@@ -285,3 +286,12 @@ def test_profile(self, client: InfrahubClientSync, db: InfrahubDatabase, init_db
285286

286287
obj1 = client.get(kind="BuiltinStatus", id=obj.id)
287288
assert obj1.description.value == "description in profile"
289+
290+
def test_create_branch(self, client: InfrahubClientSync, db: InfrahubDatabase, init_db_base, base_dataset):
291+
branch = client.branch.create(branch_name="new-branch-1")
292+
assert isinstance(branch, BranchData)
293+
assert branch.id is not None
294+
295+
def test_create_branch_async(self, client: InfrahubClientSync, db: InfrahubDatabase, init_db_base, base_dataset):
296+
task_id = client.branch.create(branch_name="new-branch-2", wait_until_completion=False)
297+
assert isinstance(task_id, str)

0 commit comments

Comments
 (0)