diff --git a/tests/unit/api/test_api_organization_invitations.py b/tests/unit/api/test_api_organization_invitations.py index c36b46e324..d9d39db4e9 100644 --- a/tests/unit/api/test_api_organization_invitations.py +++ b/tests/unit/api/test_api_organization_invitations.py @@ -11,6 +11,7 @@ from tracecat.api.app import app from tracecat.auth.types import Role +from tracecat.auth.users import current_active_user from tracecat.authz.enums import OrgRole from tracecat.db.engine import get_async_session @@ -50,7 +51,8 @@ async def test_list_my_pending_invitations_success( pending_result = Mock() pending_result.tuples.return_value = tuples_result - mock_session.execute.side_effect = [user_result, pending_result] + app.dependency_overrides[current_active_user] = lambda: mock_user + mock_session.execute.side_effect = [pending_result] response = client.get("/organization/invitations/pending/me") @@ -69,13 +71,7 @@ async def test_list_my_pending_invitations_success( async def test_list_my_pending_invitations_user_not_found( client: TestClient, test_admin_role: Role ) -> None: - mock_session = await app.dependency_overrides[get_async_session]() - - user_result = Mock() - user_result.scalar_one_or_none.return_value = None - mock_session.execute.side_effect = [user_result] - + app.dependency_overrides.pop(current_active_user, None) response = client.get("/organization/invitations/pending/me") - assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.json()["detail"] == "User not found" + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/tests/unit/test_dependencies.py b/tests/unit/test_dependencies.py index 159465c0a7..93442029df 100644 --- a/tests/unit/test_dependencies.py +++ b/tests/unit/test_dependencies.py @@ -45,24 +45,24 @@ async def test_verify_auth_type_not_allowed( @pytest.mark.anyio async def test_verify_auth_type_setting_disabled(mocker: MockerFixture): """Test that disabled auth types raise HTTPException.""" - mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.BASIC]) + mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.SAML]) mocker.patch("tracecat.auth.dependencies.get_setting", return_value=False) with pytest.raises(HTTPException) as exc: - await verify_auth_type(AuthType.BASIC) + await verify_auth_type(AuthType.SAML) assert exc.value.status_code == status.HTTP_403_FORBIDDEN - assert exc.value.detail == f"Auth type {AuthType.BASIC.value} is not enabled" + assert exc.value.detail == f"Auth type {AuthType.SAML.value} is not enabled" @pytest.mark.anyio async def test_verify_auth_type_invalid_setting(mocker: MockerFixture): """Test that invalid settings raise HTTPException.""" - mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.BASIC]) + mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.SAML]) mocker.patch("tracecat.auth.dependencies.get_setting", return_value=None) with pytest.raises(HTTPException) as exc: - await verify_auth_type(AuthType.BASIC) + await verify_auth_type(AuthType.SAML) assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert exc.value.detail == "Invalid setting configuration" diff --git a/tests/unit/test_workflow_executions.py b/tests/unit/test_workflow_executions.py index b6625cbcfb..47c9fef895 100644 --- a/tests/unit/test_workflow_executions.py +++ b/tests/unit/test_workflow_executions.py @@ -10,6 +10,7 @@ """ +import asyncio import datetime from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch @@ -148,6 +149,19 @@ def create_mock_history_event( return event +def create_test_dsl_input() -> DSLInput: + """Create a minimal valid DSLInput for tests.""" + return DSLInput.model_validate( + { + "title": "Webhook test workflow", + "description": "Test workflow", + "entrypoint": {"ref": "start"}, + "actions": [{"ref": "start", "action": "core.noop"}], + "config": {"enable_runtime_tests": False}, + } + ) + + @pytest.mark.anyio class TestWorkflowExecutionEvents: """Test workflow execution events functionality.""" @@ -903,15 +917,7 @@ async def test_create_workflow_execution_wait_for_start_acknowledges_temporal_st service = WorkflowExecutionsService(client=mock_client, role=mock_role) mock_client.start_workflow = AsyncMock(return_value=Mock(spec=WorkflowHandle)) - dsl = DSLInput.model_validate( - { - "title": "Webhook test workflow", - "description": "Test workflow", - "entrypoint": {"ref": "start"}, - "actions": [{"ref": "start", "action": "core.noop"}], - "config": {"enable_runtime_tests": False}, - } - ) + dsl = create_test_dsl_input() wf_id = WorkflowUUID.new("wf_4itKqkgCZrLhgYiq5L211X") with patch.object( @@ -932,3 +938,123 @@ async def test_create_workflow_execution_wait_for_start_acknowledges_temporal_st assert ( mock_client.start_workflow.await_args.kwargs["id"] == response["wf_exec_id"] ) + + +class TestWorkflowNowaitBackgroundBehavior: + @pytest.mark.anyio + async def test_create_workflow_execution_nowait_schedules_start_only( + self, + mock_client: Mock, + mock_role: Role, + ) -> None: + service = WorkflowExecutionsService(client=mock_client, role=mock_role) + dsl = create_test_dsl_input() + wf_id = WorkflowUUID.new("wf_5K8NL5TYLRM8JqkDnGzYdE") + + mock_wait_for_start = AsyncMock( + return_value={ + "message": "Workflow execution started", + "wf_id": wf_id, + "wf_exec_id": "unused", + } + ) + mock_wait_for_completion = AsyncMock() + + with ( + patch.object( + service, + "create_workflow_execution_wait_for_start", + mock_wait_for_start, + ), + patch.object( + service, + "create_workflow_execution", + mock_wait_for_completion, + ), + ): + response = service.create_workflow_execution_nowait( + dsl=dsl, + wf_id=wf_id, + payload=None, + trigger_type=TriggerType.MANUAL, + ) + await asyncio.sleep(0) + + mock_wait_for_start.assert_awaited_once() + mock_wait_for_completion.assert_not_called() + assert mock_wait_for_start.await_args is not None + assert response["wf_exec_id"] == mock_wait_for_start.await_args.kwargs["wf_exec_id"] + + @pytest.mark.anyio + async def test_create_draft_workflow_execution_nowait_schedules_start_only( + self, + mock_client: Mock, + mock_role: Role, + ) -> None: + service = WorkflowExecutionsService(client=mock_client, role=mock_role) + dsl = create_test_dsl_input() + wf_id = WorkflowUUID.new("wf_4fHecX13GwQY74HCAS4j7L") + + mock_wait_for_start = AsyncMock( + return_value={ + "message": "Draft workflow execution started", + "wf_id": wf_id, + "wf_exec_id": "unused", + } + ) + mock_wait_for_completion = AsyncMock() + + with ( + patch.object( + service, + "create_draft_workflow_execution_wait_for_start", + mock_wait_for_start, + ), + patch.object( + service, + "create_draft_workflow_execution", + mock_wait_for_completion, + ), + ): + response = service.create_draft_workflow_execution_nowait( + dsl=dsl, + wf_id=wf_id, + payload=None, + trigger_type=TriggerType.MANUAL, + ) + await asyncio.sleep(0) + + mock_wait_for_start.assert_awaited_once() + mock_wait_for_completion.assert_not_called() + assert mock_wait_for_start.await_args is not None + assert response["wf_exec_id"] == mock_wait_for_start.await_args.kwargs["wf_exec_id"] + + @pytest.mark.anyio + async def test_background_start_swallows_exceptions( + self, + mock_client: Mock, + mock_role: Role, + ) -> None: + service = WorkflowExecutionsService(client=mock_client, role=mock_role) + dsl = create_test_dsl_input() + wf_id = WorkflowUUID.new("wf_5eP17nM2Cc7Bf1tBwNRar9") + + mock_wait_for_start = AsyncMock(side_effect=RuntimeError("boom")) + with ( + patch.object( + service, + "create_workflow_execution_wait_for_start", + mock_wait_for_start, + ), + patch.object(service.logger, "error") as mock_error, + ): + _ = service.create_workflow_execution_nowait( + dsl=dsl, + wf_id=wf_id, + payload=None, + trigger_type=TriggerType.MANUAL, + ) + await asyncio.sleep(0) + + mock_wait_for_start.assert_awaited_once() + mock_error.assert_called_once() diff --git a/tracecat/workflow/executions/service.py b/tracecat/workflow/executions/service.py index 52e3d5eee4..f4154033b2 100644 --- a/tracecat/workflow/executions/service.py +++ b/tracecat/workflow/executions/service.py @@ -102,6 +102,71 @@ def _handle_background_task_exception(self, task: asyncio.Task[Any]) -> None: exception=str(exc), ) + async def _start_workflow_execution_background( + self, + dsl: DSLInput, + *, + wf_id: WorkflowID, + wf_exec_id: WorkflowExecutionID, + payload: TriggerInputs | None = None, + trigger_type: TriggerType = TriggerType.MANUAL, + time_anchor: datetime.datetime | None = None, + registry_lock: RegistryLock | None = None, + memo: dict[str, Any] | None = None, + ) -> None: + """Start a published workflow in the background and swallow start errors. + + Background dispatches should never bubble exceptions back to the event loop. + """ + try: + await self.create_workflow_execution_wait_for_start( + dsl=dsl, + wf_id=wf_id, + payload=payload, + trigger_type=trigger_type, + wf_exec_id=wf_exec_id, + time_anchor=time_anchor, + registry_lock=registry_lock, + memo=memo, + ) + except Exception as e: + self.logger.error( + "Failed to start background workflow execution", + role=self.role, + wf_exec_id=wf_exec_id, + error=str(e), + ) + + async def _start_draft_workflow_execution_background( + self, + dsl: DSLInput, + *, + wf_id: WorkflowID, + wf_exec_id: WorkflowExecutionID, + payload: TriggerInputs | None = None, + trigger_type: TriggerType = TriggerType.MANUAL, + time_anchor: datetime.datetime | None = None, + registry_lock: RegistryLock | None = None, + ) -> None: + """Start a draft workflow in the background and swallow start errors.""" + try: + await self.create_draft_workflow_execution_wait_for_start( + dsl=dsl, + wf_id=wf_id, + payload=payload, + trigger_type=trigger_type, + wf_exec_id=wf_exec_id, + time_anchor=time_anchor, + registry_lock=registry_lock, + ) + except Exception as e: + self.logger.error( + "Failed to start background draft workflow execution", + role=self.role, + wf_exec_id=wf_exec_id, + error=str(e), + ) + def handle( self, wf_exec_id: WorkflowExecutionID ) -> WorkflowHandle[DSLWorkflow, StoredObject]: @@ -700,14 +765,14 @@ def create_workflow_execution_nowait( ) -> WorkflowExecutionCreateResponse: """Create a new workflow execution. - Note: This method schedules the workflow execution and returns immediately. + Note: This method schedules the workflow start and returns immediately. Args: memo: Optional memo dict to store with the workflow execution. Useful for correlation (e.g., parent_wf_exec_id). """ wf_exec_id = generate_exec_id(wf_id) - coro = self.create_workflow_execution( + coro = self._start_workflow_execution_background( dsl=dsl, wf_id=wf_id, payload=payload, @@ -773,10 +838,10 @@ def create_draft_workflow_execution_nowait( """Create a new draft workflow execution. Draft executions use the draft workflow graph and resolve aliases from draft workflows. - This method schedules the workflow execution and returns immediately. + This method schedules workflow start and returns immediately. """ wf_exec_id = generate_exec_id(wf_id) - coro = self.create_draft_workflow_execution( + coro = self._start_draft_workflow_execution_background( dsl=dsl, wf_id=wf_id, payload=payload,