1
1
from collections .abc import AsyncGenerator
2
- from typing import Optional , TypedDict , Union
2
+ from typing import Optional , Union
3
3
4
4
from openai import AsyncAzureOpenAI , AsyncOpenAI
5
5
from openai .types .chat import ChatCompletionMessageParam
11
11
12
12
from fastapi_app .api_models import (
13
13
AIChatRoles ,
14
+ BrandFilter ,
14
15
ChatRequestOverrides ,
16
+ Filter ,
15
17
ItemPublic ,
16
18
Message ,
19
+ PriceFilter ,
17
20
RAGContext ,
18
21
RetrievalResponse ,
19
22
RetrievalResponseDelta ,
23
+ SearchResults ,
20
24
ThoughtStep ,
21
25
)
22
26
from fastapi_app .postgres_searcher import PostgresSearcher
23
27
from fastapi_app .rag_base import ChatParams , RAGChatBase
24
28
25
29
26
- class PriceFilter (TypedDict ):
27
- column : str = "price"
28
- """The column to filter on (always 'price' for this filter)"""
29
-
30
- comparison_operator : str
31
- """The operator for price comparison ('>', '<', '>=', '<=', '=')"""
32
-
33
- value : float
34
- """ The price value to compare against (e.g., 30.00) """
35
-
36
-
37
- class BrandFilter (TypedDict ):
38
- column : str = "brand"
39
- """The column to filter on (always 'brand' for this filter)"""
40
-
41
- comparison_operator : str
42
- """The operator for brand comparison ('=' or '!=')"""
43
-
44
- value : str
45
- """The brand name to compare against (e.g., 'AirStrider')"""
46
-
47
-
48
- class SearchResults (TypedDict ):
49
- query : str
50
- """The original search query"""
51
-
52
- items : list [ItemPublic ]
53
- """List of items that match the search query and filters"""
54
-
55
- filters : list [Union [PriceFilter , BrandFilter ]]
56
- """List of filters applied to the search results"""
57
-
58
-
59
30
class AdvancedRAGChat (RAGChatBase ):
60
31
query_prompt_template = open (RAGChatBase .prompts_dir / "query.txt" ).read ()
61
32
query_fewshots = open (RAGChatBase .prompts_dir / "query_fewshots.json" ).read ()
@@ -79,9 +50,13 @@ def __init__(
79
50
chat_model if chat_deployment is None else chat_deployment ,
80
51
provider = OpenAIProvider (openai_client = openai_chat_client ),
81
52
)
82
- self .search_agent = Agent (
53
+ self .search_agent = Agent [ ChatParams , SearchResults ] (
83
54
pydantic_chat_model ,
84
- model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = self .chat_params .seed ),
55
+ model_settings = ModelSettings (
56
+ temperature = 0.0 ,
57
+ max_tokens = 500 ,
58
+ ** ({"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}),
59
+ ),
85
60
system_prompt = self .query_prompt_template ,
86
61
tools = [self .search_database ],
87
62
output_type = SearchResults ,
@@ -92,7 +67,7 @@ def __init__(
92
67
model_settings = ModelSettings (
93
68
temperature = self .chat_params .temperature ,
94
69
max_tokens = self .chat_params .response_token_limit ,
95
- seed = self .chat_params .seed ,
70
+ ** ({ " seed" : self .chat_params .seed } if self . chat_params . seed is not None else {}) ,
96
71
),
97
72
)
98
73
@@ -115,7 +90,7 @@ async def search_database(
115
90
List of formatted items that match the search query and filters
116
91
"""
117
92
# Only send non-None filters
118
- filters = []
93
+ filters : list [ Filter ] = []
119
94
if price_filter :
120
95
filters .append (price_filter )
121
96
if brand_filter :
@@ -134,12 +109,12 @@ async def search_database(
134
109
async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
135
110
few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
136
111
user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
137
- results = await self .search_agent .run (
112
+ results = await self .search_agent .run ( # type: ignore[call-overload]
138
113
user_query ,
139
114
message_history = few_shots + self .chat_params .past_messages ,
140
115
deps = self .chat_params ,
141
116
)
142
- items = results .output [ " items" ]
117
+ items = results .output . items
143
118
thoughts = [
144
119
ThoughtStep (
145
120
title = "Prompt to generate search arguments" ,
@@ -148,12 +123,12 @@ async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
148
123
),
149
124
ThoughtStep (
150
125
title = "Search using generated search arguments" ,
151
- description = results .output [ " query" ] ,
126
+ description = results .output . query ,
152
127
props = {
153
128
"top" : self .chat_params .top ,
154
129
"vector_search" : self .chat_params .enable_vector_search ,
155
130
"text_search" : self .chat_params .enable_text_search ,
156
- "filters" : results .output [ " filters" ] ,
131
+ "filters" : results .output . filters ,
157
132
},
158
133
),
159
134
ThoughtStep (
0 commit comments