|
2 | 2 | from collections.abc import AsyncGenerator
|
3 | 3 | from typing import Optional, Union
|
4 | 4 |
|
5 |
| -from agents import Agent, ModelSettings, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled |
6 |
| -from openai import AsyncAzureOpenAI, AsyncOpenAI |
7 |
| -from openai.types.chat import ( |
8 |
| - ChatCompletionMessageParam, |
9 |
| -) |
10 |
| -from openai.types.responses import ( |
11 |
| - EasyInputMessageParam, |
12 |
| - ResponseFunctionToolCallParam, |
13 |
| - ResponseTextDeltaEvent, |
| 5 | +from agents import ( |
| 6 | + Agent, |
| 7 | + ModelSettings, |
| 8 | + OpenAIChatCompletionsModel, |
| 9 | + Runner, |
| 10 | + ToolCallOutputItem, |
| 11 | + function_tool, |
| 12 | + set_tracing_disabled, |
14 | 13 | )
|
15 |
| -from openai.types.responses.response_input_item_param import FunctionCallOutput |
| 14 | +from openai import AsyncAzureOpenAI, AsyncOpenAI |
| 15 | +from openai.types.responses import EasyInputMessageParam, ResponseInputItemParam, ResponseTextDeltaEvent |
16 | 16 |
|
17 | 17 | from fastapi_app.api_models import (
|
18 | 18 | AIChatRoles,
|
@@ -41,7 +41,7 @@ class AdvancedRAGChat(RAGChatBase):
|
41 | 41 | def __init__(
|
42 | 42 | self,
|
43 | 43 | *,
|
44 |
| - messages: list[ChatCompletionMessageParam], |
| 44 | + messages: list[ResponseInputItemParam], |
45 | 45 | overrides: ChatRequestOverrides,
|
46 | 46 | searcher: PostgresSearcher,
|
47 | 47 | openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
@@ -109,34 +109,17 @@ async def search_database(
|
109 | 109 | )
|
110 | 110 |
|
111 | 111 | async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
|
112 |
| - few_shots = json.loads(self.query_fewshots) |
113 |
| - few_shot_inputs = [] |
114 |
| - for few_shot in few_shots: |
115 |
| - if few_shot["role"] == "user": |
116 |
| - message = EasyInputMessageParam(role="user", content=few_shot["content"]) |
117 |
| - elif few_shot["role"] == "assistant" and few_shot["tool_calls"] is not None: |
118 |
| - message = ResponseFunctionToolCallParam( |
119 |
| - id="madeup", |
120 |
| - call_id=few_shot["tool_calls"][0]["id"], |
121 |
| - name=few_shot["tool_calls"][0]["function"]["name"], |
122 |
| - arguments=few_shot["tool_calls"][0]["function"]["arguments"], |
123 |
| - type="function_call", |
124 |
| - ) |
125 |
| - elif few_shot["role"] == "tool" and few_shot["tool_call_id"] is not None: |
126 |
| - message = FunctionCallOutput( |
127 |
| - id="madeupoutput", |
128 |
| - call_id=few_shot["tool_call_id"], |
129 |
| - output=few_shot["content"], |
130 |
| - type="function_call_output", |
131 |
| - ) |
132 |
| - few_shot_inputs.append(message) |
133 |
| - |
| 112 | + few_shots: list[ResponseInputItemParam] = json.loads(self.query_fewshots) |
134 | 113 | user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
|
135 | 114 | new_user_message = EasyInputMessageParam(role="user", content=user_query)
|
136 |
| - all_messages = few_shot_inputs + self.chat_params.past_messages + [new_user_message] |
| 115 | + all_messages = few_shots + self.chat_params.past_messages + [new_user_message] |
137 | 116 |
|
138 | 117 | run_results = await Runner.run(self.search_agent, input=all_messages)
|
139 |
| - search_results = run_results.new_items[-1].output |
| 118 | + most_recent_response = run_results.new_items[-1] |
| 119 | + if isinstance(most_recent_response, ToolCallOutputItem): |
| 120 | + search_results = most_recent_response.output |
| 121 | + else: |
| 122 | + raise ValueError("Error retrieving search results, model did not call tool properly") |
140 | 123 |
|
141 | 124 | thoughts = [
|
142 | 125 | ThoughtStep(
|
|
0 commit comments