From b195983f5a5d4b1d934cf120f89065648676c3e3 Mon Sep 17 00:00:00 2001 From: Jordan Umusu Date: Wed, 21 Jan 2026 18:39:45 -0500 Subject: [PATCH] feat(agent): add agent interrupt support --- .../d8d9404cdee2_add_agent_session_status.py | 40 ++++++++ frontend/src/client/schemas.gen.ts | 31 +++++++ frontend/src/client/services.gen.ts | 37 ++++++++ frontend/src/client/types.gen.ts | 47 ++++++++++ .../components/ai-elements/prompt-input.tsx | 13 ++- .../src/components/chat/chat-session-pane.tsx | 25 +++-- .../components/copilot/copilot-chat-pane.tsx | 23 +++-- frontend/src/hooks/use-chat.ts | 91 +++++++++++++++++++ frontend/src/lib/agents.ts | 7 +- .../tracecat_ee/agent/workflows/durable.py | 53 ++++++++++- tracecat/agent/executor/activity.py | 13 +++ tracecat/agent/executor/loopback.py | 42 ++++++++- tracecat/agent/schemas.py | 1 + tracecat/agent/session/activities.py | 36 +++++++- tracecat/agent/session/router.py | 37 +++++++- tracecat/agent/session/schemas.py | 4 +- tracecat/agent/session/service.py | 73 ++++++++++++++- tracecat/agent/session/types.py | 18 ++++ tracecat/db/models.py | 8 ++ 19 files changed, 567 insertions(+), 32 deletions(-) create mode 100644 alembic/versions/d8d9404cdee2_add_agent_session_status.py diff --git a/alembic/versions/d8d9404cdee2_add_agent_session_status.py b/alembic/versions/d8d9404cdee2_add_agent_session_status.py new file mode 100644 index 0000000000..039694ea98 --- /dev/null +++ b/alembic/versions/d8d9404cdee2_add_agent_session_status.py @@ -0,0 +1,40 @@ +"""add_agent_session_status + +Revision ID: d8d9404cdee2 +Revises: c7737fa6338a +Create Date: 2026-01-21 17:40:55.526398 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d8d9404cdee2" +down_revision: str | None = "c7737fa6338a" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "agent_session", + sa.Column( + "status", sa.String(length=20), server_default="idle", nullable=False + ), + ) + op.create_index( + op.f("ix_agent_session_status"), "agent_session", ["status"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_agent_session_status"), table_name="agent_session") + op.drop_column("agent_session", "status") + # ### end Alembic commands ### diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index fd24645726..3dcb743a3f 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -891,6 +891,11 @@ export const $AgentOutput = { format: "uuid", title: "Session Id", }, + interrupted: { + type: "boolean", + title: "Interrupted", + default: false, + }, }, type: "object", required: ["output", "duration", "session_id"], @@ -1675,6 +1680,10 @@ export const $AgentSessionRead = { ], title: "Harness Type", }, + status: { + $ref: "#/components/schemas/AgentSessionStatus", + default: "idle", + }, last_stream_id: { anyOf: [ { @@ -1801,6 +1810,10 @@ export const $AgentSessionReadVercel = { ], title: "Harness Type", }, + status: { + $ref: "#/components/schemas/AgentSessionStatus", + default: "idle", + }, last_stream_id: { anyOf: [ { @@ -1935,6 +1948,10 @@ export const $AgentSessionReadWithMessages = { ], title: "Harness Type", }, + status: { + $ref: "#/components/schemas/AgentSessionStatus", + default: "idle", + }, last_stream_id: { anyOf: [ { @@ -1993,6 +2010,20 @@ export const $AgentSessionReadWithMessages = { description: "Response schema for agent session with message history.", } as const +export const $AgentSessionStatus = { + type: "string", + enum: ["idle", "running", "interrupted", "completed", "failed"], + title: "AgentSessionStatus", + description: `Status of an agent session. + +Tracks the lifecycle state of an agent session: +- IDLE: No active workflow running +- RUNNING: Workflow currently executing +- INTERRUPTED: User requested interrupt (transient state) +- COMPLETED: Last run completed successfully +- FAILED: Last run failed`, +} as const + export const $AgentSessionUpdate = { properties: { title: { diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 48d99ede19..b717d8f05c 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -87,6 +87,8 @@ import type { AgentSessionsGetSessionResponse, AgentSessionsGetSessionVercelData, AgentSessionsGetSessionVercelResponse, + AgentSessionsInterruptSessionData, + AgentSessionsInterruptSessionResponse, AgentSessionsListSessionsData, AgentSessionsListSessionsResponse, AgentSessionsSendMessageData, @@ -3543,6 +3545,41 @@ export const agentSessionsForkSession = ( }) } +/** + * Interrupt Session + * Request interruption of a running agent session. + * + * Marks the session for interrupt. The agent executor will detect this + * status change and terminate execution cleanly, emitting stream.done() + * to prevent the frontend from hanging. + * + * Returns: + * {"interrupted": true} if the session was running and is now interrupted, + * {"interrupted": false} if the session was not in a running state. + * @param data The data for the request. + * @param data.sessionId + * @param data.workspaceId + * @returns boolean Successful Response + * @throws ApiError + */ +export const agentSessionsInterruptSession = ( + data: AgentSessionsInterruptSessionData +): CancelablePromise => { + return __request(OpenAPI, { + method: "POST", + url: "/agent/sessions/{session_id}/interrupt", + path: { + session_id: data.sessionId, + }, + query: { + workspace_id: data.workspaceId, + }, + errors: { + 422: "Validation Error", + }, + }) +} + /** * Submit Approvals * Submit approval decisions to a running agent workflow. diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index d507c26569..e9f5f068e6 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -223,6 +223,7 @@ export type AgentOutput = { duration: number usage?: RunUsage | null session_id: string + interrupted?: boolean } export type AgentPreset = { @@ -392,6 +393,7 @@ export type AgentSessionRead = { tools: Array | null agent_preset_id: string | null harness_type: string | null + status?: AgentSessionStatus last_stream_id?: string | null parent_session_id?: string | null created_at: string @@ -411,6 +413,7 @@ export type AgentSessionReadVercel = { tools: Array | null agent_preset_id: string | null harness_type: string | null + status?: AgentSessionStatus last_stream_id?: string | null parent_session_id?: string | null created_at: string @@ -434,6 +437,7 @@ export type AgentSessionReadWithMessages = { tools: Array | null agent_preset_id: string | null harness_type: string | null + status?: AgentSessionStatus last_stream_id?: string | null parent_session_id?: string | null created_at: string @@ -444,6 +448,23 @@ export type AgentSessionReadWithMessages = { messages?: Array } +/** + * Status of an agent session. + * + * Tracks the lifecycle state of an agent session: + * - IDLE: No active workflow running + * - RUNNING: Workflow currently executing + * - INTERRUPTED: User requested interrupt (transient state) + * - COMPLETED: Last run completed successfully + * - FAILED: Last run failed + */ +export type AgentSessionStatus = + | "idle" + | "running" + | "interrupted" + | "completed" + | "failed" + /** * Request schema for updating an agent session. */ @@ -6802,6 +6823,15 @@ export type AgentSessionsForkSessionData = { export type AgentSessionsForkSessionResponse = AgentSessionRead +export type AgentSessionsInterruptSessionData = { + sessionId: string + workspaceId: string +} + +export type AgentSessionsInterruptSessionResponse = { + [key: string]: boolean +} + export type ApprovalsSubmitApprovalsData = { requestBody: ApprovalSubmission sessionId: string @@ -9598,6 +9628,23 @@ export type $OpenApiTs = { } } } + "/agent/sessions/{session_id}/interrupt": { + post: { + req: AgentSessionsInterruptSessionData + res: { + /** + * Successful Response + */ + 200: { + [key: string]: boolean + } + /** + * Validation Error + */ + 422: HTTPValidationError + } + } + } "/approvals/{session_id}": { post: { req: ApprovalsSubmitApprovalsData diff --git a/frontend/src/components/ai-elements/prompt-input.tsx b/frontend/src/components/ai-elements/prompt-input.tsx index 29ef00c1c4..1de0a0046e 100644 --- a/frontend/src/components/ai-elements/prompt-input.tsx +++ b/frontend/src/components/ai-elements/prompt-input.tsx @@ -678,6 +678,7 @@ export const PromptInputActionMenuItem = ({ export type PromptInputSubmitProps = ComponentProps & { status?: ChatStatus + onInterrupt?: () => void } export const PromptInputSubmit = ({ @@ -686,6 +687,7 @@ export const PromptInputSubmit = ({ size = "icon", status, children, + onInterrupt, ...props }: PromptInputSubmitProps) => { let Icon = @@ -698,17 +700,22 @@ export const PromptInputSubmit = ({ Icon = } + // When streaming, the button becomes an interrupt/stop button + const isStreaming = status === "streaming" + const handleClick = isStreaming && onInterrupt ? onInterrupt : undefined + return ( ) } diff --git a/frontend/src/components/chat/chat-session-pane.tsx b/frontend/src/components/chat/chat-session-pane.tsx index d5ef6cfdf2..120a4346a1 100644 --- a/frontend/src/components/chat/chat-session-pane.tsx +++ b/frontend/src/components/chat/chat-session-pane.tsx @@ -153,13 +153,21 @@ export function ChatSessionPane({ () => (chat?.messages || []).map(toUIMessage), [chat?.messages] ) - const { sendMessage, messages, status, regenerate, lastError, clearError } = - useVercelChat({ - chatId: chat.id, - workspaceId, - messages: uiMessages, - modelInfo, - }) + + const { + sendMessage, + messages, + status, + regenerate, + lastError, + clearError, + interrupt, + } = useVercelChat({ + chatId: chat.id, + workspaceId, + messages: uiMessages, + modelInfo, + }) // Track whether we've sent the pending message to avoid double-sends const pendingMessageSentRef = useRef(false) @@ -451,8 +459,9 @@ export function ChatSessionPane({ )} diff --git a/frontend/src/components/copilot/copilot-chat-pane.tsx b/frontend/src/components/copilot/copilot-chat-pane.tsx index 6999416106..a63b6011a1 100644 --- a/frontend/src/components/copilot/copilot-chat-pane.tsx +++ b/frontend/src/components/copilot/copilot-chat-pane.tsx @@ -89,13 +89,21 @@ export function CopilotChatPane({ () => (chat?.messages || []).map(toUIMessage), [chat?.messages] ) - const { sendMessage, messages, status, regenerate, lastError, clearError } = - useVercelChat({ - chatId: chat.id, - workspaceId, - messages: uiMessages, - modelInfo, - }) + + const { + sendMessage, + messages, + status, + regenerate, + lastError, + clearError, + interrupt, + } = useVercelChat({ + chatId: chat.id, + workspaceId, + messages: uiMessages, + modelInfo, + }) const isWaitingForResponse = useMemo(() => { if (status === "submitted") return true @@ -248,6 +256,7 @@ export function CopilotChatPane({ diff --git a/frontend/src/hooks/use-chat.ts b/frontend/src/hooks/use-chat.ts index 538b6840f2..2e2cc156fd 100644 --- a/frontend/src/hooks/use-chat.ts +++ b/frontend/src/hooks/use-chat.ts @@ -13,6 +13,7 @@ import { type AgentSessionRead, type AgentSessionsGetSessionResponse, type AgentSessionsGetSessionVercelResponse, + type AgentSessionsInterruptSessionResponse, type AgentSessionsListSessionsResponse, type AgentSessionUpdate, type ApiError, @@ -20,6 +21,7 @@ import { agentSessionsDeleteSession, agentSessionsGetSession, agentSessionsGetSessionVercel, + agentSessionsInterruptSession, agentSessionsListSessions, agentSessionsUpdateSession, type ContinueRunRequest, @@ -273,6 +275,7 @@ export function useVercelChat({ }) { const queryClient = useQueryClient() const [lastError, setLastError] = useState(null) + const [isInterrupting, setIsInterrupting] = useState(false) // Build the Vercel streaming endpoint URL const apiEndpoint = useMemo(() => { @@ -334,10 +337,98 @@ export function useVercelChat({ }, }) + // Interrupt handler for stopping a running session + // Uses stop() to abort the client-side stream, then calls backend to kill the process. + const interrupt = async () => { + if (!chatId || isInterrupting) return + setIsInterrupting(true) + try { + // 1. Abort the client-side stream immediately + chat.stop() + + // 2. Tell backend to kill the process + const result = await agentSessionsInterruptSession({ + sessionId: chatId, + workspaceId, + }) + if (result.interrupted) { + toast({ + title: "Session interrupted", + description: "The agent has been stopped.", + }) + } + + // 3. Invalidate queries to refresh session state + queryClient.invalidateQueries({ + queryKey: ["chat", chatId, workspaceId], + }) + queryClient.invalidateQueries({ + queryKey: ["chat", chatId, workspaceId, "vercel"], + }) + } catch (error) { + console.error("Failed to interrupt session:", error) + toast({ + variant: "destructive", + title: "Failed to stop", + description: "Could not interrupt the session. Please try again.", + }) + } finally { + setIsInterrupting(false) + } + } + return { ...chat, lastError, clearError: () => setLastError(null), + interrupt, + isInterrupting, + } +} + +// Hook for interrupting a running chat session +export function useInterruptChat(workspaceId: string) { + const queryClient = useQueryClient() + + const mutation = useMutation< + AgentSessionsInterruptSessionResponse, + ApiError, + { chatId: string } + >({ + mutationFn: ({ chatId }) => + agentSessionsInterruptSession({ + sessionId: chatId, + workspaceId, + }), + onSuccess: (data, variables) => { + if (data.interrupted) { + toast({ + title: "Session interrupted", + description: "The agent has been stopped.", + }) + } + // Invalidate queries to refresh session state + queryClient.invalidateQueries({ + queryKey: ["chat", variables.chatId, workspaceId], + }) + queryClient.invalidateQueries({ + queryKey: ["chat", variables.chatId, workspaceId, "vercel"], + }) + }, + onError: (error) => { + console.error("Failed to interrupt session:", error) + toast({ + variant: "destructive", + title: "Failed to stop", + description: "Could not interrupt the session. Please try again.", + }) + }, + }) + + return { + interruptChat: mutation.mutateAsync, + isInterrupting: mutation.isPending, + interruptError: mutation.error, } } diff --git a/frontend/src/lib/agents.ts b/frontend/src/lib/agents.ts index 602d4dfc6b..68ee9ed0cf 100644 --- a/frontend/src/lib/agents.ts +++ b/frontend/src/lib/agents.ts @@ -28,7 +28,8 @@ export type AgentSessionReadWithMeta = SessionBase & { action_ref?: string | null action_title?: string | null approvals?: ApprovalRead[] | null - status?: WorkflowExecutionStatus | null + /** Temporal workflow execution status (numeric enum) */ + workflow_status?: WorkflowExecutionStatus | null parent_id?: string | null parent_run_id?: string | null root_id?: string | null @@ -158,8 +159,8 @@ export function enrichAgentSession( } const temporalStatus = - session.status != null - ? (TEMPORAL_STATUS_MAP[session.status] ?? null) + session.workflow_status != null + ? (TEMPORAL_STATUS_MAP[session.workflow_status] ?? null) : null const derivedStatus: AgentDerivedStatus = temporalStatus ?? "UNKNOWN" const metadata = STATUS_METADATA[derivedStatus] diff --git a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py index 6fb389bdc8..f5674d6937 100644 --- a/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py +++ b/packages/tracecat-ee/tracecat_ee/agent/workflows/durable.py @@ -29,10 +29,12 @@ from tracecat.agent.session.activities import ( CreateSessionInput, LoadSessionInput, + UpdateSessionStatusInput, create_session_activity, load_session_activity, + update_session_status_activity, ) - from tracecat.agent.session.types import AgentSessionEntity + from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.tokens import ( InternalToolContext, mint_llm_token, @@ -332,11 +334,48 @@ async def _run_with_nsjail( ) if not result.success: + # Update status to failed + await workflow.execute_activity( + update_session_status_activity, + UpdateSessionStatusInput( + role=self.role, + session_id=self.session_id, + status=AgentSessionStatus.FAILED, + ), + start_to_close_timeout=timedelta(seconds=30), + retry_policy=RETRY_POLICIES["activity:fail_fast"], + ) raise ApplicationError( f"Agent execution failed: {result.error}", non_retryable=True, ) + # Check if execution was interrupted + if result.interrupted: + logger.info( + "Agent execution interrupted by user", + session_id=self.session_id, + ) + # Update status to idle (interrupt complete) + await workflow.execute_activity( + update_session_status_activity, + UpdateSessionStatusInput( + role=self.role, + session_id=self.session_id, + status=AgentSessionStatus.IDLE, + ), + start_to_close_timeout=timedelta(seconds=30), + retry_policy=RETRY_POLICIES["activity:fail_fast"], + ) + return AgentOutput( + output=None, + message_history=result.messages, + duration=(datetime.now(UTC) - info.start_time).total_seconds(), + usage=RunUsage(), + session_id=self.session_id, + interrupted=True, + ) + if result.approval_requested: logger.info("Agent waiting for approval", session_id=self.session_id) # Convert ToolCallContent to ToolCallPart for ApprovalManager @@ -422,7 +461,17 @@ async def _run_with_nsjail( self._turn += 1 continue - # Agent completed successfully + # Agent completed successfully - update status to completed + await workflow.execute_activity( + update_session_status_activity, + UpdateSessionStatusInput( + role=self.role, + session_id=self.session_id, + status=AgentSessionStatus.COMPLETED, + ), + start_to_close_timeout=timedelta(seconds=30), + retry_policy=RETRY_POLICIES["activity:fail_fast"], + ) return AgentOutput( output=None, # NSJail path doesn't return structured output yet message_history=result.messages, # Messages fetched from DB by activity diff --git a/tracecat/agent/executor/activity.py b/tracecat/agent/executor/activity.py index c6890555e0..a37f2a0cca 100644 --- a/tracecat/agent/executor/activity.py +++ b/tracecat/agent/executor/activity.py @@ -97,6 +97,7 @@ class AgentExecutorResult(BaseModel): approval_requested: bool = False approval_items: list[ToolCallContent] | None = None messages: list[ChatMessage] | None = None + interrupted: bool = False @dataclass @@ -368,11 +369,23 @@ async def wait_process_exit() -> tuple[int, str]: "Loopback result received", success=loopback_result.success, error=loopback_result.error, + interrupted=loopback_result.interrupted, ) result.success = loopback_result.success result.error = loopback_result.error result.approval_requested = loopback_result.approval_requested result.approval_items = loopback_result.approval_items or None + result.interrupted = loopback_result.interrupted + + # If interrupted, kill the process immediately + if loopback_result.interrupted: + logger.info( + "Killing runtime process due to interrupt", + session_id=self.input.session_id, + ) + if self._process and self._process.returncode is None: + self._process.kill() + break else: # Exceeded total timeout diff --git a/tracecat/agent/executor/loopback.py b/tracecat/agent/executor/loopback.py index ee12fa6461..4d015593ec 100644 --- a/tracecat/agent/executor/loopback.py +++ b/tracecat/agent/executor/loopback.py @@ -33,12 +33,15 @@ MCPToolDefinition, SandboxAgentConfig, ) +from tracecat.agent.session.types import AgentSessionStatus from tracecat.agent.stream.connector import AgentStream from tracecat.agent.types import AgentConfig from tracecat.db.engine import get_async_session_context_manager from tracecat.db.models import AgentSession, AgentSessionHistory from tracecat.logger import logger +INTERRUPT_POLL_INTERVAL_SEC = 0.3 + @dataclass(kw_only=True, slots=True) class LoopbackInput: @@ -77,6 +80,7 @@ class LoopbackResult: error: str | None = None approval_requested: bool = False approval_items: list[ToolCallContent] = field(default_factory=list) + interrupted: bool = False class LoopbackHandler: @@ -115,6 +119,23 @@ async def _emit_stream_done(self) -> None: except Exception as e: logger.warning("Failed to emit stream done", error=str(e)) + async def _check_interrupt(self) -> bool: + """Check if the session has been marked for interrupt. + + Queries the database to check if the session status is INTERRUPTED. + + Returns: + True if the session should be interrupted, False otherwise. + """ + async with get_async_session_context_manager() as db_session: + stmt = select(AgentSession.status).where( + AgentSession.id == self.input.session_id, + AgentSession.workspace_id == self.input.workspace_id, + ) + result = await db_session.execute(stmt) + status = result.scalar_one_or_none() + return status == AgentSessionStatus.INTERRUPTED + async def handle_connection( self, reader: asyncio.StreamReader, @@ -210,16 +231,31 @@ async def _process_runtime_events(self, reader: asyncio.StreamReader) -> None: """Read and process events from the runtime. Forwards streaming events to Redis, persists complete messages to DB, - and handles session updates. + and handles session updates. Also polls for interrupt requests. """ if self._stream is None: raise RuntimeError("Stream not initialized") while True: + # Use wait_for with timeout to allow interrupt polling between reads try: - _msg_type, payload_bytes = await read_message( - reader, expected_type=MessageType.EVENT + _, payload_bytes = await asyncio.wait_for( + read_message(reader, expected_type=MessageType.EVENT), + timeout=INTERRUPT_POLL_INTERVAL_SEC, ) + except TimeoutError: + # No message received within timeout - check for interrupt + if await self._check_interrupt(): + logger.info( + "Interrupt detected, stopping agent execution", + session_id=self.input.session_id, + ) + self._result.interrupted = True + self._result.success = True # Interrupt is a clean termination + await self._emit_stream_done() + return + # No interrupt - continue waiting for next message + continue except asyncio.IncompleteReadError: # Connection closed unexpectedly - treat as error, not silent break logger.warning( diff --git a/tracecat/agent/schemas.py b/tracecat/agent/schemas.py index 7040cd1b9b..4464875bf8 100644 --- a/tracecat/agent/schemas.py +++ b/tracecat/agent/schemas.py @@ -172,6 +172,7 @@ class AgentOutput(BaseModel): duration: float usage: RunUsage | None = None session_id: uuid.UUID + interrupted: bool = False class ExecuteToolCallArgs(BaseModel): diff --git a/tracecat/agent/session/activities.py b/tracecat/agent/session/activities.py index 79407585e0..0b09337ec1 100644 --- a/tracecat/agent/session/activities.py +++ b/tracecat/agent/session/activities.py @@ -15,7 +15,7 @@ from tracecat.agent.common.stream_types import HarnessType from tracecat.agent.session.schemas import AgentSessionCreate from tracecat.agent.session.service import AgentSessionService -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.auth.types import Role from tracecat.contexts import ctx_role from tracecat.logger import logger @@ -158,9 +158,43 @@ async def load_session_activity(input: LoadSessionInput) -> LoadSessionResult: return LoadSessionResult(found=False, error=str(e)) +class UpdateSessionStatusInput(BaseModel): + """Input for update_session_status_activity.""" + + role: Role + session_id: uuid.UUID + status: AgentSessionStatus + + +@activity.defn +async def update_session_status_activity(input: UpdateSessionStatusInput) -> None: + """Update the status of an agent session. + + Called by the workflow to update session status on completion, failure, or interrupt. + """ + ctx_role.set(input.role) + + try: + async with AgentSessionService.with_session(role=input.role) as service: + await service.update_session_status(input.session_id, input.status) + logger.info( + "Updated session status", + session_id=input.session_id, + status=input.status, + ) + except Exception as e: + logger.error( + "Failed to update session status", + session_id=input.session_id, + status=input.status, + error=str(e), + ) + + def get_session_activities() -> list: """Get all session-related activities for worker registration.""" return [ create_session_activity, load_session_activity, + update_session_status_activity, ] diff --git a/tracecat/agent/session/router.py b/tracecat/agent/session/router.py index 8ec7b47105..99210aeb4e 100644 --- a/tracecat/agent/session/router.py +++ b/tracecat/agent/session/router.py @@ -328,6 +328,9 @@ async def send_message( ) # Create stream and return with Vercel format + # The stream will automatically close when: + # 1. Client disconnects (is_disconnected returns True) + # 2. StreamEnd marker is read from Redis (emitted by interrupt or normal completion) stream = await AgentStream.new(agent_session.id, workspace_id) return StreamingResponse( stream.sse(http_request.is_disconnected, last_id=start_id, format="vercel"), @@ -387,8 +390,8 @@ async def stream_session_events( detail="Workspace access required", ) - # Try to get last_stream_id from session, but don't fail if session doesn't exist yet. - # This handles the race condition where frontend connects before session is created. + # Try to get last_stream_id from session, but don't fail if session doesn't exist yet. + # This handles the race condition where frontend connects before session is created. last_stream_id: str | None = None async with AgentSessionService.with_session(role=role) as svc: agent_session = await svc.get_session(session_id) @@ -404,6 +407,9 @@ async def stream_session_events( session_id=session_id, ) + # Stream will automatically close when: + # 1. Client disconnects (is_disconnected returns True) + # 2. StreamEnd marker is read from Redis (emitted by interrupt or normal completion) stream = await AgentStream.new(session_id, workspace_id) headers = { "Cache-Control": "no-cache, no-transform", @@ -445,3 +451,30 @@ async def fork_session( status_code=status.HTTP_404_NOT_FOUND, detail=str(e), ) from e + + +@router.post("/{session_id}/interrupt") +async def interrupt_session( + session_id: uuid.UUID, + role: WorkspaceUser, + session: AsyncDBSession, +) -> dict[str, bool]: + """Request interruption of a running agent session. + + Marks the session for interrupt. The agent executor will detect this + status change and terminate execution cleanly, emitting stream.done() + to prevent the frontend from hanging. + + Returns: + {"interrupted": true} if the session was running and is now interrupted, + {"interrupted": false} if the session was not in a running state. + """ + try: + svc = AgentSessionService(session, role) + interrupted = await svc.interrupt_session(session_id) + return {"interrupted": interrupted} + except TracecatNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e diff --git a/tracecat/agent/session/schemas.py b/tracecat/agent/session/schemas.py index 8fbae446b5..a66a695813 100644 --- a/tracecat/agent/session/schemas.py +++ b/tracecat/agent/session/schemas.py @@ -10,7 +10,7 @@ from tracecat.agent.adapter.vercel import UIMessage from tracecat.agent.common.stream_types import HarnessType -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus class AgentSessionCreate(BaseModel): @@ -101,6 +101,8 @@ class AgentSessionRead(BaseModel): agent_preset_id: uuid.UUID | None # Harness harness_type: str | None + # Status + status: AgentSessionStatus = AgentSessionStatus.IDLE # Stream tracking last_stream_id: str | None = None # Fork tracking diff --git a/tracecat/agent/session/service.py b/tracecat/agent/session/service.py index 660564a166..9cac57da04 100644 --- a/tracecat/agent/session/service.py +++ b/tracecat/agent/session/service.py @@ -29,7 +29,7 @@ AgentSessionRead, AgentSessionUpdate, ) -from tracecat.agent.session.types import AgentSessionEntity +from tracecat.agent.session.types import AgentSessionEntity, AgentSessionStatus from tracecat.agent.types import AgentConfig, ClaudeSDKMessageTA from tracecat.audit.logger import audit_log from tracecat.cases.prompts import CaseCopilotPrompts @@ -338,6 +338,74 @@ async def update_last_stream_id( await self.session.refresh(agent_session) return agent_session + async def interrupt_session( + self, + session_id: uuid.UUID, + ) -> bool: + """Request interruption of a running agent session. + + Sets the session status to INTERRUPTED. The loopback handler polls + for this status and terminates execution when detected. + + Args: + session_id: The session UUID to interrupt. + + Returns: + True if the session was running and is now marked for interrupt, + False if the session was not in a running state. + """ + agent_session = await self.get_session(session_id) + if not agent_session: + raise TracecatNotFoundError(f"Session with ID {session_id} not found") + + # Only interrupt if session is currently running + if agent_session.status != AgentSessionStatus.RUNNING: + logger.info( + "Session not in running state, cannot interrupt", + session_id=session_id, + current_status=agent_session.status, + ) + return False + + agent_session.status = AgentSessionStatus.INTERRUPTED + self.session.add(agent_session) + await self.session.commit() + + logger.info( + "Session marked for interrupt", + session_id=session_id, + ) + return True + + async def update_session_status( + self, + session_id: uuid.UUID, + status: AgentSessionStatus, + ) -> None: + """Update the status of an agent session. + + Args: + session_id: The session UUID. + status: The new status to set. + """ + agent_session = await self.get_session(session_id) + if not agent_session: + logger.warning( + "Cannot update status for non-existent session", + session_id=session_id, + ) + return + + agent_session.status = status + self.session.add(agent_session) + await self.session.commit() + + logger.debug( + "Session status updated", + session_id=session_id, + status=status, + ) + # ========================================================================= # Session History Management (for Claude SDK session persistence) # ========================================================================= @@ -553,8 +621,9 @@ async def run_turn( agent_preset_id=agent_session.agent_preset_id, ) - # Update session with current run_id for approval lookups + # Update session with current run_id and set status to running agent_session.curr_run_id = run_id + agent_session.status = AgentSessionStatus.RUNNING self.session.add(agent_session) await self.session.commit() diff --git a/tracecat/agent/session/types.py b/tracecat/agent/session/types.py index d0a571c053..186cddb110 100644 --- a/tracecat/agent/session/types.py +++ b/tracecat/agent/session/types.py @@ -3,6 +3,24 @@ from enum import StrEnum +class AgentSessionStatus(StrEnum): + """Status of an agent session. + + Tracks the lifecycle state of an agent session: + - IDLE: No active workflow running + - RUNNING: Workflow currently executing + - INTERRUPTED: User requested interrupt (transient state) + - COMPLETED: Last run completed successfully + - FAILED: Last run failed + """ + + IDLE = "idle" + RUNNING = "running" + INTERRUPTED = "interrupted" + COMPLETED = "completed" + FAILED = "failed" + + class AgentSessionEntity(StrEnum): """The type of entity associated with an agent session. diff --git a/tracecat/db/models.py b/tracecat/db/models.py index f3c19151a4..28c3790e8a 100644 --- a/tracecat/db/models.py +++ b/tracecat/db/models.py @@ -2030,6 +2030,14 @@ class AgentSession(WorkspaceModel): index=True, doc="Current workflow run ID - used to construct workflow handle for approvals", ) + # Session status tracking (for interrupts and lifecycle) + status: Mapped[str] = mapped_column( + String(20), + default="idle", + nullable=False, + index=True, + doc="Session status: idle, running, interrupted, completed, failed", + ) # Stream position tracking (for resuming from last event) last_stream_id: Mapped[str | None] = mapped_column( String(128),