| 
 | 1 | +from unittest.mock import call, patch  | 
 | 2 | + | 
 | 3 | +import pytest  | 
 | 4 | + | 
 | 5 | +from infrahub.auth import AccountSession, AuthType  | 
 | 6 | +from infrahub.context import InfrahubContext  | 
 | 7 | +from infrahub.core.branch import Branch  | 
 | 8 | +from infrahub.core.constants import InfrahubKind  | 
 | 9 | +from infrahub.core.node import Node  | 
 | 10 | +from infrahub.database import InfrahubDatabase  | 
 | 11 | +from infrahub.generators.models import ProposedChangeGeneratorDefinition, RequestGeneratorDefinitionRun  | 
 | 12 | +from infrahub.graphql.initialization import prepare_graphql_params  | 
 | 13 | +from infrahub.services import InfrahubServices  | 
 | 14 | +from infrahub.services.adapters.workflow.local import WorkflowLocalExecution  | 
 | 15 | +from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN  | 
 | 16 | +from tests.adapters.message_bus import BusRecorder  | 
 | 17 | +from tests.helpers.graphql import graphql  | 
 | 18 | + | 
 | 19 | + | 
 | 20 | +@pytest.fixture  | 
 | 21 | +async def group1(db: InfrahubDatabase, car_person_data_generic: dict[str, Node]) -> Node:  | 
 | 22 | +    g1 = await Node.init(db=db, schema=InfrahubKind.STANDARDGROUP)  | 
 | 23 | +    await g1.new(db=db, name="group1", members=[car_person_data_generic["c1"], car_person_data_generic["c2"]])  | 
 | 24 | +    await g1.save(db=db)  | 
 | 25 | +    return g1  | 
 | 26 | + | 
 | 27 | + | 
 | 28 | +@pytest.fixture  | 
 | 29 | +async def definition1(db: InfrahubDatabase, car_person_data_generic: dict[str, Node], group1: Node) -> Node:  | 
 | 30 | +    gd1 = await Node.init(db=db, schema=InfrahubKind.GENERATORDEFINITION)  | 
 | 31 | +    await gd1.new(  | 
 | 32 | +        db=db,  | 
 | 33 | +        name="generatordef01",  | 
 | 34 | +        query=str(car_person_data_generic["q1"].id),  | 
 | 35 | +        repository=str(car_person_data_generic["r1"].id),  | 
 | 36 | +        file_path="generator01.py",  | 
 | 37 | +        class_name="Generator01",  | 
 | 38 | +        targets=str(group1.id),  | 
 | 39 | +        parameters={"value": {"name": "name__value"}},  | 
 | 40 | +    )  | 
 | 41 | +    await gd1.save(db=db)  | 
 | 42 | +    return gd1  | 
 | 43 | + | 
 | 44 | + | 
 | 45 | +async def test_run_generator_definition(  | 
 | 46 | +    db: InfrahubDatabase,  | 
 | 47 | +    default_branch: Branch,  | 
 | 48 | +    register_core_models_schema,  | 
 | 49 | +    car_person_data_generic,  | 
 | 50 | +    create_test_admin: Node,  | 
 | 51 | +    definition1: Node,  | 
 | 52 | +):  | 
 | 53 | +    query = """  | 
 | 54 | +    mutation {  | 
 | 55 | +        CoreGeneratorDefinitionRun(data: { id: "%s" }, wait_until_completion: false) {  | 
 | 56 | +            ok  | 
 | 57 | +        }  | 
 | 58 | +    }  | 
 | 59 | +    """ % (definition1.id)  | 
 | 60 | +    recorder = BusRecorder()  | 
 | 61 | +    service = await InfrahubServices.new(message_bus=recorder, workflow=WorkflowLocalExecution())  | 
 | 62 | + | 
 | 63 | +    account_session = AccountSession(  | 
 | 64 | +        authenticated=True, account_id=create_test_admin.id, session_id=None, auth_type=AuthType.API  | 
 | 65 | +    )  | 
 | 66 | +    gql_params = await prepare_graphql_params(  | 
 | 67 | +        db=db, include_subscription=False, branch=default_branch, service=service, account_session=account_session  | 
 | 68 | +    )  | 
 | 69 | + | 
 | 70 | +    with patch(  | 
 | 71 | +        "infrahub.services.adapters.workflow.local.WorkflowLocalExecution.submit_workflow"  | 
 | 72 | +    ) as mock_submit_workflow:  | 
 | 73 | +        result = await graphql(  | 
 | 74 | +            schema=gql_params.schema,  | 
 | 75 | +            source=query,  | 
 | 76 | +            context_value=gql_params.context,  | 
 | 77 | +            root_value=None,  | 
 | 78 | +            variable_values={},  | 
 | 79 | +        )  | 
 | 80 | + | 
 | 81 | +        assert not result.errors  | 
 | 82 | +        assert result.data  | 
 | 83 | +        assert result.data["CoreGeneratorDefinitionRun"]["ok"]  | 
 | 84 | + | 
 | 85 | +        context = InfrahubContext.init(branch=default_branch, account=account_session)  | 
 | 86 | +        query = await definition1.query.get_peer(db=db)  | 
 | 87 | +        repository = await definition1.repository.get_peer(db=db)  | 
 | 88 | +        group = await definition1.targets.get_peer(db=db)  | 
 | 89 | +        expected_calls = [  | 
 | 90 | +            call(  | 
 | 91 | +                workflow=REQUEST_GENERATOR_DEFINITION_RUN,  | 
 | 92 | +                parameters={  | 
 | 93 | +                    "model": RequestGeneratorDefinitionRun(  | 
 | 94 | +                        generator_definition=ProposedChangeGeneratorDefinition(  | 
 | 95 | +                            definition_id=definition1.id,  | 
 | 96 | +                            definition_name=definition1.name.value,  | 
 | 97 | +                            class_name=definition1.class_name.value,  | 
 | 98 | +                            file_path=definition1.file_path.value,  | 
 | 99 | +                            query_name=query.name.value,  | 
 | 100 | +                            query_models=query.models.value,  | 
 | 101 | +                            repository_id=repository.id,  | 
 | 102 | +                            parameters=definition1.parameters.value,  | 
 | 103 | +                            group_id=group.id,  | 
 | 104 | +                            convert_query_response=definition1.convert_query_response.value,  | 
 | 105 | +                        ),  | 
 | 106 | +                        branch=context.branch.name,  | 
 | 107 | +                    )  | 
 | 108 | +                },  | 
 | 109 | +                context=context,  | 
 | 110 | +            ),  | 
 | 111 | +        ]  | 
 | 112 | +        mock_submit_workflow.assert_has_calls(expected_calls)  | 
0 commit comments