1
1
from enum import Enum
2
- from typing import Any , Optional
2
+ from typing import Any , Optional , Union
3
3
4
4
from openai .types .chat import ChatCompletionMessageParam
5
- from pydantic import BaseModel
5
+ from pydantic import BaseModel , Field
6
+ from pydantic_ai .messages import ModelRequest , ModelResponse
6
7
7
8
8
9
class AIChatRoles (str , Enum ):
@@ -40,6 +41,30 @@ class ChatRequest(BaseModel):
40
41
context : ChatRequestContext
41
42
sessionState : Optional [Any ] = None
42
43
44
+
45
+ class ItemPublic (BaseModel ):
46
+ id : int
47
+ name : str
48
+ location : str
49
+ cuisine : str
50
+ rating : int
51
+ price_level : int
52
+ review_count : int
53
+ hours : int
54
+ tags : str
55
+ description : str
56
+ menu_summary : str
57
+ top_reviews : str
58
+ vibe : str
59
+
60
+
61
+ class ItemWithDistance (ItemPublic ):
62
+ distance : float
63
+
64
+ def __init__ (self , ** data ):
65
+ super ().__init__ (** data )
66
+ self .distance = round (self .distance , 2 )
67
+
43
68
44
69
class ThoughtStep (BaseModel ):
45
70
title : str
@@ -48,7 +73,7 @@ class ThoughtStep(BaseModel):
48
73
49
74
50
75
class RAGContext (BaseModel ):
51
- data_points : dict [int , dict [ str , Any ] ]
76
+ data_points : dict [int , ItemPublic ]
52
77
thoughts : list [ThoughtStep ]
53
78
followup_questions : Optional [list [str ]] = None
54
79
@@ -69,34 +94,39 @@ class RetrievalResponseDelta(BaseModel):
69
94
sessionState : Optional [Any ] = None
70
95
71
96
72
- class ItemPublic (BaseModel ):
73
- id : int
74
- name : str
75
- location : str
76
- cuisine : str
77
- rating : int
78
- price_level : int
79
- review_count : int
80
- hours : int
81
- tags : str
82
- description : str
83
- menu_summary : str
84
- top_reviews : str
85
- vibe : str
86
-
87
-
88
- class ItemWithDistance (ItemPublic ):
89
- distance : float
90
-
91
- def __init__ (self , ** data ):
92
- super ().__init__ (** data )
93
- self .distance = round (self .distance , 2 )
94
-
95
-
96
97
class ChatParams (ChatRequestOverrides ):
97
98
prompt_template : str
98
99
response_token_limit : int = 1024
99
100
enable_text_search : bool
100
101
enable_vector_search : bool
101
102
original_user_query : str
102
- past_messages : list [ChatCompletionMessageParam ]
103
+ past_messages : list [Union [ModelRequest , ModelResponse ]]
104
+
105
+
106
+ class Filter (BaseModel ):
107
+ column : str
108
+ comparison_operator : str
109
+ value : Any
110
+
111
+
112
+ class PriceLevelFilter (Filter ):
113
+ column : str = Field (default = "price_level" , description = "The column to filter on (always 'price_level' for this filter)" )
114
+ comparison_operator : str = Field (description = "The operator for price level comparison ('>', '<', '>=', '<=', '=')" )
115
+ value : float = Field (description = "Value to compare against, either 1, 2, 3, 4" )
116
+
117
+
118
+ class RatingFilter (Filter ):
119
+ column : str = Field (default = "rating" , description = "The column to filter on (always 'rating' for this filter)" )
120
+ comparison_operator : str = Field (description = "The operator for rating comparison ('>', '<', '>=', '<=', '=')" )
121
+ value : str = Field (description = "Value to compare against, either 0 1 2 3 4" )
122
+
123
+
124
+ class SearchResults (BaseModel ):
125
+ query : str
126
+ """The original search query"""
127
+
128
+ items : list [ItemPublic ]
129
+ """List of items that match the search query and filters"""
130
+
131
+ filters : list [Filter ]
132
+ """List of filters applied to the search results"""
0 commit comments