@@ -5377,3 +5377,137 @@ def dynamic_instr() -> str:
5377
5377
sys_texts = [p .content for p in req .parts if isinstance (p , SystemPromptPart )]
5378
5378
# The dynamic system prompt should still be present since overrides target instructions only
5379
5379
assert dynamic_value in sys_texts
5380
+
5381
+
5382
+ def test_continue_conversation_that_ended_in_output_tool_call (allow_model_requests : None ):
5383
+ def llm (messages : list [ModelMessage ], info : AgentInfo ) -> ModelResponse :
5384
+ if any (isinstance (p , ToolReturnPart ) and p .tool_name == 'roll_dice' for p in messages [- 1 ].parts ):
5385
+ return ModelResponse (
5386
+ parts = [
5387
+ ToolCallPart (
5388
+ tool_name = 'final_result' ,
5389
+ args = {'dice_roll' : 4 },
5390
+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5391
+ )
5392
+ ]
5393
+ )
5394
+ return ModelResponse (
5395
+ parts = [ToolCallPart (tool_name = 'roll_dice' , args = {}, tool_call_id = 'pyd_ai_tool_call_id__roll_dice' )]
5396
+ )
5397
+
5398
+ class Result (BaseModel ):
5399
+ dice_roll : int
5400
+
5401
+ agent = Agent (FunctionModel (llm ), output_type = Result )
5402
+
5403
+ @agent .tool_plain
5404
+ def roll_dice () -> int :
5405
+ return 4
5406
+
5407
+ result = agent .run_sync ('Roll me a dice.' )
5408
+ messages = result .all_messages ()
5409
+ assert messages == snapshot (
5410
+ [
5411
+ ModelRequest (
5412
+ parts = [
5413
+ UserPromptPart (
5414
+ content = 'Roll me a dice.' ,
5415
+ timestamp = IsDatetime (),
5416
+ )
5417
+ ]
5418
+ ),
5419
+ ModelResponse (
5420
+ parts = [ToolCallPart (tool_name = 'roll_dice' , args = {}, tool_call_id = 'pyd_ai_tool_call_id__roll_dice' )],
5421
+ usage = RequestUsage (input_tokens = 55 , output_tokens = 2 ),
5422
+ model_name = 'function:llm:' ,
5423
+ timestamp = IsDatetime (),
5424
+ ),
5425
+ ModelRequest (
5426
+ parts = [
5427
+ ToolReturnPart (
5428
+ tool_name = 'roll_dice' ,
5429
+ content = 4 ,
5430
+ tool_call_id = 'pyd_ai_tool_call_id__roll_dice' ,
5431
+ timestamp = IsDatetime (),
5432
+ )
5433
+ ]
5434
+ ),
5435
+ ModelResponse (
5436
+ parts = [
5437
+ ToolCallPart (
5438
+ tool_name = 'final_result' ,
5439
+ args = {'dice_roll' : 4 },
5440
+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5441
+ )
5442
+ ],
5443
+ usage = RequestUsage (input_tokens = 56 , output_tokens = 6 ),
5444
+ model_name = 'function:llm:' ,
5445
+ timestamp = IsDatetime (),
5446
+ ),
5447
+ ModelRequest (
5448
+ parts = [
5449
+ ToolReturnPart (
5450
+ tool_name = 'final_result' ,
5451
+ content = 'Final result processed.' ,
5452
+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5453
+ timestamp = IsDatetime (),
5454
+ )
5455
+ ]
5456
+ ),
5457
+ ]
5458
+ )
5459
+
5460
+ result = agent .run_sync ('Roll me a dice again.' , message_history = messages )
5461
+ new_messages = result .new_messages ()
5462
+ assert new_messages == snapshot (
5463
+ [
5464
+ ModelRequest (
5465
+ parts = [
5466
+ UserPromptPart (
5467
+ content = 'Roll me a dice again.' ,
5468
+ timestamp = IsDatetime (),
5469
+ )
5470
+ ]
5471
+ ),
5472
+ ModelResponse (
5473
+ parts = [ToolCallPart (tool_name = 'roll_dice' , args = {}, tool_call_id = 'pyd_ai_tool_call_id__roll_dice' )],
5474
+ usage = RequestUsage (input_tokens = 66 , output_tokens = 8 ),
5475
+ model_name = 'function:llm:' ,
5476
+ timestamp = IsDatetime (),
5477
+ ),
5478
+ ModelRequest (
5479
+ parts = [
5480
+ ToolReturnPart (
5481
+ tool_name = 'roll_dice' ,
5482
+ content = 4 ,
5483
+ tool_call_id = 'pyd_ai_tool_call_id__roll_dice' ,
5484
+ timestamp = IsDatetime (),
5485
+ )
5486
+ ]
5487
+ ),
5488
+ ModelResponse (
5489
+ parts = [
5490
+ ToolCallPart (
5491
+ tool_name = 'final_result' ,
5492
+ args = {'dice_roll' : 4 },
5493
+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5494
+ )
5495
+ ],
5496
+ usage = RequestUsage (input_tokens = 67 , output_tokens = 12 ),
5497
+ model_name = 'function:llm:' ,
5498
+ timestamp = IsDatetime (),
5499
+ ),
5500
+ ModelRequest (
5501
+ parts = [
5502
+ ToolReturnPart (
5503
+ tool_name = 'final_result' ,
5504
+ content = 'Final result processed.' ,
5505
+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5506
+ timestamp = IsDatetime (),
5507
+ )
5508
+ ]
5509
+ ),
5510
+ ]
5511
+ )
5512
+
5513
+ assert not any (isinstance (p , ToolReturnPart ) and p .tool_name == 'final_result' for p in new_messages [0 ].parts )
0 commit comments