Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 203 additions & 8 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from openai.types.responses import (
EasyInputMessageParam,
ResponseFunctionToolCallParam,
ResponseFunctionWebSearch,
ResponseInputContentParam,
ResponseInputMessageContentListParam,
ResponseInputTextParam,
Expand Down Expand Up @@ -55,6 +56,7 @@
EndOfTurnItem,
FileSource,
HiddenContextItem,
SearchTask,
Task,
TaskItem,
ThoughtTask,
Expand Down Expand Up @@ -340,6 +342,11 @@ class StreamingThoughtTracker(BaseModel):
task: ThoughtTask


class SearchTaskTracker(BaseModel):
item_id: str
task: SearchTask


async def stream_agent_response(
context: AgentContext, result: RunResultStreaming
) -> AsyncIterator[ThreadStreamEvent]:
Expand All @@ -350,6 +357,7 @@ async def stream_agent_response(
queue_iterator = _AsyncQueueIterator(context._events)
produced_items = set()
streaming_thought: None | StreamingThoughtTracker = None
search_tasks: dict[str, SearchTaskTracker] = {}

# check if the last item in the thread was a workflow or a client tool call
# if it was a client tool call, check if the second last item was a workflow
Expand All @@ -371,6 +379,7 @@ async def stream_agent_response(
ctx.workflow_item = second_last_item

def end_workflow(item: WorkflowItem):
nonlocal search_tasks
if item == ctx.workflow_item:
ctx.workflow_item = None
delta = datetime.now() - item.created_at
Expand All @@ -381,8 +390,165 @@ def end_workflow(item: WorkflowItem):
# To keep a workflow open on completion, close it explicitly with
# AgentContext.end_workflow(expanded=True)
item.workflow.expanded = False
search_tasks.clear()
return ThreadItemDoneEvent(item=item)

def ensure_workflow() -> list[ThreadStreamEvent]:
events: list[ThreadStreamEvent] = []
if not ctx.workflow_item:
ctx.workflow_item = WorkflowItem(
id=ctx.generate_id("workflow"),
created_at=datetime.now(),
workflow=Workflow(type="reasoning", tasks=[]),
thread_id=thread.id,
)
produced_items.add(ctx.workflow_item.id)
events.append(ThreadItemAddedEvent(item=ctx.workflow_item))
return events

def ensure_search_task(
item_id: str,
) -> tuple[SearchTaskTracker, bool, list[ThreadStreamEvent]]:
events = ensure_workflow()
tracker = search_tasks.get(item_id)
if not tracker:
tracker = SearchTaskTracker(
item_id=item_id,
task=SearchTask(status_indicator="loading"),
)
search_tasks[item_id] = tracker
task_added = False
if ctx.workflow_item and tracker.task not in ctx.workflow_item.workflow.tasks:
ctx.workflow_item.workflow.tasks.append(tracker.task)
task_added = True
return tracker, task_added, events

def apply_search_task_updates(
tracker: SearchTaskTracker,
*,
status: str | None = None,
call: ResponseFunctionWebSearch | None = None,
) -> bool:
updated = False
task = tracker.task
if status is not None and task.status_indicator != status:
task.status_indicator = status
updated = True
if call is None:
return updated

action = call.action
action_type = getattr(action, "type", None)
if action_type == "search":
query = getattr(action, "query", None)
if query:
if task.title != query:
task.title = query
updated = True
if task.title_query != query:
task.title_query = query
updated = True
if query not in task.queries:
task.queries.append(query)
updated = True
sources = getattr(action, "sources", None) or []
if sources:
existing_urls = {source.url for source in task.sources}
new_sources = []
for source in sources:
if source.url not in existing_urls:
new_sources.append(
URLSource(
title=source.url,
url=source.url,
)
)
existing_urls.add(source.url)
if new_sources:
task.sources.extend(new_sources)
updated = True
elif action_type in {"open_page", "find"}:
url = getattr(action, "url", None)
if url:
if task.title is None:
task.title = url
updated = True
existing_urls = {source.url for source in task.sources}
if url not in existing_urls:
task.sources.append(URLSource(title=url, url=url))
updated = True
return updated

def search_status_from_call(call: ResponseFunctionWebSearch) -> str:
if call.status == "completed":
return "complete"
if call.status in {"in_progress", "searching"}:
return "loading"
return "none"

def upsert_search_task(
call: ResponseFunctionWebSearch, *, status: str | None = None
) -> list[ThreadStreamEvent]:
tracker, task_added, events = ensure_search_task(call.id)
effective_status = status or search_status_from_call(call)
updated = apply_search_task_updates(
tracker,
status=effective_status,
call=call,
)
if ctx.workflow_item:
task_index = ctx.workflow_item.workflow.tasks.index(tracker.task)
if task_added:
events.append(
ThreadItemUpdated(
item_id=ctx.workflow_item.id,
update=WorkflowTaskAdded(
task=tracker.task,
task_index=task_index,
),
)
)
elif updated:
events.append(
ThreadItemUpdated(
item_id=ctx.workflow_item.id,
update=WorkflowTaskUpdated(
task=tracker.task,
task_index=task_index,
),
)
)
return events

def update_search_task_status(
item_id: str, status: str
) -> list[ThreadStreamEvent]:
tracker, task_added, events = ensure_search_task(item_id)
updated = apply_search_task_updates(tracker, status=status)
if ctx.workflow_item:
task_index = ctx.workflow_item.workflow.tasks.index(tracker.task)
if task_added:
events.append(
ThreadItemUpdated(
item_id=ctx.workflow_item.id,
update=WorkflowTaskAdded(
task=tracker.task,
task_index=task_index,
),
)
)
elif updated:
events.append(
ThreadItemUpdated(
item_id=ctx.workflow_item.id,
update=WorkflowTaskUpdated(
task=tracker.task,
task_index=task_index,
),
)
)
return events

try:
async for event in _merge_generators(result.stream_events(), queue_iterator):
# Events emitted from agent context helpers
Expand All @@ -407,6 +573,7 @@ def end_workflow(item: WorkflowItem):
and event.item.type == "workflow"
):
ctx.workflow_item = event.item
search_tasks.clear()

# track integration produced items so we can clean them up if
# there is a guardrail tripwire
Expand All @@ -424,6 +591,15 @@ def end_workflow(item: WorkflowItem):
current_item_id = event.raw_item.id
assert current_item_id
produced_items.add(current_item_id)
elif (
event.type == "tool_call_item"
and event.raw_item.type == "web_search_call"
):
for search_event in upsert_search_task(
cast(ResponseFunctionWebSearch, event.raw_item),
status="loading",
):
yield search_event
continue

if event.type != "raw_response_event":
Expand Down Expand Up @@ -468,14 +644,8 @@ def end_workflow(item: WorkflowItem):
elif event.type == "response.output_item.added":
item = event.item
if item.type == "reasoning" and not ctx.workflow_item:
ctx.workflow_item = WorkflowItem(
id=ctx.generate_id("workflow"),
created_at=datetime.now(),
workflow=Workflow(type="reasoning", tasks=[]),
thread_id=thread.id,
)
produced_items.add(ctx.workflow_item.id)
yield ThreadItemAddedEvent(item=ctx.workflow_item)
for workflow_event in ensure_workflow():
yield workflow_event
if item.type == "message":
if ctx.workflow_item:
yield end_workflow(ctx.workflow_item)
Expand All @@ -489,6 +659,11 @@ def end_workflow(item: WorkflowItem):
created_at=datetime.now(),
),
)
elif item.type == "web_search_call":
for search_event in upsert_search_task(
cast(ResponseFunctionWebSearch, item)
):
yield search_event
elif event.type == "response.reasoning_summary_text.delta":
if not ctx.workflow_item:
continue
Expand Down Expand Up @@ -566,6 +741,26 @@ def end_workflow(item: WorkflowItem):
created_at=datetime.now(),
),
)
elif item.type == "web_search_call":
for search_event in upsert_search_task(
cast(ResponseFunctionWebSearch, item), status="complete"
):
yield search_event
elif event.type == "response.web_search_call.in_progress":
for search_event in update_search_task_status(
event.item_id, "loading"
):
yield search_event
elif event.type == "response.web_search_call.searching":
for search_event in update_search_task_status(
event.item_id, "loading"
):
yield search_event
elif event.type == "response.web_search_call.completed":
for search_event in update_search_task_status(
event.item_id, "complete"
):
yield search_event

except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
for item_id in produced_items:
Expand Down
Loading