|
19 | 19 | ModelResponse,
|
20 | 20 | RetryPromptPart,
|
21 | 21 | SystemPromptPart,
|
| 22 | + TextPart, |
22 | 23 | ToolCallPart,
|
23 | 24 | ToolReturnPart,
|
24 | 25 | UserPromptPart,
|
@@ -537,25 +538,58 @@ def handler(_: httpx.Request):
|
537 | 538 |
|
538 | 539 |
|
539 | 540 | async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient):
|
540 |
| - response = gemini_response( |
541 |
| - _GeminiContent( |
542 |
| - role='model', |
543 |
| - parts=[ |
544 |
| - _GeminiTextPart(text='foo'), |
545 |
| - _function_call_part_from_call( |
| 541 | + """Indicates that tool calls are prioritized over text in heterogeneous responses.""" |
| 542 | + responses = [ |
| 543 | + gemini_response( |
| 544 | + _content_model_response( |
| 545 | + ModelResponse( |
| 546 | + parts=[TextPart(content='foo'), ToolCallPart.from_raw_args('get_location', {'loc_name': 'London'})] |
| 547 | + ) |
| 548 | + ) |
| 549 | + ), |
| 550 | + gemini_response(_content_model_response(ModelResponse.from_text('final response'))), |
| 551 | + ] |
| 552 | + |
| 553 | + gemini_client = get_gemini_client(responses) |
| 554 | + m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) |
| 555 | + agent = Agent(m) |
| 556 | + |
| 557 | + @agent.tool_plain |
| 558 | + async def get_location(loc_name: str) -> str: |
| 559 | + if loc_name == 'London': |
| 560 | + return json.dumps({'lat': 51, 'lng': 0}) |
| 561 | + else: |
| 562 | + raise ModelRetry('Wrong location, please try again') |
| 563 | + |
| 564 | + result = await agent.run('Hello') |
| 565 | + assert result.data == 'final response' |
| 566 | + assert result.all_messages() == snapshot( |
| 567 | + [ |
| 568 | + ModelRequest( |
| 569 | + parts=[ |
| 570 | + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), |
| 571 | + ] |
| 572 | + ), |
| 573 | + ModelResponse( |
| 574 | + parts=[ |
| 575 | + TextPart(content='foo'), |
546 | 576 | ToolCallPart(
|
547 | 577 | tool_name='get_location',
|
548 |
| - args=ArgsDict(args_dict={'loc_name': 'San Fransisco'}), |
| 578 | + args=ArgsDict(args_dict={'loc_name': 'London'}), |
| 579 | + ), |
| 580 | + ], |
| 581 | + timestamp=IsNow(tz=timezone.utc), |
| 582 | + ), |
| 583 | + ModelRequest( |
| 584 | + parts=[ |
| 585 | + ToolReturnPart( |
| 586 | + tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc) |
549 | 587 | )
|
550 |
| - ), |
551 |
| - ], |
552 |
| - ) |
| 588 | + ] |
| 589 | + ), |
| 590 | + ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), |
| 591 | + ] |
553 | 592 | )
|
554 |
| - gemini_client = get_gemini_client(response) |
555 |
| - m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) |
556 |
| - agent = Agent(m) |
557 |
| - result = await agent.run('Hello') |
558 |
| - assert result.data == 'foo' |
559 | 593 |
|
560 | 594 |
|
561 | 595 | async def test_stream_text(get_gemini_client: GetGeminiClient):
|
|
0 commit comments