Skip to content

Commit 504b8f1

Browse files
authored
IFC-1053 Add mutation to trigger run of a generator (#5791)
1 parent 33b9c90 commit 504b8f1

File tree

7 files changed

+314
-5
lines changed

7 files changed

+314
-5
lines changed

backend/infrahub/generators/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class RequestGeneratorDefinitionRun(BaseModel):
3030

3131
generator_definition: ProposedChangeGeneratorDefinition = Field(..., description="The Generator Definition")
3232
branch: str = Field(..., description="The branch to target")
33+
target_members: list[str] = Field(default_factory=list, description="List of targets to run the generator for")
3334

3435

3536
class ProposedChangeGeneratorDefinition(BaseModel):

backend/infrahub/generators/tasks.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
from typing import TYPE_CHECKING, Any
5+
36
from infrahub_sdk.exceptions import ModuleImportError
47
from infrahub_sdk.node import InfrahubNode
58
from infrahub_sdk.protocols import CoreGeneratorInstance
69
from infrahub_sdk.schema.repository import InfrahubGeneratorDefinitionConfig
7-
from prefect import flow, task
10+
from prefect import State, flow, task
811
from prefect.cache_policies import NONE
12+
from prefect.states import Completed, Failed
913

1014
from infrahub import lock
1115
from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow
@@ -21,6 +25,9 @@
2125
from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN, REQUEST_GENERATOR_RUN
2226
from infrahub.workflows.utils import add_tags
2327

28+
if TYPE_CHECKING:
29+
from collections.abc import Coroutine
30+
2431

2532
@flow(
2633
name="generator-run",
@@ -154,7 +161,7 @@ async def run_generator_definition(branch: str, context: InfrahubContext, servic
154161
)
155162
async def request_generator_definition_run(
156163
model: RequestGeneratorDefinitionRun, context: InfrahubContext, service: InfrahubServices
157-
) -> None:
164+
) -> State[Any]:
158165
await add_tags(branches=[model.branch], nodes=[model.generator_definition.definition_id])
159166

160167
group = await service.client.get(
@@ -190,8 +197,13 @@ async def request_generator_definition_run(
190197
raise_when_missing=True,
191198
)
192199

200+
tasks: list[Coroutine[Any, Any, Any]] = []
193201
for relationship in group.members.peers:
194202
member = relationship.peer
203+
204+
if model.target_members and member.id not in model.target_members:
205+
continue
206+
195207
generator_instance = instance_by_member.get(member.id)
196208
request_generator_run_model = RequestGeneratorRun(
197209
generator_definition=model.generator_definition,
@@ -206,6 +218,14 @@ async def request_generator_definition_run(
206218
target_id=member.id,
207219
target_name=member.display_label,
208220
)
209-
await service.workflow.submit_workflow(
210-
workflow=REQUEST_GENERATOR_RUN, context=context, parameters={"model": request_generator_run_model}
221+
tasks.append(
222+
service.workflow.execute_workflow(
223+
workflow=REQUEST_GENERATOR_RUN, context=context, parameters={"model": request_generator_run_model}
224+
)
211225
)
226+
227+
try:
228+
await asyncio.gather(*tasks)
229+
return Completed(message=f"Successfully run {len(tasks)} generators")
230+
except Exception as exc:
231+
return Failed(message="One or more generators failed", error=exc)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from graphene import Boolean, Field, InputField, InputObjectType, List, Mutation, NonNull, String
6+
7+
from infrahub.core.manager import NodeManager
8+
from infrahub.generators.models import ProposedChangeGeneratorDefinition, RequestGeneratorDefinitionRun
9+
from infrahub.graphql.types.task import TaskInfo
10+
from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN
11+
12+
if TYPE_CHECKING:
13+
from graphql import GraphQLResolveInfo
14+
15+
from ..initialization import GraphqlContext
16+
17+
18+
class GeneratorDefinitionRequestRunInput(InputObjectType):
19+
id = InputField(String(required=True), description="ID of the generator definition to run")
20+
nodes = InputField(List(of_type=NonNull(String)), description="ID list of targets to run the generator for")
21+
22+
23+
class GeneratorDefinitionRequestRun(Mutation):
24+
class Arguments:
25+
data = GeneratorDefinitionRequestRunInput(required=True)
26+
wait_until_completion = Boolean(required=False)
27+
28+
ok = Boolean()
29+
task = Field(TaskInfo, required=False)
30+
31+
@classmethod
32+
async def mutate(
33+
cls,
34+
root: dict, # noqa: ARG003
35+
info: GraphQLResolveInfo,
36+
data: GeneratorDefinitionRequestRunInput,
37+
wait_until_completion: bool = True,
38+
) -> GeneratorDefinitionRequestRun:
39+
graphql_context: GraphqlContext = info.context
40+
db = graphql_context.db
41+
42+
generator_definition = await NodeManager.get_one(
43+
id=str(data.id), db=db, branch=graphql_context.branch, prefetch_relationships=True, raise_on_error=True
44+
)
45+
query = await generator_definition.query.get_peer(db=db)
46+
repository = await generator_definition.repository.get_peer(db=db)
47+
group = await generator_definition.targets.get_peer(db=db)
48+
49+
request_model = RequestGeneratorDefinitionRun(
50+
generator_definition=ProposedChangeGeneratorDefinition(
51+
definition_id=generator_definition.id,
52+
definition_name=generator_definition.name.value,
53+
class_name=generator_definition.class_name.value,
54+
file_path=generator_definition.file_path.value,
55+
query_name=query.name.value,
56+
query_models=query.models.value,
57+
repository_id=repository.id,
58+
parameters=generator_definition.parameters.value,
59+
group_id=group.id,
60+
convert_query_response=generator_definition.convert_query_response.value or False,
61+
),
62+
branch=graphql_context.branch.name,
63+
target_members=data.get("nodes", []),
64+
)
65+
66+
if not wait_until_completion:
67+
workflow = await graphql_context.active_service.workflow.submit_workflow(
68+
workflow=REQUEST_GENERATOR_DEFINITION_RUN,
69+
context=graphql_context.get_context(),
70+
parameters={"model": request_model},
71+
)
72+
return cls(ok=True, task={"id": workflow.id})
73+
74+
await graphql_context.active_service.workflow.execute_workflow(
75+
workflow=REQUEST_GENERATOR_DEFINITION_RUN,
76+
context=graphql_context.get_context(),
77+
parameters={"model": request_model},
78+
)
79+
return cls(ok=True)

backend/infrahub/graphql/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .mutations.computed_attribute import UpdateComputedAttribute
1919
from .mutations.diff import DiffUpdateMutation
2020
from .mutations.diff_conflict import ResolveDiffConflict
21+
from .mutations.generator import GeneratorDefinitionRequestRun
2122
from .mutations.proposed_change import ProposedChangeMerge, ProposedChangeRequestRunCheck
2223
from .mutations.relationship import (
2324
RelationshipAdd,
@@ -83,6 +84,7 @@ class InfrahubBaseMutation(ObjectType):
8384
InfrahubAccountTokenDelete = InfrahubAccountTokenDelete.Field()
8485
CoreProposedChangeRunCheck = ProposedChangeRequestRunCheck.Field()
8586
CoreProposedChangeMerge = ProposedChangeMerge.Field()
87+
CoreGeneratorDefinitionRun = GeneratorDefinitionRequestRun.Field()
8688

8789
IPPrefixPoolGetResource = IPPrefixPoolGetResource.Field()
8890
IPAddressPoolGetResource = IPAddressPoolGetResource.Field()
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import pytest
6+
from infrahub_sdk.graphql import Mutation
7+
from infrahub_sdk.protocols import CoreGeneratorDefinition
8+
from tests.constants import TestKind
9+
from tests.helpers.file_repo import FileRepo
10+
from tests.helpers.schema import CAR_SCHEMA, load_schema
11+
from tests.helpers.test_app import TestInfrahubApp
12+
13+
from infrahub.core.constants import InfrahubKind
14+
from infrahub.core.node import Node
15+
from infrahub.services.adapters.cache.redis import RedisCache
16+
17+
if TYPE_CHECKING:
18+
from pathlib import Path
19+
20+
from infrahub_sdk import InfrahubClient
21+
from tests.adapters.message_bus import BusSimulator
22+
23+
from infrahub.database import InfrahubDatabase
24+
25+
26+
class TestMutationGenerator(TestInfrahubApp):
27+
@pytest.fixture(scope="class")
28+
async def initial_dataset(
29+
self,
30+
db: InfrahubDatabase,
31+
initialize_registry: None,
32+
git_repos_source_dir_module_scope: Path,
33+
client: InfrahubClient,
34+
bus_simulator: BusSimulator,
35+
prefect_test_fixture,
36+
) -> None:
37+
await load_schema(db, schema=CAR_SCHEMA)
38+
39+
bus_simulator.service._cache = RedisCache()
40+
41+
john = await Node.init(schema=TestKind.PERSON, db=db)
42+
await john.new(db=db, name="John", height=175, age=25, description="The famous Joe Doe")
43+
await john.save(db=db)
44+
koenigsegg = await Node.init(schema=TestKind.MANUFACTURER, db=db)
45+
await koenigsegg.new(db=db, name="Koenigsegg")
46+
await koenigsegg.save(db=db)
47+
people = await Node.init(schema=InfrahubKind.STANDARDGROUP, db=db)
48+
await people.new(db=db, name="people", members=[john])
49+
await people.save(db=db)
50+
51+
jesko = await Node.init(schema=TestKind.CAR, db=db)
52+
await jesko.new(
53+
db=db,
54+
name="Jesko",
55+
color="Red",
56+
description="A limited production mid-engine sports car",
57+
owner=john,
58+
manufacturer=koenigsegg,
59+
)
60+
await jesko.save(db=db)
61+
62+
branch1 = await client.branch.create(branch_name="branch1")
63+
64+
FileRepo(name="car-dealership", sources_directory=git_repos_source_dir_module_scope)
65+
client_repository = await client.create(
66+
kind=InfrahubKind.REPOSITORY,
67+
data={"name": "car-dealership", "location": f"{git_repos_source_dir_module_scope}/car-dealership"},
68+
branch=branch1.name,
69+
)
70+
await client_repository.save()
71+
72+
richard = await Node.init(schema=TestKind.PERSON, db=db, branch=branch1.name)
73+
await richard.new(db=db, name="Richard", height=180, description="The less famous Richard Doe")
74+
await richard.save(db=db)
75+
76+
async def test_execute_generator(self, db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient) -> None:
77+
generator = await client.get(kind=CoreGeneratorDefinition, branch="branch1", name__value="cartags")
78+
mutation = Mutation(
79+
mutation="CoreGeneratorDefinitionRun", input_data={"data": {"id": generator.id}}, query={"ok": None}
80+
)
81+
response = await client.execute_graphql(query=mutation.render(), branch_name="branch1")
82+
assert response["CoreGeneratorDefinitionRun"]["ok"]
83+
84+
async def test_execute_generator_background(
85+
self, db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient
86+
) -> None:
87+
generator = await client.get(kind=CoreGeneratorDefinition, branch="branch1", name__value="cartags")
88+
mutation = Mutation(
89+
mutation="CoreGeneratorDefinitionRun",
90+
input_data={"data": {"id": generator.id}, "wait_until_completion": False},
91+
query={"ok": None, "task": {"id": None}},
92+
)
93+
response = await client.execute_graphql(query=mutation.render(), branch_name="branch1")
94+
assert response["CoreGeneratorDefinitionRun"]["ok"]
95+
assert response["CoreGeneratorDefinitionRun"]["task"]["id"]

backend/tests/unit/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ async def car_person_data_generic(db: InfrahubDatabase, register_core_models_sch
12051205
"""
12061206

12071207
q1 = await Node.init(db=db, schema=InfrahubKind.GRAPHQLQUERY)
1208-
await q1.new(db=db, name="query01", query=query)
1208+
await q1.new(db=db, name="query01", query=query, models=["TestPerson"])
12091209
await q1.save(db=db)
12101210

12111211
r1 = await Node.init(db=db, schema=InfrahubKind.REPOSITORY)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from unittest.mock import call, patch
2+
3+
import pytest
4+
5+
from infrahub.auth import AccountSession, AuthType
6+
from infrahub.context import InfrahubContext
7+
from infrahub.core.branch import Branch
8+
from infrahub.core.constants import InfrahubKind
9+
from infrahub.core.node import Node
10+
from infrahub.database import InfrahubDatabase
11+
from infrahub.generators.models import ProposedChangeGeneratorDefinition, RequestGeneratorDefinitionRun
12+
from infrahub.graphql.initialization import prepare_graphql_params
13+
from infrahub.services import InfrahubServices
14+
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
15+
from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN
16+
from tests.adapters.message_bus import BusRecorder
17+
from tests.helpers.graphql import graphql
18+
19+
20+
@pytest.fixture
21+
async def group1(db: InfrahubDatabase, car_person_data_generic: dict[str, Node]) -> Node:
22+
g1 = await Node.init(db=db, schema=InfrahubKind.STANDARDGROUP)
23+
await g1.new(db=db, name="group1", members=[car_person_data_generic["c1"], car_person_data_generic["c2"]])
24+
await g1.save(db=db)
25+
return g1
26+
27+
28+
@pytest.fixture
29+
async def definition1(db: InfrahubDatabase, car_person_data_generic: dict[str, Node], group1: Node) -> Node:
30+
gd1 = await Node.init(db=db, schema=InfrahubKind.GENERATORDEFINITION)
31+
await gd1.new(
32+
db=db,
33+
name="generatordef01",
34+
query=str(car_person_data_generic["q1"].id),
35+
repository=str(car_person_data_generic["r1"].id),
36+
file_path="generator01.py",
37+
class_name="Generator01",
38+
targets=str(group1.id),
39+
parameters={"value": {"name": "name__value"}},
40+
)
41+
await gd1.save(db=db)
42+
return gd1
43+
44+
45+
async def test_run_generator_definition(
46+
db: InfrahubDatabase,
47+
default_branch: Branch,
48+
register_core_models_schema,
49+
car_person_data_generic,
50+
create_test_admin: Node,
51+
definition1: Node,
52+
):
53+
query = """
54+
mutation {
55+
CoreGeneratorDefinitionRun(data: { id: "%s" }, wait_until_completion: false) {
56+
ok
57+
}
58+
}
59+
""" % (definition1.id)
60+
recorder = BusRecorder()
61+
service = await InfrahubServices.new(message_bus=recorder, workflow=WorkflowLocalExecution())
62+
63+
account_session = AccountSession(
64+
authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API
65+
)
66+
gql_params = await prepare_graphql_params(
67+
db=db, include_subscription=False, branch=default_branch, service=service, account_session=account_session
68+
)
69+
70+
with patch(
71+
"infrahub.services.adapters.workflow.local.WorkflowLocalExecution.submit_workflow"
72+
) as mock_submit_workflow:
73+
result = await graphql(
74+
schema=gql_params.schema,
75+
source=query,
76+
context_value=gql_params.context,
77+
root_value=None,
78+
variable_values={},
79+
)
80+
81+
assert not result.errors
82+
assert result.data
83+
assert result.data["CoreGeneratorDefinitionRun"]["ok"]
84+
85+
context = InfrahubContext.init(branch=default_branch, account=account_session)
86+
query = await definition1.query.get_peer(db=db)
87+
repository = await definition1.repository.get_peer(db=db)
88+
group = await definition1.targets.get_peer(db=db)
89+
expected_calls = [
90+
call(
91+
workflow=REQUEST_GENERATOR_DEFINITION_RUN,
92+
parameters={
93+
"model": RequestGeneratorDefinitionRun(
94+
generator_definition=ProposedChangeGeneratorDefinition(
95+
definition_id=definition1.id,
96+
definition_name=definition1.name.value,
97+
class_name=definition1.class_name.value,
98+
file_path=definition1.file_path.value,
99+
query_name=query.name.value,
100+
query_models=query.models.value,
101+
repository_id=repository.id,
102+
parameters=definition1.parameters.value,
103+
group_id=group.id,
104+
convert_query_response=definition1.convert_query_response.value,
105+
),
106+
branch=context.branch.name,
107+
)
108+
},
109+
context=context,
110+
),
111+
]
112+
mock_submit_workflow.assert_has_calls(expected_calls)

0 commit comments

Comments
 (0)