Skip to content

Commit d83ca7f

Browse files
authored
Merge pull request #4874 from opsmill/dga-20241107-pc-merge
Convert proposed change merge to a prefect task
2 parents c05e713 + e9e17ce commit d83ca7f

File tree

18 files changed

+384
-194
lines changed

18 files changed

+384
-194
lines changed

backend/infrahub/core/branch/tasks.py

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -148,68 +148,71 @@ async def merge_branch(branch: str) -> None:
148148
await add_branch_tag(branch_name=branch)
149149
await add_branch_tag(branch_name=registry.default_branch)
150150

151-
obj = await Branch.get_by_name(db=service.database, name=branch)
152-
component_registry = get_component_registry()
153-
154-
merger: BranchMerger | None = None
155-
async with lock.registry.global_graph_lock():
156-
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=service.database, branch=obj)
157-
diff_merger = await component_registry.get_component(DiffMerger, db=service.database, branch=obj)
158-
merger = BranchMerger(
159-
db=service.database,
160-
diff_coordinator=diff_coordinator,
161-
diff_merger=diff_merger,
162-
source_branch=obj,
163-
service=service,
164-
)
165-
try:
166-
await merger.merge()
167-
except Exception as exc:
168-
await merger.rollback()
169-
raise MergeFailedError(branch_name=branch) from exc
170-
await merger.update_schema()
171-
172-
if merger and merger.migrations:
173-
errors = await schema_apply_migrations(
174-
message=SchemaApplyMigrationData(
175-
branch=merger.destination_branch,
176-
new_schema=merger.destination_schema,
177-
previous_schema=merger.initial_source_schema,
178-
migrations=merger.migrations,
151+
async with service.database.start_session() as db:
152+
obj = await Branch.get_by_name(db=db, name=branch)
153+
component_registry = get_component_registry()
154+
155+
merger: BranchMerger | None = None
156+
async with lock.registry.global_graph_lock():
157+
# await update_diff(model=RequestDiffUpdate(branch_name=obj.name))
158+
159+
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj)
160+
diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj)
161+
merger = BranchMerger(
162+
db=db,
163+
diff_coordinator=diff_coordinator,
164+
diff_merger=diff_merger,
165+
source_branch=obj,
166+
service=service,
179167
)
168+
try:
169+
await merger.merge()
170+
except Exception as exc:
171+
await merger.rollback()
172+
raise MergeFailedError(branch_name=branch) from exc
173+
await merger.update_schema()
174+
175+
if merger and merger.migrations:
176+
errors = await schema_apply_migrations(
177+
message=SchemaApplyMigrationData(
178+
branch=merger.destination_branch,
179+
new_schema=merger.destination_schema,
180+
previous_schema=merger.initial_source_schema,
181+
migrations=merger.migrations,
182+
)
183+
)
184+
for error in errors:
185+
log.error(error)
186+
187+
# -------------------------------------------------------------
188+
# Trigger the reconciliation of IPAM data after the merge
189+
# -------------------------------------------------------------
190+
differ = await merger.get_graph_diff()
191+
diff_parser = IpamDiffParser(
192+
db=db,
193+
differ=differ,
194+
source_branch_name=obj.name,
195+
target_branch_name=registry.default_branch,
196+
)
197+
ipam_node_details = await diff_parser.get_changed_ipam_node_details()
198+
await service.workflow.submit_workflow(
199+
workflow=IPAM_RECONCILIATION,
200+
parameters={"branch": registry.default_branch, "ipam_node_details": ipam_node_details},
180201
)
181-
for error in errors:
182-
log.error(error)
183-
184-
# -------------------------------------------------------------
185-
# Trigger the reconciliation of IPAM data after the merge
186-
# -------------------------------------------------------------
187-
differ = await merger.get_graph_diff()
188-
diff_parser = IpamDiffParser(
189-
db=service.database,
190-
differ=differ,
191-
source_branch_name=obj.name,
192-
target_branch_name=registry.default_branch,
193-
)
194-
ipam_node_details = await diff_parser.get_changed_ipam_node_details()
195-
await service.workflow.submit_workflow(
196-
workflow=IPAM_RECONCILIATION,
197-
parameters={"branch": registry.default_branch, "ipam_node_details": ipam_node_details},
198-
)
199202

200-
# -------------------------------------------------------------
201-
# Generate an event to indicate that a branch has been merged
202-
# NOTE: we still need to convert this event and potentially pull
203-
# some tasks currently executed based on the event into this workflow
204-
# -------------------------------------------------------------
205-
log_data = get_log_data()
206-
request_id = log_data.get("request_id", "")
207-
message = messages.EventBranchMerge(
208-
source_branch=obj.name,
209-
target_branch=registry.default_branch,
210-
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=request_id),
211-
)
212-
await service.send(message=message)
203+
# -------------------------------------------------------------
204+
# Generate an event to indicate that a branch has been merged
205+
# NOTE: we still need to convert this event and potentially pull
206+
# some tasks currently executed based on the event into this workflow
207+
# -------------------------------------------------------------
208+
log_data = get_log_data()
209+
request_id = log_data.get("request_id", "")
210+
message = messages.EventBranchMerge(
211+
source_branch=obj.name,
212+
target_branch=registry.default_branch,
213+
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=request_id),
214+
)
215+
await service.send(message=message)
213216

214217

215218
@flow(name="branch-delete", flow_run_name="Delete branch {branch}")

backend/infrahub/core/constants/__init__.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"SchemaElementPathType",
1616
"UpdateSupport",
1717
"UpdateValidationErrorType",
18+
"ValidationError",
1819
]
1920

2021

@@ -207,37 +208,6 @@ def from_relationship(cls, relationship: RelationshipCardinality) -> PathType:
207208
return cls("relationship_many")
208209

209210

210-
class ProposedChangeState(InfrahubStringEnum):
211-
OPEN = "open"
212-
MERGED = "merged"
213-
CLOSED = "closed"
214-
CANCELED = "canceled"
215-
216-
def validate_state_check_run(self) -> None:
217-
if self == ProposedChangeState.OPEN:
218-
return
219-
220-
raise ValidationError(input_value="Unable to trigger check on proposed changes that aren't in the open state")
221-
222-
def validate_editability(self) -> None:
223-
if self in [ProposedChangeState.CANCELED, ProposedChangeState.MERGED]:
224-
raise ValidationError(
225-
input_value=f"A proposed change in the {self.value} state is not allowed to be updated"
226-
)
227-
228-
def validate_state_transition(self, updated_state: ProposedChangeState) -> None:
229-
if self == ProposedChangeState.OPEN:
230-
return
231-
232-
if self == ProposedChangeState.CLOSED and updated_state not in [
233-
ProposedChangeState.CANCELED,
234-
ProposedChangeState.OPEN,
235-
]:
236-
raise ValidationError(
237-
input_value="A closed proposed change is only allowed to transition to the open state"
238-
)
239-
240-
241211
class HashableModelState(InfrahubStringEnum):
242212
PRESENT = "present"
243213
ABSENT = "absent"

backend/infrahub/core/diff/data_check_synchronizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from enum import Enum
22

3-
from infrahub.core.constants import BranchConflictKeep, InfrahubKind, ProposedChangeState
3+
from infrahub.core.constants import BranchConflictKeep, InfrahubKind
44
from infrahub.core.integrity.object_conflict.conflict_recorder import ObjectConflictValidatorRecorder
55
from infrahub.core.manager import NodeManager
66
from infrahub.core.node import Node
77
from infrahub.database import InfrahubDatabase
88
from infrahub.exceptions import SchemaNotFoundError
9+
from infrahub.proposed_change.constants import ProposedChangeState
910

1011
from .conflicts_extractor import DiffConflictsExtractor
1112
from .model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot

backend/infrahub/core/diff/tasks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
from infrahub.dependencies.registry import get_component_registry
77
from infrahub.log import get_logger
88
from infrahub.services import services
9+
from infrahub.workflows.utils import add_branch_tag
910

1011
log = get_logger()
1112

1213

13-
@flow(name="diff-update")
14+
@flow(name="diff-update", flow_run_name="Update diff for branch {model.branch_name}")
1415
async def update_diff(model: RequestDiffUpdate) -> None:
1516
service = services.service
17+
await add_branch_tag(branch_name=model.branch_name)
18+
1619
component_registry = get_component_registry()
1720
base_branch = await registry.get_branch(db=service.database, branch=registry.default_branch)
1821
diff_branch = await registry.get_branch(db=service.database, branch=model.branch_name)
@@ -28,9 +31,10 @@ async def update_diff(model: RequestDiffUpdate) -> None:
2831
)
2932

3033

31-
@flow(name="diff-refresh")
34+
@flow(name="diff-refresh", flow_run_name="Recreate diff for branch {model.branch_name}")
3235
async def refresh_diff(model: RequestDiffRefresh) -> None:
3336
service = services.service
37+
await add_branch_tag(branch_name=model.branch_name)
3438

3539
component_registry = get_component_registry()
3640
base_branch = await registry.get_branch(db=service.database, branch=registry.default_branch)

backend/infrahub/core/relationship/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -843,23 +843,23 @@ async def get_peer(
843843
async def get_peers(
844844
self,
845845
db: InfrahubDatabase,
846-
branch_agnostic: bool,
847846
peer_type: type[PeerType],
847+
branch_agnostic: bool = ...,
848848
) -> Mapping[str, PeerType]: ...
849849

850850
@overload
851851
async def get_peers(
852852
self,
853853
db: InfrahubDatabase,
854-
branch_agnostic: bool,
855854
peer_type: Literal[None] = None,
855+
branch_agnostic: bool = ...,
856856
) -> Mapping[str, Node]: ...
857857

858858
async def get_peers(
859859
self,
860860
db: InfrahubDatabase,
861-
branch_agnostic: bool = False,
862861
peer_type: type[PeerType] | None = None, # pylint: disable=unused-argument
862+
branch_agnostic: bool = False,
863863
) -> Mapping[str, Node | PeerType]:
864864
rels = await self.get_relationships(db=db, branch_agnostic=branch_agnostic)
865865
peer_ids = [rel.peer_id for rel in rels if rel.peer_id]

backend/infrahub/core/schema/definitions/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
InfrahubKind,
1818
PermissionAction,
1919
PermissionDecision,
20-
ProposedChangeState,
2120
RelationshipDeleteBehavior,
2221
RepositoryInternalStatus,
2322
RepositoryOperationalStatus,
@@ -26,6 +25,7 @@
2625
ValidatorConclusion,
2726
ValidatorState,
2827
)
28+
from infrahub.proposed_change.constants import ProposedChangeState
2929

3030
# pylint: disable=too-many-lines
3131

backend/infrahub/database/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ async def execute_query_with_metadata(
349349

350350
with QUERY_EXECUTION_METRICS.labels(**labels).time():
351351
response = await self.run_query(query=query, params=params, name=name)
352+
if response is None:
353+
return [], {}
352354
results = [item async for item in response]
353355
return results, response._metadata or {}
354356

backend/infrahub/graphql/mutations/branch.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING, Any
44

55
import pydantic
6-
from graphene import Boolean, Field, InputField, InputObjectType, Mutation, ObjectType, String
6+
from graphene import Boolean, Field, InputField, InputObjectType, Mutation, String
77
from infrahub_sdk.utils import extract_fields, extract_fields_first_node
88
from opentelemetry import trace
99
from typing_extensions import Self
@@ -25,6 +25,7 @@
2525
)
2626

2727
from ..types import BranchType
28+
from ..types.task import TaskInfo
2829

2930
if TYPE_CHECKING:
3031
from graphql import GraphQLResolveInfo
@@ -37,10 +38,6 @@
3738
log = get_logger()
3839

3940

40-
class TaskInfo(ObjectType):
41-
id = Field(String)
42-
43-
4441
class BranchCreateInput(InputObjectType):
4542
id = String(required=False)
4643
name = String(required=True)

0 commit comments

Comments
 (0)