22
22
ArgsDict ,
23
23
Message ,
24
24
ModelResponse ,
25
+ RetryPrompt ,
25
26
ToolCallPart ,
27
+ ToolReturn ,
28
+ UserPrompt ,
26
29
)
27
30
from pydantic_ai .models import KnownModelName , Model
28
31
from pydantic_ai .models .function import AgentInfo , DeltaToolCall , DeltaToolCalls , FunctionModel
@@ -215,7 +218,7 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response:
215
218
216
219
async def model_logic (messages : list [Message ], info : AgentInfo ) -> ModelResponse : # pragma: no cover
217
220
m = messages [- 1 ]
218
- if m . message_kind == 'user-prompt' :
221
+ if isinstance ( m , UserPrompt ) :
219
222
if response := text_responses .get (m .content ):
220
223
if isinstance (response , str ):
221
224
return ModelResponse .from_text (content = response )
@@ -225,28 +228,28 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelResponse
225
228
if re .fullmatch (r'sql prompt \d+' , m .content ):
226
229
return ModelResponse .from_text (content = 'SELECT 1' )
227
230
228
- elif m . message_kind == 'tool-return' and m .tool_name == 'roulette_wheel' :
231
+ elif isinstance ( m , ToolReturn ) and m .tool_name == 'roulette_wheel' :
229
232
win = m .content == 'winner'
230
233
return ModelResponse (parts = [ToolCallPart (tool_name = 'final_result' , args = ArgsDict ({'response' : win }))])
231
- elif m . message_kind == 'tool-return' and m .tool_name == 'roll_die' :
234
+ elif isinstance ( m , ToolReturn ) and m .tool_name == 'roll_die' :
232
235
return ModelResponse (parts = [ToolCallPart (tool_name = 'get_player_name' , args = ArgsDict ({}))])
233
- elif m . message_kind == 'tool-return' and m .tool_name == 'get_player_name' :
236
+ elif isinstance ( m , ToolReturn ) and m .tool_name == 'get_player_name' :
234
237
return ModelResponse .from_text (content = "Congratulations Anne, you guessed correctly! You're a winner!" )
235
238
if (
236
- m . message_kind == 'retry-prompt'
239
+ isinstance ( m , RetryPrompt )
237
240
and isinstance (m .content , str )
238
241
and m .content .startswith ("No user found with name 'Joh" )
239
242
):
240
243
return ModelResponse (parts = [ToolCallPart (tool_name = 'get_user_by_name' , args = ArgsDict ({'name' : 'John Doe' }))])
241
- elif m . message_kind == 'tool-return' and m .tool_name == 'get_user_by_name' :
244
+ elif isinstance ( m , ToolReturn ) and m .tool_name == 'get_user_by_name' :
242
245
args = {
243
246
'message' : 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!' ,
244
247
'user_id' : 123 ,
245
248
}
246
249
return ModelResponse (parts = [ToolCallPart (tool_name = 'final_result' , args = ArgsDict (args ))])
247
- elif m . message_kind == 'retry-prompt' and m .tool_name == 'calc_volume' :
250
+ elif isinstance ( m , RetryPrompt ) and m .tool_name == 'calc_volume' :
248
251
return ModelResponse (parts = [ToolCallPart (tool_name = 'calc_volume' , args = ArgsDict ({'size' : 6 }))])
249
- elif m . message_kind == 'tool-return' and m .tool_name == 'customer_balance' :
252
+ elif isinstance ( m , ToolReturn ) and m .tool_name == 'customer_balance' :
250
253
args = {
251
254
'support_advice' : 'Hello John, your current account balance, including pending transactions, is $123.45.' ,
252
255
'block_card' : False ,
@@ -262,7 +265,7 @@ async def stream_model_logic(
262
265
messages : list [Message ], info : AgentInfo
263
266
) -> AsyncIterator [str | DeltaToolCalls ]: # pragma: no cover
264
267
m = messages [- 1 ]
265
- if m . message_kind == 'user-prompt' :
268
+ if isinstance ( m , UserPrompt ) :
266
269
if response := text_responses .get (m .content ):
267
270
if isinstance (response , str ):
268
271
words = response .split (' ' )
0 commit comments