diff --git a/chatkit/agents.py b/chatkit/agents.py index 38543d1..8bd49af 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -25,6 +25,7 @@ from openai.types.responses import ( EasyInputMessageParam, ResponseFunctionToolCallParam, + ResponseFunctionWebSearch, ResponseInputContentParam, ResponseInputMessageContentListParam, ResponseInputTextParam, @@ -55,6 +56,7 @@ EndOfTurnItem, FileSource, HiddenContextItem, + SearchTask, Task, TaskItem, ThoughtTask, @@ -340,6 +342,11 @@ class StreamingThoughtTracker(BaseModel): task: ThoughtTask +class SearchTaskTracker(BaseModel): + item_id: str + task: SearchTask + + async def stream_agent_response( context: AgentContext, result: RunResultStreaming ) -> AsyncIterator[ThreadStreamEvent]: @@ -350,6 +357,7 @@ async def stream_agent_response( queue_iterator = _AsyncQueueIterator(context._events) produced_items = set() streaming_thought: None | StreamingThoughtTracker = None + search_tasks: dict[str, SearchTaskTracker] = {} # check if the last item in the thread was a workflow or a client tool call # if it was a client tool call, check if the second last item was a workflow @@ -371,6 +379,7 @@ async def stream_agent_response( ctx.workflow_item = second_last_item def end_workflow(item: WorkflowItem): + nonlocal search_tasks if item == ctx.workflow_item: ctx.workflow_item = None delta = datetime.now() - item.created_at @@ -381,8 +390,165 @@ def end_workflow(item: WorkflowItem): # To keep a workflow open on completion, close it explicitly with # AgentContext.end_workflow(expanded=True) item.workflow.expanded = False + search_tasks.clear() return ThreadItemDoneEvent(item=item) + def ensure_workflow() -> list[ThreadStreamEvent]: + events: list[ThreadStreamEvent] = [] + if not ctx.workflow_item: + ctx.workflow_item = WorkflowItem( + id=ctx.generate_id("workflow"), + created_at=datetime.now(), + workflow=Workflow(type="reasoning", tasks=[]), + thread_id=thread.id, + ) + produced_items.add(ctx.workflow_item.id) + events.append(ThreadItemAddedEvent(item=ctx.workflow_item)) + return events + + def ensure_search_task( + item_id: str, + ) -> tuple[SearchTaskTracker, bool, list[ThreadStreamEvent]]: + events = ensure_workflow() + tracker = search_tasks.get(item_id) + if not tracker: + tracker = SearchTaskTracker( + item_id=item_id, + task=SearchTask(status_indicator="loading"), + ) + search_tasks[item_id] = tracker + task_added = False + if ctx.workflow_item and tracker.task not in ctx.workflow_item.workflow.tasks: + ctx.workflow_item.workflow.tasks.append(tracker.task) + task_added = True + return tracker, task_added, events + + def apply_search_task_updates( + tracker: SearchTaskTracker, + *, + status: str | None = None, + call: ResponseFunctionWebSearch | None = None, + ) -> bool: + updated = False + task = tracker.task + if status is not None and task.status_indicator != status: + task.status_indicator = status + updated = True + if call is None: + return updated + + action = call.action + action_type = getattr(action, "type", None) + if action_type == "search": + query = getattr(action, "query", None) + if query: + if task.title != query: + task.title = query + updated = True + if task.title_query != query: + task.title_query = query + updated = True + if query not in task.queries: + task.queries.append(query) + updated = True + sources = getattr(action, "sources", None) or [] + if sources: + existing_urls = {source.url for source in task.sources} + new_sources = [] + for source in sources: + if source.url not in existing_urls: + new_sources.append( + URLSource( + title=source.url, + url=source.url, + ) + ) + existing_urls.add(source.url) + if new_sources: + task.sources.extend(new_sources) + updated = True + elif action_type in {"open_page", "find"}: + url = getattr(action, "url", None) + if url: + if task.title is None: + task.title = url + updated = True + existing_urls = {source.url for source in task.sources} + if url not in existing_urls: + task.sources.append(URLSource(title=url, url=url)) + updated = True + return updated + + def search_status_from_call(call: ResponseFunctionWebSearch) -> str: + if call.status == "completed": + return "complete" + if call.status in {"in_progress", "searching"}: + return "loading" + return "none" + + def upsert_search_task( + call: ResponseFunctionWebSearch, *, status: str | None = None + ) -> list[ThreadStreamEvent]: + tracker, task_added, events = ensure_search_task(call.id) + effective_status = status or search_status_from_call(call) + updated = apply_search_task_updates( + tracker, + status=effective_status, + call=call, + ) + if ctx.workflow_item: + task_index = ctx.workflow_item.workflow.tasks.index(tracker.task) + if task_added: + events.append( + ThreadItemUpdated( + item_id=ctx.workflow_item.id, + update=WorkflowTaskAdded( + task=tracker.task, + task_index=task_index, + ), + ) + ) + elif updated: + events.append( + ThreadItemUpdated( + item_id=ctx.workflow_item.id, + update=WorkflowTaskUpdated( + task=tracker.task, + task_index=task_index, + ), + ) + ) + return events + + def update_search_task_status( + item_id: str, status: str + ) -> list[ThreadStreamEvent]: + tracker, task_added, events = ensure_search_task(item_id) + updated = apply_search_task_updates(tracker, status=status) + if ctx.workflow_item: + task_index = ctx.workflow_item.workflow.tasks.index(tracker.task) + if task_added: + events.append( + ThreadItemUpdated( + item_id=ctx.workflow_item.id, + update=WorkflowTaskAdded( + task=tracker.task, + task_index=task_index, + ), + ) + ) + elif updated: + events.append( + ThreadItemUpdated( + item_id=ctx.workflow_item.id, + update=WorkflowTaskUpdated( + task=tracker.task, + task_index=task_index, + ), + ) + ) + return events + try: async for event in _merge_generators(result.stream_events(), queue_iterator): # Events emitted from agent context helpers @@ -407,6 +573,7 @@ def end_workflow(item: WorkflowItem): and event.item.type == "workflow" ): ctx.workflow_item = event.item + search_tasks.clear() # track integration produced items so we can clean them up if # there is a guardrail tripwire @@ -424,6 +591,15 @@ def end_workflow(item: WorkflowItem): current_item_id = event.raw_item.id assert current_item_id produced_items.add(current_item_id) + elif ( + event.type == "tool_call_item" + and event.raw_item.type == "web_search_call" + ): + for search_event in upsert_search_task( + cast(ResponseFunctionWebSearch, event.raw_item), + status="loading", + ): + yield search_event continue if event.type != "raw_response_event": @@ -468,14 +644,8 @@ def end_workflow(item: WorkflowItem): elif event.type == "response.output_item.added": item = event.item if item.type == "reasoning" and not ctx.workflow_item: - ctx.workflow_item = WorkflowItem( - id=ctx.generate_id("workflow"), - created_at=datetime.now(), - workflow=Workflow(type="reasoning", tasks=[]), - thread_id=thread.id, - ) - produced_items.add(ctx.workflow_item.id) - yield ThreadItemAddedEvent(item=ctx.workflow_item) + for workflow_event in ensure_workflow(): + yield workflow_event if item.type == "message": if ctx.workflow_item: yield end_workflow(ctx.workflow_item) @@ -489,6 +659,11 @@ def end_workflow(item: WorkflowItem): created_at=datetime.now(), ), ) + elif item.type == "web_search_call": + for search_event in upsert_search_task( + cast(ResponseFunctionWebSearch, item) + ): + yield search_event elif event.type == "response.reasoning_summary_text.delta": if not ctx.workflow_item: continue @@ -566,6 +741,26 @@ def end_workflow(item: WorkflowItem): created_at=datetime.now(), ), ) + elif item.type == "web_search_call": + for search_event in upsert_search_task( + cast(ResponseFunctionWebSearch, item), status="complete" + ): + yield search_event + elif event.type == "response.web_search_call.in_progress": + for search_event in update_search_task_status( + event.item_id, "loading" + ): + yield search_event + elif event.type == "response.web_search_call.searching": + for search_event in update_search_task_status( + event.item_id, "loading" + ): + yield search_event + elif event.type == "response.web_search_call.completed": + for search_event in update_search_task_status( + event.item_id, "complete" + ): + yield search_event except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered): for item_id in produced_items: diff --git a/tests/test_agents.py b/tests/test_agents.py index c0fc4b0..ea19412 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -27,6 +27,7 @@ from openai.types.responses import ( EasyInputMessageParam, ResponseFileSearchToolCall, + ResponseFunctionWebSearch, ResponseInputContentParam, ResponseInputTextParam, ResponseOutputItemAddedEvent, @@ -38,6 +39,10 @@ ResponseContentPartAddedEvent, ) from openai.types.responses.response_file_search_tool_call import Result +from openai.types.responses.response_function_web_search import ( + ActionSearch, + ActionSearchSource, +) from openai.types.responses.response_output_text import ( AnnotationFileCitation as ResponsesAnnotationFileCitation, ) @@ -52,6 +57,9 @@ ) from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent from openai.types.responses.response_text_done_event import ResponseTextDoneEvent +from openai.types.responses.response_web_search_call_searching_event import ( + ResponseWebSearchCallSearchingEvent, +) from chatkit.agents import ( AgentContext, @@ -75,6 +83,7 @@ FileSource, InferenceOptions, Page, + SearchTask, TaskItem, ThoughtTask, Thread, @@ -1170,6 +1179,106 @@ async def test_workflow_streams_first_thought(): pass +async def test_stream_agent_response_tracks_web_search_tasks(): + mock_store.add_thread_item.reset_mock() + context = AgentContext( + previous_response_id=None, thread=thread, store=mock_store, request_context=None + ) + result = make_result() + + call = ResponseFunctionWebSearch( + id="ws_1", + action=ActionSearch(type="search", query="latest news", sources=[]), + status="in_progress", + type="web_search_call", + ) + tool_call_item = ToolCallItem(agent=Agent(name="Assistant"), raw_item=call) + result.add_event( + RunItemStreamEvent( + name="tool_called", + item=tool_call_item, + ) + ) + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=ResponseWebSearchCallSearchingEvent( + item_id=call.id, + output_index=0, + sequence_number=0, + type="response.web_search_call.searching", + ), + ) + ) + completed_call = ResponseFunctionWebSearch( + id=call.id, + action=ActionSearch( + type="search", + query="latest news", + sources=[ActionSearchSource(type="url", url="https://example.com")], + ), + status="completed", + type="web_search_call", + ) + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=ResponseOutputItemDoneEvent( + type="response.output_item.done", + item=completed_call, + output_index=0, + sequence_number=1, + ), + ) + ) + + result.done() + + events = await all_events(stream_agent_response(context, result)) + + workflow_added = next( + ( + event + for event in events + if isinstance(event, ThreadItemAddedEvent) + and event.item.type == "workflow" + ), + None, + ) + assert workflow_added is not None + + search_task_added = next( + ( + event + for event in events + if isinstance(event, ThreadItemUpdated) + and isinstance(event.update, WorkflowTaskAdded) + and isinstance(event.update.task, SearchTask) + ), + None, + ) + assert search_task_added is not None + assert search_task_added.update.task.queries == ["latest news"] + + search_task_completed = next( + ( + event + for event in events + if isinstance(event, ThreadItemUpdated) + and isinstance(event.update, WorkflowTaskUpdated) + and isinstance(event.update.task, SearchTask) + and event.update.task.status_indicator == "complete" + ), + None, + ) + assert search_task_completed is not None + assert any( + isinstance(source, URLSource) and source.url == "https://example.com" + for source in search_task_completed.update.task.sources + ) + assert mock_store.add_thread_item.await_count == 1 + + async def test_workflow_ends_on_message(): context = AgentContext( previous_response_id=None, thread=thread, store=mock_store, request_context=None