Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions tests/unit/api/test_api_organization_invitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from tracecat.api.app import app
from tracecat.auth.types import Role
from tracecat.auth.users import current_active_user
from tracecat.authz.enums import OrgRole
from tracecat.db.engine import get_async_session

Expand Down Expand Up @@ -50,7 +51,8 @@ async def test_list_my_pending_invitations_success(
pending_result = Mock()
pending_result.tuples.return_value = tuples_result

mock_session.execute.side_effect = [user_result, pending_result]
app.dependency_overrides[current_active_user] = lambda: mock_user
mock_session.execute.side_effect = [pending_result]

response = client.get("/organization/invitations/pending/me")

Expand All @@ -69,13 +71,7 @@ async def test_list_my_pending_invitations_success(
async def test_list_my_pending_invitations_user_not_found(
client: TestClient, test_admin_role: Role
) -> None:
mock_session = await app.dependency_overrides[get_async_session]()

user_result = Mock()
user_result.scalar_one_or_none.return_value = None
mock_session.execute.side_effect = [user_result]

app.dependency_overrides.pop(current_active_user, None)
response = client.get("/organization/invitations/pending/me")

assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json()["detail"] == "User not found"
assert response.status_code == status.HTTP_401_UNAUTHORIZED
10 changes: 5 additions & 5 deletions tests/unit/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,24 @@ async def test_verify_auth_type_not_allowed(
@pytest.mark.anyio
async def test_verify_auth_type_setting_disabled(mocker: MockerFixture):
"""Test that disabled auth types raise HTTPException."""
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.BASIC])
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.SAML])
mocker.patch("tracecat.auth.dependencies.get_setting", return_value=False)

with pytest.raises(HTTPException) as exc:
await verify_auth_type(AuthType.BASIC)
await verify_auth_type(AuthType.SAML)

assert exc.value.status_code == status.HTTP_403_FORBIDDEN
assert exc.value.detail == f"Auth type {AuthType.BASIC.value} is not enabled"
assert exc.value.detail == f"Auth type {AuthType.SAML.value} is not enabled"


@pytest.mark.anyio
async def test_verify_auth_type_invalid_setting(mocker: MockerFixture):
"""Test that invalid settings raise HTTPException."""
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.BASIC])
mocker.patch("tracecat.config.TRACECAT__AUTH_TYPES", [AuthType.SAML])
mocker.patch("tracecat.auth.dependencies.get_setting", return_value=None)

with pytest.raises(HTTPException) as exc:
await verify_auth_type(AuthType.BASIC)
await verify_auth_type(AuthType.SAML)

assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert exc.value.detail == "Invalid setting configuration"
Expand Down
144 changes: 135 additions & 9 deletions tests/unit/test_workflow_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

"""

import asyncio
import datetime
from typing import Any, cast
from unittest.mock import AsyncMock, Mock, patch
Expand Down Expand Up @@ -148,6 +149,19 @@ def create_mock_history_event(
return event


def create_test_dsl_input() -> DSLInput:
"""Create a minimal valid DSLInput for tests."""
return DSLInput.model_validate(
{
"title": "Webhook test workflow",
"description": "Test workflow",
"entrypoint": {"ref": "start"},
"actions": [{"ref": "start", "action": "core.noop"}],
"config": {"enable_runtime_tests": False},
}
)


@pytest.mark.anyio
class TestWorkflowExecutionEvents:
"""Test workflow execution events functionality."""
Expand Down Expand Up @@ -903,15 +917,7 @@ async def test_create_workflow_execution_wait_for_start_acknowledges_temporal_st
service = WorkflowExecutionsService(client=mock_client, role=mock_role)
mock_client.start_workflow = AsyncMock(return_value=Mock(spec=WorkflowHandle))

dsl = DSLInput.model_validate(
{
"title": "Webhook test workflow",
"description": "Test workflow",
"entrypoint": {"ref": "start"},
"actions": [{"ref": "start", "action": "core.noop"}],
"config": {"enable_runtime_tests": False},
}
)
dsl = create_test_dsl_input()
wf_id = WorkflowUUID.new("wf_4itKqkgCZrLhgYiq5L211X")

with patch.object(
Expand All @@ -932,3 +938,123 @@ async def test_create_workflow_execution_wait_for_start_acknowledges_temporal_st
assert (
mock_client.start_workflow.await_args.kwargs["id"] == response["wf_exec_id"]
)


class TestWorkflowNowaitBackgroundBehavior:
@pytest.mark.anyio
async def test_create_workflow_execution_nowait_schedules_start_only(
self,
mock_client: Mock,
mock_role: Role,
) -> None:
service = WorkflowExecutionsService(client=mock_client, role=mock_role)
dsl = create_test_dsl_input()
wf_id = WorkflowUUID.new("wf_5K8NL5TYLRM8JqkDnGzYdE")

mock_wait_for_start = AsyncMock(
return_value={
"message": "Workflow execution started",
"wf_id": wf_id,
"wf_exec_id": "unused",
}
)
mock_wait_for_completion = AsyncMock()

with (
patch.object(
service,
"create_workflow_execution_wait_for_start",
mock_wait_for_start,
),
patch.object(
service,
"create_workflow_execution",
mock_wait_for_completion,
),
):
response = service.create_workflow_execution_nowait(
dsl=dsl,
wf_id=wf_id,
payload=None,
trigger_type=TriggerType.MANUAL,
)
await asyncio.sleep(0)

mock_wait_for_start.assert_awaited_once()
mock_wait_for_completion.assert_not_called()
assert mock_wait_for_start.await_args is not None
assert response["wf_exec_id"] == mock_wait_for_start.await_args.kwargs["wf_exec_id"]

@pytest.mark.anyio
async def test_create_draft_workflow_execution_nowait_schedules_start_only(
self,
mock_client: Mock,
mock_role: Role,
) -> None:
service = WorkflowExecutionsService(client=mock_client, role=mock_role)
dsl = create_test_dsl_input()
wf_id = WorkflowUUID.new("wf_4fHecX13GwQY74HCAS4j7L")

mock_wait_for_start = AsyncMock(
return_value={
"message": "Draft workflow execution started",
"wf_id": wf_id,
"wf_exec_id": "unused",
}
)
mock_wait_for_completion = AsyncMock()

with (
patch.object(
service,
"create_draft_workflow_execution_wait_for_start",
mock_wait_for_start,
),
patch.object(
service,
"create_draft_workflow_execution",
mock_wait_for_completion,
),
):
response = service.create_draft_workflow_execution_nowait(
dsl=dsl,
wf_id=wf_id,
payload=None,
trigger_type=TriggerType.MANUAL,
)
await asyncio.sleep(0)

mock_wait_for_start.assert_awaited_once()
mock_wait_for_completion.assert_not_called()
assert mock_wait_for_start.await_args is not None
assert response["wf_exec_id"] == mock_wait_for_start.await_args.kwargs["wf_exec_id"]

@pytest.mark.anyio
async def test_background_start_swallows_exceptions(
self,
mock_client: Mock,
mock_role: Role,
) -> None:
service = WorkflowExecutionsService(client=mock_client, role=mock_role)
dsl = create_test_dsl_input()
wf_id = WorkflowUUID.new("wf_5eP17nM2Cc7Bf1tBwNRar9")

mock_wait_for_start = AsyncMock(side_effect=RuntimeError("boom"))
with (
patch.object(
service,
"create_workflow_execution_wait_for_start",
mock_wait_for_start,
),
patch.object(service.logger, "error") as mock_error,
):
_ = service.create_workflow_execution_nowait(
dsl=dsl,
wf_id=wf_id,
payload=None,
trigger_type=TriggerType.MANUAL,
)
await asyncio.sleep(0)

mock_wait_for_start.assert_awaited_once()
mock_error.assert_called_once()
73 changes: 69 additions & 4 deletions tracecat/workflow/executions/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,71 @@ def _handle_background_task_exception(self, task: asyncio.Task[Any]) -> None:
exception=str(exc),
)

async def _start_workflow_execution_background(
self,
dsl: DSLInput,
*,
wf_id: WorkflowID,
wf_exec_id: WorkflowExecutionID,
payload: TriggerInputs | None = None,
trigger_type: TriggerType = TriggerType.MANUAL,
time_anchor: datetime.datetime | None = None,
registry_lock: RegistryLock | None = None,
memo: dict[str, Any] | None = None,
) -> None:
"""Start a published workflow in the background and swallow start errors.

Background dispatches should never bubble exceptions back to the event loop.
"""
try:
await self.create_workflow_execution_wait_for_start(
dsl=dsl,
wf_id=wf_id,
payload=payload,
trigger_type=trigger_type,
wf_exec_id=wf_exec_id,
time_anchor=time_anchor,
registry_lock=registry_lock,
memo=memo,
)
except Exception as e:
self.logger.error(
"Failed to start background workflow execution",
role=self.role,
wf_exec_id=wf_exec_id,
error=str(e),
)

async def _start_draft_workflow_execution_background(
self,
dsl: DSLInput,
*,
wf_id: WorkflowID,
wf_exec_id: WorkflowExecutionID,
payload: TriggerInputs | None = None,
trigger_type: TriggerType = TriggerType.MANUAL,
time_anchor: datetime.datetime | None = None,
registry_lock: RegistryLock | None = None,
) -> None:
"""Start a draft workflow in the background and swallow start errors."""
try:
await self.create_draft_workflow_execution_wait_for_start(
dsl=dsl,
wf_id=wf_id,
payload=payload,
trigger_type=trigger_type,
wf_exec_id=wf_exec_id,
time_anchor=time_anchor,
registry_lock=registry_lock,
)
except Exception as e:
self.logger.error(
"Failed to start background draft workflow execution",
role=self.role,
wf_exec_id=wf_exec_id,
error=str(e),
)

def handle(
self, wf_exec_id: WorkflowExecutionID
) -> WorkflowHandle[DSLWorkflow, StoredObject]:
Expand Down Expand Up @@ -700,14 +765,14 @@ def create_workflow_execution_nowait(
) -> WorkflowExecutionCreateResponse:
"""Create a new workflow execution.

Note: This method schedules the workflow execution and returns immediately.
Note: This method schedules the workflow start and returns immediately.

Args:
memo: Optional memo dict to store with the workflow execution.
Useful for correlation (e.g., parent_wf_exec_id).
"""
wf_exec_id = generate_exec_id(wf_id)
coro = self.create_workflow_execution(
coro = self._start_workflow_execution_background(
dsl=dsl,
wf_id=wf_id,
payload=payload,
Expand Down Expand Up @@ -773,10 +838,10 @@ def create_draft_workflow_execution_nowait(
"""Create a new draft workflow execution.

Draft executions use the draft workflow graph and resolve aliases from draft workflows.
This method schedules the workflow execution and returns immediately.
This method schedules workflow start and returns immediately.
"""
wf_exec_id = generate_exec_id(wf_id)
coro = self.create_draft_workflow_execution(
coro = self._start_draft_workflow_execution_background(
dsl=dsl,
wf_id=wf_id,
payload=payload,
Expand Down