Skip to content

Commit 48ce545

Browse files
authored
Merge pull request #7530 from opsmill/stable
Merge stable into release-1.5
2 parents 0becd88 + 9a4ce0d commit 48ce545

File tree

7 files changed

+151
-4
lines changed

7 files changed

+151
-4
lines changed

backend/infrahub/cli/tasks.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import typer
44
from infrahub_sdk.async_typer import AsyncTyper
55
from prefect.client.orchestration import get_client
6+
from prefect.client.schemas.objects import StateType
67

78
from infrahub import config
89
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution
10+
from infrahub.task_manager.task import PrefectTask
911
from infrahub.tasks.dummy import DUMMY_FLOW, DummyInput
1012
from infrahub.workflows.initialization import setup_task_manager
1113
from infrahub.workflows.models import WorkerPoolDefinition
@@ -50,3 +52,47 @@ async def execute(
5052
workflow=DUMMY_FLOW, parameters={"data": DummyInput(firstname="John", lastname="Doe")}
5153
) # type: ignore[var-annotated]
5254
print(result)
55+
56+
57+
flush_app = AsyncTyper()
58+
59+
app.add_typer(flush_app, name="flush")
60+
61+
62+
@flush_app.command()
63+
async def flow_runs(
64+
ctx: typer.Context, # noqa: ARG001
65+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
66+
days_to_keep: int = 30,
67+
batch_size: int = 100,
68+
) -> None:
69+
"""Flush old task runs"""
70+
logging.getLogger("infrahub").setLevel(logging.WARNING)
71+
logging.getLogger("neo4j").setLevel(logging.ERROR)
72+
logging.getLogger("prefect").setLevel(logging.ERROR)
73+
74+
config.load_and_exit(config_file_name=config_file)
75+
76+
await PrefectTask.delete_flow_runs(
77+
days_to_keep=days_to_keep,
78+
batch_size=batch_size,
79+
)
80+
81+
82+
@flush_app.command()
83+
async def stale_runs(
84+
ctx: typer.Context, # noqa: ARG001
85+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
86+
days_to_keep: int = 2,
87+
batch_size: int = 100,
88+
) -> None:
89+
"""Flush stale task runs"""
90+
logging.getLogger("infrahub").setLevel(logging.WARNING)
91+
logging.getLogger("neo4j").setLevel(logging.ERROR)
92+
logging.getLogger("prefect").setLevel(logging.ERROR)
93+
94+
config.load_and_exit(config_file_name=config_file)
95+
96+
await PrefectTask.delete_flow_runs(
97+
states=[StateType.RUNNING], delete=False, days_to_keep=days_to_keep, batch_size=batch_size
98+
)

backend/infrahub/services/adapters/http/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from typing import TYPE_CHECKING, Any
44

55
if TYPE_CHECKING:
6+
import ssl
7+
68
import httpx
79

810

911
class InfrahubHTTP:
12+
def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext:
13+
raise NotImplementedError()
14+
1015
async def get(
1116
self,
1217
url: str,

backend/infrahub/services/adapters/workflow/worker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from typing import TYPE_CHECKING, Any, overload
44

55
from prefect.client.schemas.objects import StateType
6+
from prefect.context import AsyncClientContext
67
from prefect.deployments import run_deployment
78

9+
from infrahub.services.adapters.http.httpx import HttpxAdapter
810
from infrahub.workers.utils import inject_context_parameter
911
from infrahub.workflows.initialization import setup_task_manager, setup_task_manager_identifiers
1012
from infrahub.workflows.models import WorkflowInfo
@@ -19,6 +21,11 @@
1921

2022

2123
class WorkflowWorkerExecution(InfrahubWorkflow):
24+
# This is required to grab a cached SSLContext from the HttpAdapter.
25+
# We cannot use the get_http() dependency since it introduces a circular dependency.
26+
# We could remove this later on by introducing a cached SSLContext outside of this adapter.
27+
_http_adapter = HttpxAdapter()
28+
2229
@staticmethod
2330
async def initialize(component_is_primary_server: bool, is_initial_setup: bool = False) -> None:
2431
if component_is_primary_server:
@@ -82,5 +89,6 @@ async def submit_workflow(
8289
parameters = dict(parameters) if parameters is not None else {}
8390
inject_context_parameter(func=flow_func, parameters=parameters, context=context)
8491

85-
flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc]
92+
async with AsyncClientContext(httpx_settings={"verify": self._http_adapter.verify_tls()}):
93+
flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc]
8694
return WorkflowInfo.from_flow(flow_run=flow_run)

backend/infrahub/task_manager/task.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import asyncio
12
import uuid
3+
from datetime import datetime, timedelta, timezone
24
from typing import Any
35
from uuid import UUID
46

7+
from prefect import State
58
from prefect.client.orchestration import PrefectClient, get_client
69
from prefect.client.schemas.filters import (
710
ArtifactFilter,
@@ -12,6 +15,7 @@
1215
FlowRunFilter,
1316
FlowRunFilterId,
1417
FlowRunFilterName,
18+
FlowRunFilterStartTime,
1519
FlowRunFilterState,
1620
FlowRunFilterStateType,
1721
FlowRunFilterTags,
@@ -311,3 +315,72 @@ async def query(
311315
)
312316

313317
return {"count": count or 0, "edges": nodes}
318+
319+
@classmethod
320+
async def delete_flow_runs(
321+
cls,
322+
states: list[StateType] = [StateType.COMPLETED, StateType.FAILED, StateType.CANCELLED], # noqa: B006
323+
delete: bool = True,
324+
days_to_keep: int = 2,
325+
batch_size: int = 100,
326+
) -> None:
327+
"""Delete flow runs in the specified states and older than specified days."""
328+
329+
logger = get_logger()
330+
331+
async with get_client(sync_client=False) as client:
332+
cutoff = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
333+
334+
flow_run_filter = FlowRunFilter(
335+
start_time=FlowRunFilterStartTime(before_=cutoff), # type: ignore[arg-type]
336+
state=FlowRunFilterState(type=FlowRunFilterStateType(any_=states)),
337+
)
338+
339+
# Get flow runs to delete
340+
flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size)
341+
342+
deleted_total = 0
343+
344+
while True:
345+
batch_deleted = 0
346+
failed_deletes = []
347+
348+
# Delete each flow run through the API
349+
for flow_run in flow_runs:
350+
try:
351+
if delete:
352+
await client.delete_flow_run(flow_run_id=flow_run.id)
353+
else:
354+
await client.set_flow_run_state(
355+
flow_run_id=flow_run.id,
356+
state=State(type=StateType.CRASHED),
357+
force=True,
358+
)
359+
deleted_total += 1
360+
batch_deleted += 1
361+
except Exception as e:
362+
logger.warning(f"Failed to delete flow run {flow_run.id}: {e}")
363+
failed_deletes.append(flow_run.id)
364+
365+
# Rate limiting
366+
if batch_deleted % 10 == 0:
367+
await asyncio.sleep(0.5)
368+
369+
logger.info(f"Delete {batch_deleted}/{len(flow_runs)} flow runs (total: {deleted_total})")
370+
371+
# Get next batch
372+
previous_flow_run_ids = [fr.id for fr in flow_runs]
373+
flow_runs = await client.read_flow_runs(flow_run_filter=flow_run_filter, limit=batch_size)
374+
375+
if not flow_runs:
376+
logger.info("No more flow runs to delete")
377+
break
378+
379+
if previous_flow_run_ids == [fr.id for fr in flow_runs]:
380+
logger.info("Found same flow runs to delete, aborting")
381+
break
382+
383+
# Delay between batches to avoid overwhelming the API
384+
await asyncio.sleep(1.0)
385+
386+
logger.info(f"Retention complete. Total deleted tasks: {deleted_total}")

backend/infrahub/workers/infrahub_async.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from infrahub_sdk.exceptions import Error as SdkError
99
from prefect import settings as prefect_settings
1010
from prefect.client.schemas.objects import FlowRun
11+
from prefect.context import AsyncClientContext
1112
from prefect.flow_engine import run_flow_async
1213
from prefect.logging.handlers import APILogHandler
1314
from prefect.workers.base import BaseJobConfiguration, BaseVariables, BaseWorker, BaseWorkerResult
@@ -28,6 +29,7 @@
2829
get_cache,
2930
get_component,
3031
get_database,
32+
get_http,
3133
get_message_bus,
3234
get_workflow,
3335
set_component_type,
@@ -158,7 +160,9 @@ async def run(
158160
if task_status:
159161
task_status.started(True)
160162

161-
await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state")
163+
async with AsyncClientContext(httpx_settings={"verify": get_http().verify_tls()}) as ctx:
164+
ctx._httpx_settings = None # Hack to make all child task/flow runs use the same client
165+
await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state")
162166

163167
return InfrahubWorkerAsyncResult(status_code=0, identifier=str(flow_run.id))
164168

backend/infrahub/workflows/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from infrahub.core.constants import GLOBAL_BRANCH_NAME
1010
from infrahub.core.registry import registry
1111
from infrahub.tasks.registry import refresh_branches
12+
from infrahub.workers.dependencies import get_http
1213

1314
from .constants import TAG_NAMESPACE, WorkflowTag
1415

@@ -26,7 +27,7 @@ async def add_tags(
2627
namespace: bool = True,
2728
db_change: bool = False,
2829
) -> None:
29-
client = get_client(sync_client=False)
30+
client = get_client(httpx_settings={"verify": get_http().verify_tls()}, sync_client=False)
3031
current_flow_run_id = flow_run.id
3132
current_tags: list[str] = flow_run.tags
3233
branch_tags = (

backend/tests/adapters/http.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
from typing import Any
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
24

35
import httpx
46

57
from infrahub.services.adapters.http import InfrahubHTTP
68

9+
if TYPE_CHECKING:
10+
import ssl
11+
12+
import httpx
13+
714

815
class MemoryHTTP(InfrahubHTTP):
916
def __init__(self) -> None:
1017
self._get_response: dict[str, httpx.Response] = {}
1118
self._post_response: dict[str, httpx.Response] = {}
1219

20+
def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext:
21+
return False
22+
1323
async def get(
1424
self,
1525
url: str,

0 commit comments

Comments
 (0)