diff --git a/backend/infrahub/cli/tasks.py b/backend/infrahub/cli/tasks.py index 382beae372..27cdc0311f 100644 --- a/backend/infrahub/cli/tasks.py +++ b/backend/infrahub/cli/tasks.py @@ -3,9 +3,11 @@ import typer from infrahub_sdk.async_typer import AsyncTyper from prefect.client.orchestration import get_client +from prefect.client.schemas.objects import StateType from infrahub import config from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution +from infrahub.task_manager.task import PrefectTask from infrahub.tasks.dummy import DUMMY_FLOW, DummyInput from infrahub.workflows.initialization import setup_task_manager from infrahub.workflows.models import WorkerPoolDefinition @@ -50,3 +52,47 @@ async def execute( workflow=DUMMY_FLOW, parameters={"data": DummyInput(firstname="John", lastname="Doe")} ) # type: ignore[var-annotated] print(result) + + +flush_app = AsyncTyper() + +app.add_typer(flush_app, name="flush") + + +@flush_app.command() +async def flow_runs( + ctx: typer.Context, # noqa: ARG001 + config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"), + days_to_keep: int = 30, + batch_size: int = 100, +) -> None: + """Flush old task runs""" + logging.getLogger("infrahub").setLevel(logging.WARNING) + logging.getLogger("neo4j").setLevel(logging.ERROR) + logging.getLogger("prefect").setLevel(logging.ERROR) + + config.load_and_exit(config_file_name=config_file) + + await PrefectTask.delete_flow_runs( + days_to_keep=days_to_keep, + batch_size=batch_size, + ) + + +@flush_app.command() +async def stale_runs( + ctx: typer.Context, # noqa: ARG001 + config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"), + days_to_keep: int = 2, + batch_size: int = 100, +) -> None: + """Flush stale task runs""" + logging.getLogger("infrahub").setLevel(logging.WARNING) + logging.getLogger("neo4j").setLevel(logging.ERROR) + logging.getLogger("prefect").setLevel(logging.ERROR) + + config.load_and_exit(config_file_name=config_file) + + await PrefectTask.delete_flow_runs( + states=[StateType.RUNNING], delete=False, days_to_keep=days_to_keep, batch_size=batch_size + ) diff --git a/backend/infrahub/services/adapters/http/__init__.py b/backend/infrahub/services/adapters/http/__init__.py index f38552a11b..4e72ab43ae 100644 --- a/backend/infrahub/services/adapters/http/__init__.py +++ b/backend/infrahub/services/adapters/http/__init__.py @@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + import ssl + import httpx class InfrahubHTTP: + def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext: + raise NotImplementedError() + async def get( self, url: str, diff --git a/backend/infrahub/services/adapters/workflow/worker.py b/backend/infrahub/services/adapters/workflow/worker.py index aa3659a200..a2ece3d514 100644 --- a/backend/infrahub/services/adapters/workflow/worker.py +++ b/backend/infrahub/services/adapters/workflow/worker.py @@ -3,8 +3,10 @@ from typing import TYPE_CHECKING, Any, overload from prefect.client.schemas.objects import StateType +from prefect.context import AsyncClientContext from prefect.deployments import run_deployment +from infrahub.services.adapters.http.httpx import HttpxAdapter from infrahub.workers.utils import inject_context_parameter from infrahub.workflows.initialization import setup_task_manager from infrahub.workflows.models import WorkflowInfo @@ -19,6 +21,11 @@ class WorkflowWorkerExecution(InfrahubWorkflow): + # This is required to grab a cached SSLContext from the HttpAdapter. + # We cannot use the get_http() dependency since it introduces a circular dependency. + # We could remove this later on by introducing a cached SSLContext outside of this adapter. + _http_adapter = HttpxAdapter() + @staticmethod async def initialize(component_is_primary_server: bool) -> None: if component_is_primary_server: @@ -79,5 +86,6 @@ async def submit_workflow( parameters = dict(parameters) if parameters is not None else {} inject_context_parameter(func=flow_func, parameters=parameters, context=context) - flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc] + async with AsyncClientContext(httpx_settings={"verify": self._http_adapter.verify_tls()}): + flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc] return WorkflowInfo.from_flow(flow_run=flow_run) diff --git a/backend/infrahub/task_manager/task.py b/backend/infrahub/task_manager/task.py index 01a94b4606..0b9ff622e0 100644 --- a/backend/infrahub/task_manager/task.py +++ b/backend/infrahub/task_manager/task.py @@ -1,7 +1,10 @@ +import asyncio import uuid +from datetime import datetime, timedelta, timezone from typing import Any from uuid import UUID +from prefect import State from prefect.client.orchestration import PrefectClient, get_client from prefect.client.schemas.filters import ( ArtifactFilter, @@ -12,6 +15,7 @@ FlowRunFilter, FlowRunFilterId, FlowRunFilterName, + FlowRunFilterStartTime, FlowRunFilterState, FlowRunFilterStateType, FlowRunFilterTags, @@ -311,3 +315,72 @@ async def query( ) return {"count": count or 0, "edges": nodes} + + @classmethod + async def delete_flow_runs( + cls, + states: list[StateType] = [StateType.COMPLETED, StateType.FAILED, StateType.CANCELLED], # noqa: B006 + delete: bool = True, + days_to_keep: int = 2, + batch_size: int = 100, + ) -> None: + """Delete flow runs in the specified states and older than specified days.""" + + logger = get_logger() + + async with get_client(sync_client=False) as client: + cutoff = datetime.now(timezone.utc) - timedelta(days=days_to_keep) + + flow_run_filter = FlowRunFilter( + start_time=FlowRunFilterStartTime(before_=cutoff), # type: ignore[arg-type] + state=FlowRunFilterState(type=FlowRunFilterStateType(any_=states)), + ) + + # Get flow runs to delete + flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size) + + deleted_total = 0 + + while True: + batch_deleted = 0 + failed_deletes = [] + + # Delete each flow run through the API + for flow_run in flow_runs: + try: + if delete: + await client.delete_flow_run(flow_run_id=flow_run.id) + else: + await client.set_flow_run_state( + flow_run_id=flow_run.id, + state=State(type=StateType.CRASHED), + force=True, + ) + deleted_total += 1 + batch_deleted += 1 + except Exception as e: + logger.warning(f"Failed to delete flow run {flow_run.id}: {e}") + failed_deletes.append(flow_run.id) + + # Rate limiting + if batch_deleted % 10 == 0: + await asyncio.sleep(0.5) + + logger.info(f"Delete {batch_deleted}/{len(flow_runs)} flow runs (total: {deleted_total})") + + # Get next batch + previous_flow_run_ids = [fr.id for fr in flow_runs] + flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size) + + if not flow_runs: + logger.info("No more flow runs to delete") + break + + if previous_flow_run_ids == [fr.id for fr in flow_runs]: + logger.info("Found same flow runs to delete, aborting") + break + + # Delay between batches to avoid overwhelming the API + await asyncio.sleep(1.0) + + logger.info(f"Retention complete. Total deleted tasks: {deleted_total}") diff --git a/backend/infrahub/workers/infrahub_async.py b/backend/infrahub/workers/infrahub_async.py index 8e4d3411c4..0664fa764a 100644 --- a/backend/infrahub/workers/infrahub_async.py +++ b/backend/infrahub/workers/infrahub_async.py @@ -8,6 +8,7 @@ from infrahub_sdk.exceptions import Error as SdkError from prefect import settings as prefect_settings from prefect.client.schemas.objects import FlowRun +from prefect.context import AsyncClientContext from prefect.flow_engine import run_flow_async from prefect.logging.handlers import APILogHandler from prefect.workers.base import BaseJobConfiguration, BaseVariables, BaseWorker, BaseWorkerResult @@ -27,6 +28,7 @@ get_cache, get_component, get_database, + get_http, get_message_bus, get_workflow, set_component_type, @@ -154,7 +156,9 @@ async def run( if task_status: task_status.started(True) - await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state") + async with AsyncClientContext(httpx_settings={"verify": get_http().verify_tls()}) as ctx: + ctx._httpx_settings = None # Hack to make all child task/flow runs use the same client + await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state") return InfrahubWorkerAsyncResult(status_code=0, identifier=str(flow_run.id)) diff --git a/backend/infrahub/workflows/utils.py b/backend/infrahub/workflows/utils.py index 9b434cae5d..62d58c747c 100644 --- a/backend/infrahub/workflows/utils.py +++ b/backend/infrahub/workflows/utils.py @@ -9,6 +9,7 @@ from infrahub.core.constants import GLOBAL_BRANCH_NAME from infrahub.core.registry import registry from infrahub.tasks.registry import refresh_branches +from infrahub.workers.dependencies import get_http from .constants import TAG_NAMESPACE, WorkflowTag @@ -26,7 +27,7 @@ async def add_tags( namespace: bool = True, db_change: bool = False, ) -> None: - client = get_client(sync_client=False) + client = get_client(httpx_settings={"verify": get_http().verify_tls()}, sync_client=False) current_flow_run_id = flow_run.id current_tags: list[str] = flow_run.tags branch_tags = ( diff --git a/backend/tests/adapters/http.py b/backend/tests/adapters/http.py index 066f59733b..c4c4c6de80 100644 --- a/backend/tests/adapters/http.py +++ b/backend/tests/adapters/http.py @@ -1,15 +1,25 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import httpx from infrahub.services.adapters.http import InfrahubHTTP +if TYPE_CHECKING: + import ssl + + import httpx + class MemoryHTTP(InfrahubHTTP): def __init__(self) -> None: self._get_response: dict[str, httpx.Response] = {} self._post_response: dict[str, httpx.Response] = {} + def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext: + return False + async def get( self, url: str,