Skip to content

Commit bee748f

Browse files
authored
Merge pull request #4550 from opsmill/dga-20241006-storage-redis
Various enhancements around Prefect
2 parents ebe7fcf + 893849a commit bee748f

File tree

23 files changed

+296
-60
lines changed

23 files changed

+296
-60
lines changed

backend/infrahub/cli/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from infrahub.cli.events import app as events_app
99
from infrahub.cli.git_agent import app as git_app
1010
from infrahub.cli.server import app as server_app
11+
from infrahub.cli.tasks import app as tasks_app
1112
from infrahub.core.initialization import initialization
1213
from infrahub.database import InfrahubDatabase, get_db
1314

@@ -26,6 +27,7 @@ def common(ctx: typer.Context) -> None:
2627
app.add_typer(git_app, name="git-agent")
2728
app.add_typer(db_app, name="db")
2829
app.add_typer(events_app, name="events", help="Interact with the events system.")
30+
app.add_typer(tasks_app, name="tasks", hidden=True)
2931

3032

3133
async def _init_shell(config_file: str) -> None:

backend/infrahub/cli/tasks.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import logging
2+
3+
import typer
4+
from infrahub_sdk.async_typer import AsyncTyper
5+
from prefect.client.orchestration import get_client
6+
7+
from infrahub import config
8+
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution
9+
from infrahub.tasks.dummy import DUMMY_FLOW, DummyInput
10+
from infrahub.workflows.initialization import setup_task_manager
11+
from infrahub.workflows.models import WorkerPoolDefinition
12+
13+
app = AsyncTyper()
14+
15+
# pylint: disable=unused-argument
16+
17+
18+
@app.command()
19+
async def init(
20+
ctx: typer.Context,
21+
debug: bool = typer.Option(False, help="Enable advanced logging and troubleshooting"),
22+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
23+
) -> None:
24+
"""Initialize the task manager"""
25+
logging.getLogger("prefect").setLevel(logging.ERROR)
26+
27+
config.load_and_exit(config_file_name=config_file)
28+
29+
await setup_task_manager()
30+
31+
32+
@app.command()
33+
async def execute(
34+
ctx: typer.Context,
35+
debug: bool = typer.Option(False, help="Enable advanced logging and troubleshooting"),
36+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
37+
) -> None:
38+
"""Check the current format of the internal graph and apply the necessary migrations"""
39+
logging.getLogger("infrahub").setLevel(logging.WARNING)
40+
logging.getLogger("neo4j").setLevel(logging.ERROR)
41+
logging.getLogger("prefect").setLevel(logging.ERROR)
42+
43+
config.load_and_exit(config_file_name=config_file)
44+
45+
async with get_client(sync_client=False) as client:
46+
worker = WorkflowWorkerExecution()
47+
await DUMMY_FLOW.save(client=client, work_pool=WorkerPoolDefinition(name="testing", worker_type="process"))
48+
49+
result = await worker.execute(workflow=DUMMY_FLOW, data=DummyInput(firstname="John", lastname="Doe")) # type: ignore[var-annotated]
50+
print(result)

backend/infrahub/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ class CacheSettings(BaseSettings):
269269
)
270270
database: int = Field(default=0, ge=0, le=15, description="Id of the database to use")
271271
driver: CacheDriver = CacheDriver.Redis
272-
username: str = "infrahub"
273-
password: str = "infrahub"
272+
username: str = ""
273+
password: str = ""
274274
tls_enabled: bool = Field(default=False, description="Indicates if TLS is enabled for the connection")
275275
tls_insecure: bool = Field(default=False, description="Indicates if TLS certificates are verified")
276276
tls_ca_file: Optional[str] = Field(default=None, description="File path to CA cert or bundle in PEM format")
@@ -290,6 +290,9 @@ class WorkflowSettings(BaseSettings):
290290
port: Optional[int] = Field(default=None, ge=1, le=65535, description="Specified if running on a non default port.")
291291
tls_enabled: bool = Field(default=False, description="Indicates if TLS is enabled for the connection")
292292
driver: WorkflowDriver = WorkflowDriver.WORKER
293+
worker_polling_interval: int = Field(
294+
default=2, ge=1, le=30, description="Specify how often the worker should poll the server for tasks (sec)"
295+
)
293296

294297
@property
295298
def api_endpoint(self) -> str:

backend/infrahub/graphql/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from opentelemetry import trace
4040
from starlette.datastructures import UploadFile
41-
from starlette.requests import HTTPConnection, Request
41+
from starlette.requests import ClientDisconnect, HTTPConnection, Request
4242
from starlette.responses import JSONResponse, Response
4343
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
4444

@@ -191,6 +191,9 @@ async def _handle_http_request(
191191
operations = await _get_operation_from_request(request)
192192
except ValueError as exc:
193193
return JSONResponse({"errors": [exc.args[0]]}, status_code=400)
194+
except ClientDisconnect as exc:
195+
self.logger.error("Exception ClientDisconnect in _handle_http_request")
196+
return JSONResponse({"errors": [str(exc)]}, status_code=400)
194197

195198
if isinstance(operations, list):
196199
return JSONResponse({"errors": ["This server does not support batching"]}, status_code=400)

backend/infrahub/services/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def git_report(
100100

101101
async def initialize(self) -> None:
102102
"""Initialize the Services"""
103-
await self.component.initialize(service=self)
104-
await self.http.initialize(service=self)
105103
await self.message_bus.initialize(service=self)
106104
await self.cache.initialize(service=self)
105+
await self.http.initialize(service=self)
106+
await self.component.initialize(service=self)
107107
await self.scheduler.initialize(service=self)
108108
await self.workflow.initialize(service=self)
109109
await self.event.initialize(service=self)

backend/infrahub/services/adapters/workflow/worker.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
from typing import TYPE_CHECKING, Any, Awaitable, Callable
44

5-
from prefect.client.orchestration import get_client
6-
from prefect.client.schemas.actions import WorkPoolCreate
5+
from prefect.client.schemas import StateType
76
from prefect.deployments import run_deployment
8-
from prefect.exceptions import ObjectAlreadyExists
97

10-
from infrahub.workflows.catalogue import worker_pools, workflows
8+
from infrahub.workflows.initialization import setup_task_manager
119

1210
from . import InfrahubWorkflow, Return
1311

@@ -22,23 +20,8 @@ class WorkflowWorkerExecution(InfrahubWorkflow):
2220
async def initialize(self, service: InfrahubServices) -> None:
2321
"""Initialize the Workflow engine"""
2422

25-
async with get_client(sync_client=False) as client:
26-
for worker in worker_pools:
27-
wp = WorkPoolCreate(
28-
name=worker.name,
29-
type=worker.worker_type,
30-
description=worker.description,
31-
)
32-
try:
33-
await client.create_work_pool(work_pool=wp)
34-
service.log.info(f"work pool {worker} created successfully ... ")
35-
except ObjectAlreadyExists:
36-
service.log.info(f"work pool {worker} already present ")
37-
38-
# Create deployment
39-
for workflow in workflows:
40-
flow_id = await client.create_flow_from_name(workflow.name)
41-
await client.create_deployment(flow_id=flow_id, **workflow.to_deployment())
23+
if await service.component.is_primary_api():
24+
await setup_task_manager()
4225

4326
async def execute(
4427
self,
@@ -50,7 +33,11 @@ async def execute(
5033
response: FlowRun = await run_deployment(name=workflow.full_name, parameters=kwargs or {}) # type: ignore[return-value, misc]
5134
if not response.state:
5235
raise RuntimeError("Unable to read state from the response")
53-
return response.state.result(raise_on_failure=True)
36+
37+
if response.state.type == StateType.CRASHED:
38+
raise RuntimeError(response.state.message)
39+
40+
return await response.state.result(raise_on_failure=True, fetch=True) # type: ignore[call-overload]
5441

5542
if function:
5643
return await function(**kwargs)

backend/infrahub/services/component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ async def initialize(self, service: InfrahubServices) -> None:
4242
"""Initialize the Message bus"""
4343
self._service = service
4444

45+
await self.refresh_heartbeat()
46+
4547
async def is_primary_api(self) -> bool:
4648
primary_identity = await self.service.cache.get(PRIMARY_API_SERVER)
4749
return primary_identity == WORKER_IDENTITY

backend/infrahub/tasks/dummy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
from prefect import flow, task
22
from pydantic import BaseModel
33

4+
from infrahub.workflows.models import WorkflowDefinition
5+
6+
DUMMY_FLOW = WorkflowDefinition(
7+
name="dummy_flow",
8+
module="infrahub.tasks.dummy",
9+
function="dummy_flow",
10+
)
11+
12+
DUMMY_FLOW_BROKEN = WorkflowDefinition(
13+
name="dummy_flow_broken",
14+
module="infrahub.tasks.dummy",
15+
function="dummy_flow_broken",
16+
)
17+
418

519
class DummyInput(BaseModel):
620
firstname: str
@@ -19,3 +33,9 @@ async def aggregate_name(firstname: str, lastname: str) -> str:
1933
@flow(persist_result=True)
2034
async def dummy_flow(data: DummyInput) -> DummyOutput:
2135
return DummyOutput(full_name=await aggregate_name(firstname=data.firstname, lastname=data.lastname))
36+
37+
38+
@flow(persist_result=True)
39+
async def dummy_flow_broken(data: DummyInput) -> DummyOutput:
40+
response = await aggregate_name(firstname=data.firstname, lastname=data.lastname)
41+
return DummyOutput(not_valid=response) # type: ignore[call-arg]

backend/infrahub/workers/infrahub_async.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from anyio.abc import TaskStatus
88
from infrahub_sdk import Config, InfrahubClient
99
from infrahub_sdk.exceptions import Error as SdkError
10+
from prefect import settings as prefect_settings
1011
from prefect.client.schemas.objects import FlowRun
1112
from prefect.flow_engine import run_flow_async
1213
from prefect.workers.base import BaseJobConfiguration, BaseVariables, BaseWorker, BaseWorkerResult
@@ -26,6 +27,7 @@
2627
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
2728
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
2829
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution
30+
from infrahub.workflows.models import TASK_RESULT_STORAGE_NAME
2931

3032

3133
class InfrahubWorkerAsyncConfiguration(BaseJobConfiguration):
@@ -49,8 +51,6 @@ class InfrahubWorkerAsync(BaseWorker):
4951
_description = "Infrahub worker designed to run the flow in the main async loop."
5052

5153
async def setup(self, **kwargs: dict[str, Any]) -> None:
52-
await super().setup(**kwargs)
53-
5454
logging.getLogger("websockets").setLevel(logging.ERROR)
5555
logging.getLogger("httpx").setLevel(logging.ERROR)
5656
logging.getLogger("httpcore").setLevel(logging.ERROR)
@@ -67,7 +67,19 @@ async def setup(self, **kwargs: dict[str, Any]) -> None:
6767
self._logger.info(f"Starting metric endpoint on port {metric_port}")
6868
start_http_server(metric_port)
6969

70-
self._logger.info(f"Using Infrahub API at {config.SETTINGS.main.internal_address}")
70+
self._exit_stack.enter_context(
71+
prefect_settings.temporary_settings(
72+
updates={ # type: ignore[arg-type]
73+
prefect_settings.PREFECT_WORKER_QUERY_SECONDS: config.SETTINGS.workflow.worker_polling_interval,
74+
prefect_settings.PREFECT_RESULTS_PERSIST_BY_DEFAULT: True,
75+
prefect_settings.PREFECT_DEFAULT_RESULT_STORAGE_BLOCK: f"redisstoragecontainer/{TASK_RESULT_STORAGE_NAME}",
76+
}
77+
)
78+
)
79+
80+
await super().setup(**kwargs)
81+
82+
self._logger.debug(f"Using Infrahub API at {config.SETTINGS.main.internal_address}")
7183
client = InfrahubClient(
7284
config=Config(address=config.SETTINGS.main.internal_address, retry_on_failure=True, log=self._logger)
7385
)
Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,52 @@
1+
from .constants import WorkflowType
12
from .models import WorkerPoolDefinition, WorkflowDefinition
23

3-
WORKER_POOL = WorkerPoolDefinition(
4-
name="infrahub-internal", worker_type="infrahubasync", description="Pool for internal tasks"
4+
INFRAHUB_WORKER_POOL = WorkerPoolDefinition(
5+
name="infrahub-worker", worker_type="infrahubasync", description="Default Pool for internal tasks"
56
)
67

78
WEBHOOK_SEND = WorkflowDefinition(
89
name="webhook_send",
9-
work_pool=WORKER_POOL,
10+
type=WorkflowType.USER,
1011
module="infrahub.message_bus.operations.send.webhook",
1112
function="send_webhook",
1213
)
1314

1415
TRANSFORM_JINJA2_RENDER = WorkflowDefinition(
1516
name="transform_render_jinja2_template",
16-
work_pool=WORKER_POOL,
17+
type=WorkflowType.USER,
1718
module="infrahub.message_bus.operations.transform.jinja",
1819
function="transform_render_jinja2_template",
1920
)
2021

2122
ANONYMOUS_TELEMETRY_SEND = WorkflowDefinition(
2223
name="anonymous_telemetry_send",
23-
work_pool=WORKER_POOL,
24+
type=WorkflowType.INTERNAL,
2425
cron="0 2 * * *",
2526
module="infrahub.message_bus.operations.send.telemetry",
2627
function="send_telemetry_push",
2728
)
2829

29-
DUMMY_FLOW = WorkflowDefinition(
30-
name="dummy_flow",
31-
work_pool=WORKER_POOL,
32-
module="infrahub.tasks.dummy",
33-
function="dummy_flow",
34-
)
35-
3630
SCHEMA_APPLY_MIGRATION = WorkflowDefinition(
3731
name="schema_apply_migrations",
38-
work_pool=WORKER_POOL,
32+
type=WorkflowType.INTERNAL,
3933
module="infrahub.core.migrations.schema.tasks",
4034
function="schema_apply_migrations",
4135
)
4236

4337
SCHEMA_VALIDATE_MIGRATION = WorkflowDefinition(
4438
name="schema_validate_migrations",
45-
work_pool=WORKER_POOL,
39+
type=WorkflowType.INTERNAL,
4640
module="infrahub.core.validators.tasks",
4741
function="schema_validate_migrations",
4842
)
4943

50-
worker_pools = [WORKER_POOL]
44+
worker_pools = [INFRAHUB_WORKER_POOL]
5145

5246
workflows = [
5347
WEBHOOK_SEND,
5448
TRANSFORM_JINJA2_RENDER,
5549
ANONYMOUS_TELEMETRY_SEND,
56-
DUMMY_FLOW,
5750
SCHEMA_APPLY_MIGRATION,
5851
SCHEMA_VALIDATE_MIGRATION,
5952
]

0 commit comments

Comments
 (0)