diff --git a/backend/infrahub/cli/tasks.py b/backend/infrahub/cli/tasks.py index 382beae372..27cdc0311f 100644 --- a/backend/infrahub/cli/tasks.py +++ b/backend/infrahub/cli/tasks.py @@ -3,9 +3,11 @@ import typer from infrahub_sdk.async_typer import AsyncTyper from prefect.client.orchestration import get_client +from prefect.client.schemas.objects import StateType from infrahub import config from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution +from infrahub.task_manager.task import PrefectTask from infrahub.tasks.dummy import DUMMY_FLOW, DummyInput from infrahub.workflows.initialization import setup_task_manager from infrahub.workflows.models import WorkerPoolDefinition @@ -50,3 +52,47 @@ async def execute( workflow=DUMMY_FLOW, parameters={"data": DummyInput(firstname="John", lastname="Doe")} ) # type: ignore[var-annotated] print(result) + + +flush_app = AsyncTyper() + +app.add_typer(flush_app, name="flush") + + +@flush_app.command() +async def flow_runs( + ctx: typer.Context, # noqa: ARG001 + config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"), + days_to_keep: int = 30, + batch_size: int = 100, +) -> None: + """Flush old task runs""" + logging.getLogger("infrahub").setLevel(logging.WARNING) + logging.getLogger("neo4j").setLevel(logging.ERROR) + logging.getLogger("prefect").setLevel(logging.ERROR) + + config.load_and_exit(config_file_name=config_file) + + await PrefectTask.delete_flow_runs( + days_to_keep=days_to_keep, + batch_size=batch_size, + ) + + +@flush_app.command() +async def stale_runs( + ctx: typer.Context, # noqa: ARG001 + config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"), + days_to_keep: int = 2, + batch_size: int = 100, +) -> None: + """Flush stale task runs""" + logging.getLogger("infrahub").setLevel(logging.WARNING) + logging.getLogger("neo4j").setLevel(logging.ERROR) + logging.getLogger("prefect").setLevel(logging.ERROR) + + config.load_and_exit(config_file_name=config_file) + + await PrefectTask.delete_flow_runs( + states=[StateType.RUNNING], delete=False, days_to_keep=days_to_keep, batch_size=batch_size + ) diff --git a/backend/infrahub/task_manager/task.py b/backend/infrahub/task_manager/task.py index 01a94b4606..3e3145bd3f 100644 --- a/backend/infrahub/task_manager/task.py +++ b/backend/infrahub/task_manager/task.py @@ -1,4 +1,6 @@ +import asyncio import uuid +from datetime import datetime, timedelta, timezone from typing import Any from uuid import UUID @@ -12,13 +14,14 @@ FlowRunFilter, FlowRunFilterId, FlowRunFilterName, + FlowRunFilterStartTime, FlowRunFilterState, FlowRunFilterStateType, FlowRunFilterTags, LogFilter, LogFilterFlowRunId, ) -from prefect.client.schemas.objects import Flow, FlowRun, StateType +from prefect.client.schemas.objects import Flow, FlowRun, State, StateType from prefect.client.schemas.sorting import ( FlowRunSort, ) @@ -311,3 +314,68 @@ async def query( ) return {"count": count or 0, "edges": nodes} + + @classmethod + async def delete_flow_runs( + cls, + states: list[StateType] | None = None, + delete: bool = True, + days_to_keep: int = 2, + batch_size: int = 100, + ) -> None: + """Delete flow runs in the specified states and older than specified days.""" + + if states is None: + states = [StateType.COMPLETED, StateType.FAILED, StateType.CANCELLED] + + logger = get_logger() + + async with get_client(sync_client=False) as client: + cutoff = datetime.now(timezone.utc) - timedelta(days=days_to_keep) + + flow_run_filter = FlowRunFilter( + start_time=FlowRunFilterStartTime(before_=cutoff), # type: ignore[arg-type] + state=FlowRunFilterState(type=FlowRunFilterStateType(any_=states)), + ) + + # Get flow runs to delete + flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size) + + deleted_total = 0 + + while flow_runs: + batch_deleted = 0 + failed_deletes = [] + + # Delete each flow run through the API + for flow_run in flow_runs: + try: + if delete: + await client.delete_flow_run(flow_run_id=flow_run.id) + else: + await client.set_flow_run_state( + flow_run_id=flow_run.id, + state=State(type=StateType.CRASHED), + force=True, + ) + deleted_total += 1 + batch_deleted += 1 + except Exception as e: + logger.warning(f"Failed to delete flow run {flow_run.id}: {e}") + failed_deletes.append(flow_run.id) + + # Rate limiting + if batch_deleted % 10 == 0: + await asyncio.sleep(0.5) + + logger.info(f"Delete {batch_deleted}/{len(flow_runs)} flow runs (total: {deleted_total})") + if failed_deletes: + logger.warning(f"Failed to delete {len(failed_deletes)} flow runs") + + # Get next batch + flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size) + + # Delay between batches to avoid overwhelming the API + await asyncio.sleep(1.0) + + logger.info(f"Retention complete. Total deleted tasks: {deleted_total}")