Skip to content

Commit 7a0b12d

Browse files
authored
Merge pull request #5819 from opsmill/dga-20250222-diff-update-task
Update DiffUpdate mutation to return the id of the task when running asynchronously
2 parents 388e07e + 3146269 commit 7a0b12d

File tree

5 files changed

+114
-24
lines changed

5 files changed

+114
-24
lines changed

backend/infrahub/core/diff/tasks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from infrahub.log import get_logger
99
from infrahub.services import services
1010
from infrahub.workflows.catalogue import DIFF_REFRESH
11-
from infrahub.workflows.utils import add_branch_tag
11+
from infrahub.workflows.utils import add_tags
1212

1313
log = get_logger()
1414

1515

1616
@flow(name="diff-update", flow_run_name="Update diff for branch {model.branch_name}")
1717
async def update_diff(model: RequestDiffUpdate) -> None:
1818
service = services.service
19-
await add_branch_tag(branch_name=model.branch_name)
19+
await add_tags(branches=[model.branch_name])
2020

2121
async with service.database.start_session() as db:
2222
component_registry = get_component_registry()
@@ -37,7 +37,7 @@ async def update_diff(model: RequestDiffUpdate) -> None:
3737
@flow(name="diff-refresh", flow_run_name="Recreate diff for branch {branch_name}")
3838
async def refresh_diff(branch_name: str, diff_id: str) -> None:
3939
service = services.service
40-
await add_branch_tag(branch_name=branch_name)
40+
await add_tags(branches=[branch_name])
4141

4242
async with service.database.start_session() as db:
4343
component_registry = get_component_registry()
@@ -51,7 +51,7 @@ async def refresh_diff(branch_name: str, diff_id: str) -> None:
5151
@flow(name="diff-refresh-all", flow_run_name="Recreate all diffs for branch {branch_name}")
5252
async def refresh_diff_all(branch_name: str) -> None:
5353
service = services.service
54-
await add_branch_tag(branch_name=branch_name)
54+
await add_tags(branches=[branch_name])
5555

5656
async with service.database.start_session() as db:
5757
component_registry = get_component_registry()

backend/infrahub/graphql/mutations/diff.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TYPE_CHECKING
22

3-
from graphene import Boolean, DateTime, InputObjectType, Mutation, String
3+
from graphene import Boolean, DateTime, Field, InputObjectType, Mutation, String
44
from graphql import GraphQLResolveInfo
55

66
from infrahub.core import registry
@@ -14,6 +14,8 @@
1414
from infrahub.exceptions import ValidationError
1515
from infrahub.workflows.catalogue import DIFF_UPDATE
1616

17+
from ..types.task import TaskInfo
18+
1719
if TYPE_CHECKING:
1820
from ..initialization import GraphqlContext
1921

@@ -23,14 +25,16 @@ class DiffUpdateInput(InputObjectType):
2325
name = String(required=False)
2426
from_time = DateTime(required=False)
2527
to_time = DateTime(required=False)
26-
wait_for_completion = Boolean(required=False)
28+
wait_for_completion = Boolean(required=False, deprecation_reason="Please use `wait_until_completion` instead")
2729

2830

2931
class DiffUpdateMutation(Mutation):
3032
class Arguments:
3133
data = DiffUpdateInput(required=True)
34+
wait_until_completion = Boolean(required=False)
3235

3336
ok = Boolean()
37+
task = Field(TaskInfo, required=False)
3438

3539
@classmethod
3640
@retry_db_transaction(name="diff_update")
@@ -39,9 +43,13 @@ async def mutate(
3943
root: dict, # pylint: disable=unused-argument
4044
info: GraphQLResolveInfo,
4145
data: DiffUpdateInput,
42-
) -> dict[str, bool]:
46+
wait_until_completion: bool = False,
47+
) -> dict[str, bool | dict[str, str]]:
4348
context: GraphqlContext = info.context
4449

50+
if data.wait_for_completion is True:
51+
wait_until_completion = True
52+
4553
from_timestamp_str = DateTime.serialize(data.from_time) if data.from_time else None
4654
to_timestamp_str = DateTime.serialize(data.to_time) if data.to_time else None
4755
if (data.from_time or data.to_time) and not data.name:
@@ -53,11 +61,11 @@ async def mutate(
5361
diff_repository = await component_registry.get_component(DiffRepository, db=context.db, branch=diff_branch)
5462

5563
tracking_id = NameTrackingId(name=data.name)
56-
existing_diffs_metatdatas = await diff_repository.get_roots_metadata(
64+
existing_diffs_metadatas = await diff_repository.get_roots_metadata(
5765
diff_branch_names=[diff_branch.name], base_branch_names=[base_branch.name], tracking_id=tracking_id
5866
)
59-
if existing_diffs_metatdatas:
60-
metadata = existing_diffs_metatdatas[0]
67+
if existing_diffs_metadatas:
68+
metadata = existing_diffs_metadatas[0]
6169
from_time = Timestamp(from_timestamp_str) if from_timestamp_str else None
6270
to_time = Timestamp(to_timestamp_str) if to_timestamp_str else None
6371
branched_from_timestamp = Timestamp(diff_branch.get_branched_from())
@@ -68,7 +76,7 @@ async def mutate(
6876
if to_time and to_time < metadata.to_time:
6977
raise ValidationError(f"to_time must be null or greater than or equal to {metadata.to_time}")
7078

71-
if data.wait_for_completion is True:
79+
if wait_until_completion is True:
7280
diff_coordinator = await component_registry.get_component(
7381
DiffCoordinator, db=context.db, branch=diff_branch
7482
)
@@ -89,6 +97,7 @@ async def mutate(
8997
to_time=to_timestamp_str,
9098
)
9199
if context.service:
92-
await context.service.workflow.submit_workflow(workflow=DIFF_UPDATE, parameters={"model": model})
100+
workflow = await context.service.workflow.submit_workflow(workflow=DIFF_UPDATE, parameters={"model": model})
101+
return {"ok": True, "task": {"id": str(workflow.id)}}
93102

94103
return {"ok": True}

backend/tests/integration/diff/test_diff_update.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
BRANCH_NAME = "branch1"
3333
PROPOSED_CHANGE_NAME = "branch1-pc"
3434
DIFF_UPDATE_QUERY = """
35-
mutation DiffUpdate($branch_name: String!, $wait_for_completion: Boolean) {
36-
DiffUpdate(data: { branch: $branch_name, wait_for_completion: $wait_for_completion }) {
35+
mutation DiffUpdate($branch_name: String!, $wait_for_completion: Boolean = true) {
36+
DiffUpdate(data: { branch: $branch_name }, wait_until_completion: $wait_for_completion) {
3737
ok
3838
}
3939
}

backend/tests/unit/graphql/diff/test_diff_update_mutation.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,27 @@
88
from infrahub.core.timestamp import Timestamp
99
from infrahub.database import InfrahubDatabase
1010
from infrahub.graphql.initialization import prepare_graphql_params
11+
from infrahub.services import InfrahubServices, services
12+
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
13+
from tests.adapters.cache import MemoryCache
14+
from tests.adapters.message_bus import BusRecorder
1115
from tests.helpers.graphql import graphql
1216

1317
DIFF_UPDATE_MUTATION = """
14-
mutation($branch: String!, $name: String, $from_time: DateTime, $to_time: DateTime) {
15-
DiffUpdate(data: {branch: $branch, name: $name, from_time: $from_time, to_time: $to_time, wait_for_completion: true}) {
18+
mutation($branch: String!, $name: String, $from_time: DateTime, $to_time: DateTime, $wait_until_completion: Boolean = true) {
19+
DiffUpdate(
20+
data: {
21+
branch: $branch,
22+
name: $name,
23+
from_time: $from_time,
24+
to_time: $to_time
25+
},
26+
wait_until_completion: $wait_until_completion
27+
) {
1628
ok
29+
task {
30+
id
31+
}
1732
}
1833
}
1934
"""
@@ -22,16 +37,34 @@
2237
class TestDiffUpdateMutation:
2338
diff_name = "CountDiffula"
2439

40+
@pytest.fixture
41+
def service_testing(self, db: InfrahubDatabase):
42+
original = services.service
43+
service = InfrahubServices(
44+
database=db, message_bus=BusRecorder(), workflow=WorkflowLocalExecution(), cache=MemoryCache()
45+
)
46+
services.service = service
47+
services.prepare(service=services.service)
48+
yield service
49+
services.service = original
50+
services.prepare(service=services.service)
51+
2552
@pytest.fixture
2653
async def diff_branch(self, db: InfrahubDatabase) -> Branch:
2754
return await create_branch(db=db, branch_name="branch")
2855

2956
@pytest.fixture
3057
async def named_diff(
31-
self, db: InfrahubDatabase, default_branch: Branch, criticality_schema, diff_branch: Branch
58+
self,
59+
db: InfrahubDatabase,
60+
default_branch: Branch,
61+
prefect_test_fixture,
62+
service_testing: InfrahubServices,
63+
criticality_schema,
64+
diff_branch: Branch,
3265
) -> EnrichedDiffRootMetadata:
3366
params = await prepare_graphql_params(
34-
db=db, include_mutation=True, include_subscription=False, branch=default_branch
67+
db=db, include_mutation=True, include_subscription=False, branch=default_branch, service=service_testing
3568
)
3669
result = await graphql(
3770
schema=params.schema,
@@ -53,11 +86,17 @@ async def named_diff(
5386
)[0]
5487

5588
async def test_create_diff_before_branched_from_fails(
56-
self, db: InfrahubDatabase, default_branch: Branch, criticality_schema, diff_branch: Branch
89+
self,
90+
db: InfrahubDatabase,
91+
default_branch: Branch,
92+
prefect_test_fixture,
93+
service_testing: InfrahubServices,
94+
criticality_schema,
95+
diff_branch: Branch,
5796
):
5897
branched_from_timestamp = Timestamp(diff_branch.get_branched_from())
5998
params = await prepare_graphql_params(
60-
db=db, include_mutation=True, include_subscription=False, branch=default_branch
99+
db=db, include_mutation=True, include_subscription=False, branch=default_branch, service=service_testing
61100
)
62101
result = await graphql(
63102
schema=params.schema,
@@ -74,11 +113,17 @@ async def test_create_diff_before_branched_from_fails(
74113
assert result.data["DiffUpdate"]["ok"] is True
75114

76115
async def test_create_time_range_diff_without_name_fails(
77-
self, db: InfrahubDatabase, default_branch: Branch, criticality_schema, diff_branch: Branch
116+
self,
117+
db: InfrahubDatabase,
118+
default_branch: Branch,
119+
prefect_test_fixture,
120+
service_testing: InfrahubServices,
121+
criticality_schema,
122+
diff_branch: Branch,
78123
):
79124
branched_from_timestamp = Timestamp(diff_branch.get_branched_from())
80125
params = await prepare_graphql_params(
81-
db=db, include_mutation=True, include_subscription=False, branch=default_branch
126+
db=db, include_mutation=True, include_subscription=False, branch=default_branch, service=service_testing
82127
)
83128
result = await graphql(
84129
schema=params.schema,
@@ -99,12 +144,14 @@ async def test_create_diff_with_illegal_times_fails(
99144
self,
100145
db: InfrahubDatabase,
101146
default_branch: Branch,
147+
prefect_test_fixture,
148+
service_testing: InfrahubServices,
102149
criticality_schema,
103150
diff_branch: Branch,
104151
named_diff: EnrichedDiffRootMetadata,
105152
):
106153
params = await prepare_graphql_params(
107-
db=db, include_mutation=True, include_subscription=False, branch=default_branch
154+
db=db, include_mutation=True, include_subscription=False, branch=default_branch, service=service_testing
108155
)
109156
result = await graphql(
110157
schema=params.schema,
@@ -140,13 +187,44 @@ async def test_create_named_diff_with_legal_times_succeeds(
140187
self,
141188
db: InfrahubDatabase,
142189
default_branch: Branch,
190+
prefect_test_fixture,
191+
service_testing: InfrahubServices,
192+
criticality_schema,
193+
diff_branch: Branch,
194+
named_diff: EnrichedDiffRootMetadata,
195+
):
196+
branched_from_timestamp = Timestamp(diff_branch.get_branched_from())
197+
params = await prepare_graphql_params(
198+
db=db, include_mutation=True, include_subscription=False, branch=default_branch, service=service_testing
199+
)
200+
result = await graphql(
201+
schema=params.schema,
202+
source=DIFF_UPDATE_MUTATION,
203+
context_value=params.context,
204+
root_value=None,
205+
variable_values={
206+
"branch": diff_branch.name,
207+
"from_time": branched_from_timestamp.to_string(),
208+
"to_time": Timestamp().to_string(),
209+
"name": self.diff_name,
210+
},
211+
)
212+
assert result.errors is None
213+
assert result.data["DiffUpdate"]["ok"] is True
214+
215+
async def test_retrieve_task_id(
216+
self,
217+
db: InfrahubDatabase,
218+
default_branch: Branch,
219+
prefect_test_fixture,
220+
service_testing: InfrahubServices,
143221
criticality_schema,
144222
diff_branch: Branch,
145223
named_diff: EnrichedDiffRootMetadata,
146224
):
147225
branched_from_timestamp = Timestamp(diff_branch.get_branched_from())
148226
params = await prepare_graphql_params(
149-
db=db, include_mutation=True, include_subscription=False, branch=default_branch
227+
db=db, include_mutation=True, include_subscription=False, branch=default_branch, service=service_testing
150228
)
151229
result = await graphql(
152230
schema=params.schema,
@@ -158,7 +236,9 @@ async def test_create_named_diff_with_legal_times_succeeds(
158236
"from_time": branched_from_timestamp.to_string(),
159237
"to_time": Timestamp().to_string(),
160238
"name": self.diff_name,
239+
"wait_until_completion": False,
161240
},
162241
)
163242
assert result.errors is None
164243
assert result.data["DiffUpdate"]["ok"] is True
244+
assert result.data["DiffUpdate"]["task"]["id"] is not None

changelog/+diff-update.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update DiffUpdate mutation to return the id of the task when `wait_until_completion` is False. Also the argument `wait_for_completion` under data is deprecated and it has been replaced with `wait_until_completion` at the root of the mutation instead to align with the format of the other mutations.

0 commit comments

Comments
 (0)