Skip to content

Commit da2235a

Browse files
committed
Add integration tests for Prefect worker
1 parent 9be1d2f commit da2235a

File tree

19 files changed

+526
-120
lines changed

19 files changed

+526
-120
lines changed

backend/infrahub/api/schema.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,13 @@ async def load_schema(
279279
schema_branch=candidate_schema,
280280
constraints=result.constraints,
281281
)
282-
error_messages = await service.workflow.execute( # type: ignore[var-annotated]
283-
workflow=SCHEMA_VALIDATE_MIGRATION, message=validate_migration_data
282+
error_messages = await service.workflow.execute_workflow(
283+
workflow=SCHEMA_VALIDATE_MIGRATION,
284+
expected_return=list[str],
285+
parameters={"message": validate_migration_data},
284286
)
285-
if error_messages: # type: ignore[has-type]
286-
raise SchemaNotValidError(message=",\n".join(error_messages)) # type: ignore[has-type]
287+
if error_messages:
288+
raise SchemaNotValidError(message=",\n".join(error_messages))
287289

288290
# ----------------------------------------------------------
289291
# Update the schema
@@ -320,8 +322,8 @@ async def load_schema(
320322
previous_schema=origin_schema,
321323
migrations=result.migrations,
322324
)
323-
migration_error_msgs = await service.workflow.execute( # type: ignore[var-annotated]
324-
workflow=SCHEMA_APPLY_MIGRATION, message=apply_migration_data
325+
migration_error_msgs = await service.workflow.execute_workflow(
326+
workflow=SCHEMA_APPLY_MIGRATION, expected_return=list[str], parameters={"message": apply_migration_data}
325327
)
326328

327329
if migration_error_msgs:
@@ -368,8 +370,8 @@ async def check_schema(
368370
schema_branch=candidate_schema,
369371
constraints=result.constraints,
370372
)
371-
error_messages = await service.workflow.execute( # type: ignore[var-annotated]
372-
workflow=SCHEMA_VALIDATE_MIGRATION, message=validate_migration_data
373+
error_messages = await service.workflow.execute_workflow(
374+
workflow=SCHEMA_VALIDATE_MIGRATION, expected_return=list[str], parameters={"message": validate_migration_data}
373375
)
374376
if error_messages:
375377
raise SchemaNotValidError(message=",\n".join(error_messages))

backend/infrahub/api/transformation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ async def transform_jinja2(
146146

147147
service: InfrahubServices = request.app.state.service
148148

149-
response: str = await service.workflow.execute(workflow=TRANSFORM_JINJA2_RENDER, message=message) # type: ignore[arg-type]
150-
149+
response = await service.workflow.execute_workflow(
150+
workflow=TRANSFORM_JINJA2_RENDER, expected_return=str, parameters={"message": message}
151+
)
151152
return PlainTextResponse(content=response)

backend/infrahub/cli/db.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,10 @@ async def update_core_schema( # pylint: disable=too-many-statements
243243
schema_branch=candidate_schema,
244244
constraints=result.constraints,
245245
)
246-
error_messages = await service.workflow.execute( # type: ignore[var-annotated]
247-
workflow=SCHEMA_VALIDATE_MIGRATION, message=validate_migration_data
246+
error_messages = await service.workflow.execute_workflow(
247+
workflow=SCHEMA_VALIDATE_MIGRATION,
248+
expected_return=list[str],
249+
parameters={"message": validate_migration_data},
248250
)
249251
if error_messages:
250252
rprint(f"{error_badge} | Unable to update the schema, due to failed validations")
@@ -286,8 +288,8 @@ async def update_core_schema( # pylint: disable=too-many-statements
286288
previous_schema=origin_schema,
287289
migrations=result.migrations,
288290
)
289-
migration_error_msgs = await service.workflow.execute( # type: ignore[var-annotated]
290-
workflow=SCHEMA_APPLY_MIGRATION, message=apply_migration_data
291+
migration_error_msgs = await service.workflow.execute_workflow(
292+
workflow=SCHEMA_APPLY_MIGRATION, expected_return=list[str], parameters={"message": apply_migration_data}
291293
)
292294

293295
if migration_error_msgs:

backend/infrahub/cli/tasks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,7 @@ async def execute(
4646
worker = WorkflowWorkerExecution()
4747
await DUMMY_FLOW.save(client=client, work_pool=WorkerPoolDefinition(name="testing", worker_type="process"))
4848

49-
result = await worker.execute(workflow=DUMMY_FLOW, data=DummyInput(firstname="John", lastname="Doe")) # type: ignore[var-annotated]
49+
result = await worker.execute_workflow(
50+
workflow=DUMMY_FLOW, parameters={"data": DummyInput(firstname="John", lastname="Doe")}
51+
) # type: ignore[var-annotated]
5052
print(result)

backend/infrahub/message_bus/operations/refresh/webhook.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ async def configuration(
1111
message: messages.RefreshWebhookConfiguration, # pylint: disable=unused-argument
1212
service: InfrahubServices,
1313
) -> None:
14+
if not service._client:
15+
service.log.error("Client hasn't been initialized, can't refresh webhook")
16+
return
17+
1418
service.log.debug("Refreshing webhook configuration")
1519
standard_webhooks = await service.client.all(kind=InfrahubKind.STANDARDWEBHOOK)
1620
custom_webhooks = await service.client.all(kind=InfrahubKind.CUSTOMWEBHOOK)
Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ParamSpec, TypeVar
3+
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar, overload
44

55
if TYPE_CHECKING:
66
from infrahub.services import InfrahubServices
7-
from infrahub.workflows.models import WorkflowDefinition
7+
from infrahub.workflows.models import WorkflowDefinition, WorkflowInfo
88

99
Return = TypeVar("Return")
1010
Params = ParamSpec("Params")
@@ -16,10 +16,33 @@ class InfrahubWorkflow:
1616
async def initialize(self, service: InfrahubServices) -> None:
1717
"""Initialize the Workflow engine"""
1818

19-
async def execute(
19+
@overload
20+
async def execute_workflow(
2021
self,
21-
workflow: WorkflowDefinition | None = None,
22-
function: Callable[..., Awaitable[Return]] | None = None,
23-
**kwargs: Any,
24-
) -> Return:
22+
workflow: WorkflowDefinition,
23+
expected_return: type[Return],
24+
parameters: dict[str, Any] | None = ...,
25+
) -> Return: ...
26+
27+
@overload
28+
async def execute_workflow(
29+
self,
30+
workflow: WorkflowDefinition,
31+
expected_return: None = ...,
32+
parameters: dict[str, Any] | None = ...,
33+
) -> Any: ...
34+
35+
async def execute_workflow(
36+
self,
37+
workflow: WorkflowDefinition,
38+
expected_return: type[Return] | None = None,
39+
parameters: dict[str, Any] | None = None,
40+
) -> Any:
41+
raise NotImplementedError()
42+
43+
async def submit_workflow(
44+
self,
45+
workflow: WorkflowDefinition,
46+
parameters: dict[str, Any] | None = None,
47+
) -> WorkflowInfo:
2548
raise NotImplementedError()
Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
1-
from typing import Any, Awaitable, Callable
1+
import uuid
2+
from typing import Any
23

3-
from infrahub.workflows.models import WorkflowDefinition
4+
from infrahub.workflows.models import WorkflowDefinition, WorkflowInfo
45

56
from . import InfrahubWorkflow, Return
67

78

89
class WorkflowLocalExecution(InfrahubWorkflow):
9-
async def execute(
10+
async def execute_workflow(
1011
self,
11-
workflow: WorkflowDefinition | None = None,
12-
function: Callable[..., Awaitable[Return]] | None = None,
13-
**kwargs: Any,
14-
) -> Return:
15-
if workflow:
16-
fn = workflow.get_function()
17-
return await fn(**kwargs)
18-
if function:
19-
return await function(**kwargs)
20-
raise ValueError("either a workflow definition or a flow must be provided")
12+
workflow: WorkflowDefinition,
13+
expected_return: type[Return] | None = None,
14+
parameters: dict[str, Any] | None = None,
15+
) -> Any:
16+
fn = workflow.get_function()
17+
return await fn(**parameters or {})
18+
19+
async def submit_workflow(
20+
self,
21+
workflow: WorkflowDefinition,
22+
parameters: dict[str, Any] | None = None,
23+
) -> WorkflowInfo:
24+
workflow.get_function()
25+
return WorkflowInfo(id=uuid.uuid4())
Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Awaitable, Callable
3+
from typing import TYPE_CHECKING, Any, overload
44

55
from prefect.client.schemas import StateType
66
from prefect.deployments import run_deployment
77

88
from infrahub.workflows.initialization import setup_task_manager
9+
from infrahub.workflows.models import WorkflowInfo
910

1011
from . import InfrahubWorkflow, Return
1112

@@ -23,23 +24,41 @@ async def initialize(self, service: InfrahubServices) -> None:
2324
if await service.component.is_primary_api():
2425
await setup_task_manager()
2526

26-
async def execute(
27+
@overload
28+
async def execute_workflow(
2729
self,
28-
workflow: WorkflowDefinition | None = None,
29-
function: Callable[..., Awaitable[Return]] | None = None,
30-
**kwargs: Any,
31-
) -> Return:
32-
if workflow:
33-
response: FlowRun = await run_deployment(name=workflow.full_name, parameters=kwargs or {}) # type: ignore[return-value, misc]
34-
if not response.state:
35-
raise RuntimeError("Unable to read state from the response")
30+
workflow: WorkflowDefinition,
31+
expected_return: type[Return],
32+
parameters: dict[str, Any] | None = ...,
33+
) -> Return: ...
3634

37-
if response.state.type == StateType.CRASHED:
38-
raise RuntimeError(response.state.message)
35+
@overload
36+
async def execute_workflow(
37+
self,
38+
workflow: WorkflowDefinition,
39+
expected_return: None = ...,
40+
parameters: dict[str, Any] | None = ...,
41+
) -> Any: ...
42+
43+
async def execute_workflow(
44+
self,
45+
workflow: WorkflowDefinition,
46+
expected_return: type[Return] | None = None,
47+
parameters: dict[str, Any] | None = None,
48+
) -> Any:
49+
response: FlowRun = await run_deployment(name=workflow.full_name, poll_interval=1, parameters=parameters or {}) # type: ignore[return-value, misc]
50+
if not response.state:
51+
raise RuntimeError("Unable to read state from the response")
3952

40-
return await response.state.result(raise_on_failure=True, fetch=True) # type: ignore[call-overload]
53+
if response.state.type == StateType.CRASHED:
54+
raise RuntimeError(response.state.message)
4155

42-
if function:
43-
return await function(**kwargs)
56+
return await response.state.result(raise_on_failure=True, fetch=True) # type: ignore[call-overload]
4457

45-
raise ValueError("either a workflow definition or a flow must be provided")
58+
async def submit_workflow(
59+
self,
60+
workflow: WorkflowDefinition,
61+
parameters: dict[str, Any] | None = None,
62+
) -> WorkflowInfo:
63+
flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}) # type: ignore[return-value, misc]
64+
return WorkflowInfo.from_flow(flow_run=flow_run)

backend/infrahub/tasks/dummy.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1+
from __future__ import annotations
2+
13
from prefect import flow, task
24
from pydantic import BaseModel
35

46
from infrahub.workflows.models import WorkflowDefinition
57

8+
9+
class DummyInput(BaseModel):
10+
firstname: str
11+
lastname: str
12+
13+
14+
class DummyOutput(BaseModel):
15+
full_name: str
16+
17+
618
DUMMY_FLOW = WorkflowDefinition(
719
name="dummy_flow",
820
module="infrahub.tasks.dummy",
@@ -16,15 +28,6 @@
1628
)
1729

1830

19-
class DummyInput(BaseModel):
20-
firstname: str
21-
lastname: str
22-
23-
24-
class DummyOutput(BaseModel):
25-
full_name: str
26-
27-
2831
@task
2932
async def aggregate_name(firstname: str, lastname: str) -> str:
3033
return f"{firstname}, {lastname}"

0 commit comments

Comments
 (0)