Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions backend/infrahub/cli/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import typer
from infrahub_sdk.async_typer import AsyncTyper
from prefect.client.orchestration import get_client
from prefect.client.schemas.objects import StateType

from infrahub import config
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution
from infrahub.task_manager.task import PrefectTask
from infrahub.tasks.dummy import DUMMY_FLOW, DummyInput
from infrahub.workflows.initialization import setup_task_manager
from infrahub.workflows.models import WorkerPoolDefinition
Expand Down Expand Up @@ -50,3 +52,47 @@ async def execute(
workflow=DUMMY_FLOW, parameters={"data": DummyInput(firstname="John", lastname="Doe")}
) # type: ignore[var-annotated]
print(result)


flush_app = AsyncTyper()

app.add_typer(flush_app, name="flush")


@flush_app.command()
async def flow_runs(
ctx: typer.Context, # noqa: ARG001
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
days_to_keep: int = 30,
batch_size: int = 100,
) -> None:
"""Flush old task runs"""
logging.getLogger("infrahub").setLevel(logging.WARNING)
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("prefect").setLevel(logging.ERROR)

config.load_and_exit(config_file_name=config_file)

await PrefectTask.delete_flow_runs(
days_to_keep=days_to_keep,
batch_size=batch_size,
)


@flush_app.command()
async def stale_runs(
ctx: typer.Context, # noqa: ARG001
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
days_to_keep: int = 2,
batch_size: int = 100,
) -> None:
"""Flush stale task runs"""
logging.getLogger("infrahub").setLevel(logging.WARNING)
logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("prefect").setLevel(logging.ERROR)

config.load_and_exit(config_file_name=config_file)

await PrefectTask.delete_flow_runs(
states=[StateType.RUNNING], delete=False, days_to_keep=days_to_keep, batch_size=batch_size
)
5 changes: 5 additions & 0 deletions backend/infrahub/services/adapters/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
import ssl

import httpx


class InfrahubHTTP:
def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext:
raise NotImplementedError()

async def get(
self,
url: str,
Expand Down
10 changes: 9 additions & 1 deletion backend/infrahub/services/adapters/workflow/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import TYPE_CHECKING, Any, overload

from prefect.client.schemas.objects import StateType
from prefect.context import AsyncClientContext
from prefect.deployments import run_deployment

from infrahub.services.adapters.http.httpx import HttpxAdapter
from infrahub.workers.utils import inject_context_parameter
from infrahub.workflows.initialization import setup_task_manager
from infrahub.workflows.models import WorkflowInfo
Expand All @@ -19,6 +21,11 @@


class WorkflowWorkerExecution(InfrahubWorkflow):
# This is required to grab a cached SSLContext from the HttpAdapter.
# We cannot use the get_http() dependency since it introduces a circular dependency.
# We could remove this later on by introducing a cached SSLContext outside of this adapter.
_http_adapter = HttpxAdapter()

@staticmethod
async def initialize(component_is_primary_server: bool) -> None:
if component_is_primary_server:
Expand Down Expand Up @@ -79,5 +86,6 @@ async def submit_workflow(
parameters = dict(parameters) if parameters is not None else {}
inject_context_parameter(func=flow_func, parameters=parameters, context=context)

flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc]
async with AsyncClientContext(httpx_settings={"verify": self._http_adapter.verify_tls()}):
flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc]
return WorkflowInfo.from_flow(flow_run=flow_run)
73 changes: 73 additions & 0 deletions backend/infrahub/task_manager/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from uuid import UUID

from prefect import State
from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas.filters import (
ArtifactFilter,
Expand All @@ -12,6 +15,7 @@
FlowRunFilter,
FlowRunFilterId,
FlowRunFilterName,
FlowRunFilterStartTime,
FlowRunFilterState,
FlowRunFilterStateType,
FlowRunFilterTags,
Expand Down Expand Up @@ -311,3 +315,72 @@ async def query(
)

return {"count": count or 0, "edges": nodes}

@classmethod
async def delete_flow_runs(
cls,
states: list[StateType] = [StateType.COMPLETED, StateType.FAILED, StateType.CANCELLED], # noqa: B006
delete: bool = True,
days_to_keep: int = 2,
batch_size: int = 100,
) -> None:
"""Delete flow runs in the specified states and older than specified days."""

logger = get_logger()

async with get_client(sync_client=False) as client:
cutoff = datetime.now(timezone.utc) - timedelta(days=days_to_keep)

flow_run_filter = FlowRunFilter(
start_time=FlowRunFilterStartTime(before_=cutoff), # type: ignore[arg-type]
state=FlowRunFilterState(type=FlowRunFilterStateType(any_=states)),
)

# Get flow runs to delete
flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size)

deleted_total = 0

while True:
batch_deleted = 0
failed_deletes = []

# Delete each flow run through the API
for flow_run in flow_runs:
try:
if delete:
await client.delete_flow_run(flow_run_id=flow_run.id)
else:
await client.set_flow_run_state(
flow_run_id=flow_run.id,
state=State(type=StateType.CRASHED),
force=True,
)
deleted_total += 1
batch_deleted += 1
except Exception as e:
logger.warning(f"Failed to delete flow run {flow_run.id}: {e}")
failed_deletes.append(flow_run.id)

# Rate limiting
if batch_deleted % 10 == 0:
await asyncio.sleep(0.5)

logger.info(f"Delete {batch_deleted}/{len(flow_runs)} flow runs (total: {deleted_total})")

# Get next batch
previous_flow_run_ids = [fr.id for fr in flow_runs]
flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size)

if not flow_runs:
logger.info("No more flow runs to delete")
break

if previous_flow_run_ids == [fr.id for fr in flow_runs]:
logger.info("Found same flow runs to delete, aborting")
break

# Delay between batches to avoid overwhelming the API
await asyncio.sleep(1.0)

logger.info(f"Retention complete. Total deleted tasks: {deleted_total}")
6 changes: 5 additions & 1 deletion backend/infrahub/workers/infrahub_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from infrahub_sdk.exceptions import Error as SdkError
from prefect import settings as prefect_settings
from prefect.client.schemas.objects import FlowRun
from prefect.context import AsyncClientContext
from prefect.flow_engine import run_flow_async
from prefect.logging.handlers import APILogHandler
from prefect.workers.base import BaseJobConfiguration, BaseVariables, BaseWorker, BaseWorkerResult
Expand All @@ -27,6 +28,7 @@
get_cache,
get_component,
get_database,
get_http,
get_message_bus,
get_workflow,
set_component_type,
Expand Down Expand Up @@ -154,7 +156,9 @@ async def run(
if task_status:
task_status.started(True)

await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state")
async with AsyncClientContext(httpx_settings={"verify": get_http().verify_tls()}) as ctx:
ctx._httpx_settings = None # Hack to make all child task/flow runs use the same client
await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state")

return InfrahubWorkerAsyncResult(status_code=0, identifier=str(flow_run.id))

Expand Down
3 changes: 2 additions & 1 deletion backend/infrahub/workflows/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from infrahub.core.constants import GLOBAL_BRANCH_NAME
from infrahub.core.registry import registry
from infrahub.tasks.registry import refresh_branches
from infrahub.workers.dependencies import get_http

from .constants import TAG_NAMESPACE, WorkflowTag

Expand All @@ -26,7 +27,7 @@ async def add_tags(
namespace: bool = True,
db_change: bool = False,
) -> None:
client = get_client(sync_client=False)
client = get_client(httpx_settings={"verify": get_http().verify_tls()}, sync_client=False)
current_flow_run_id = flow_run.id
current_tags: list[str] = flow_run.tags
branch_tags = (
Expand Down
12 changes: 11 additions & 1 deletion backend/tests/adapters/http.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from typing import Any
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import httpx

from infrahub.services.adapters.http import InfrahubHTTP

if TYPE_CHECKING:
import ssl

import httpx


class MemoryHTTP(InfrahubHTTP):
def __init__(self) -> None:
self._get_response: dict[str, httpx.Response] = {}
self._post_response: dict[str, httpx.Response] = {}

def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext:
return False

async def get(
self,
url: str,
Expand Down
Loading