diff --git a/backend/infrahub/services/adapters/http/__init__.py b/backend/infrahub/services/adapters/http/__init__.py index f38552a11b..4e72ab43ae 100644 --- a/backend/infrahub/services/adapters/http/__init__.py +++ b/backend/infrahub/services/adapters/http/__init__.py @@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + import ssl + import httpx class InfrahubHTTP: + def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext: + raise NotImplementedError() + async def get( self, url: str, diff --git a/backend/infrahub/services/adapters/workflow/worker.py b/backend/infrahub/services/adapters/workflow/worker.py index aa3659a200..a2ece3d514 100644 --- a/backend/infrahub/services/adapters/workflow/worker.py +++ b/backend/infrahub/services/adapters/workflow/worker.py @@ -3,8 +3,10 @@ from typing import TYPE_CHECKING, Any, overload from prefect.client.schemas.objects import StateType +from prefect.context import AsyncClientContext from prefect.deployments import run_deployment +from infrahub.services.adapters.http.httpx import HttpxAdapter from infrahub.workers.utils import inject_context_parameter from infrahub.workflows.initialization import setup_task_manager from infrahub.workflows.models import WorkflowInfo @@ -19,6 +21,11 @@ class WorkflowWorkerExecution(InfrahubWorkflow): + # This is required to grab a cached SSLContext from the HttpAdapter. + # We cannot use the get_http() dependency since it introduces a circular dependency. + # We could remove this later on by introducing a cached SSLContext outside of this adapter. + _http_adapter = HttpxAdapter() + @staticmethod async def initialize(component_is_primary_server: bool) -> None: if component_is_primary_server: @@ -79,5 +86,6 @@ async def submit_workflow( parameters = dict(parameters) if parameters is not None else {} inject_context_parameter(func=flow_func, parameters=parameters, context=context) - flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc] + async with AsyncClientContext(httpx_settings={"verify": self._http_adapter.verify_tls()}): + flow_run = await run_deployment(name=workflow.full_name, timeout=0, parameters=parameters or {}, tags=tags) # type: ignore[return-value, misc] return WorkflowInfo.from_flow(flow_run=flow_run) diff --git a/backend/infrahub/workers/infrahub_async.py b/backend/infrahub/workers/infrahub_async.py index 8e4d3411c4..0664fa764a 100644 --- a/backend/infrahub/workers/infrahub_async.py +++ b/backend/infrahub/workers/infrahub_async.py @@ -8,6 +8,7 @@ from infrahub_sdk.exceptions import Error as SdkError from prefect import settings as prefect_settings from prefect.client.schemas.objects import FlowRun +from prefect.context import AsyncClientContext from prefect.flow_engine import run_flow_async from prefect.logging.handlers import APILogHandler from prefect.workers.base import BaseJobConfiguration, BaseVariables, BaseWorker, BaseWorkerResult @@ -27,6 +28,7 @@ get_cache, get_component, get_database, + get_http, get_message_bus, get_workflow, set_component_type, @@ -154,7 +156,9 @@ async def run( if task_status: task_status.started(True) - await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state") + async with AsyncClientContext(httpx_settings={"verify": get_http().verify_tls()}) as ctx: + ctx._httpx_settings = None # Hack to make all child task/flow runs use the same client + await run_flow_async(flow=flow_func, flow_run=flow_run, parameters=params, return_type="state") return InfrahubWorkerAsyncResult(status_code=0, identifier=str(flow_run.id)) diff --git a/backend/infrahub/workflows/utils.py b/backend/infrahub/workflows/utils.py index 9b434cae5d..62d58c747c 100644 --- a/backend/infrahub/workflows/utils.py +++ b/backend/infrahub/workflows/utils.py @@ -9,6 +9,7 @@ from infrahub.core.constants import GLOBAL_BRANCH_NAME from infrahub.core.registry import registry from infrahub.tasks.registry import refresh_branches +from infrahub.workers.dependencies import get_http from .constants import TAG_NAMESPACE, WorkflowTag @@ -26,7 +27,7 @@ async def add_tags( namespace: bool = True, db_change: bool = False, ) -> None: - client = get_client(sync_client=False) + client = get_client(httpx_settings={"verify": get_http().verify_tls()}, sync_client=False) current_flow_run_id = flow_run.id current_tags: list[str] = flow_run.tags branch_tags = ( diff --git a/backend/tests/adapters/http.py b/backend/tests/adapters/http.py index 066f59733b..c4c4c6de80 100644 --- a/backend/tests/adapters/http.py +++ b/backend/tests/adapters/http.py @@ -1,15 +1,25 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import httpx from infrahub.services.adapters.http import InfrahubHTTP +if TYPE_CHECKING: + import ssl + + import httpx + class MemoryHTTP(InfrahubHTTP): def __init__(self) -> None: self._get_response: dict[str, httpx.Response] = {} self._post_response: dict[str, httpx.Response] = {} + def verify_tls(self, verify: bool | None = None) -> bool | ssl.SSLContext: + return False + async def get( self, url: str,