|
16 | 16 | ModelRequest,
|
17 | 17 | before_model,
|
18 | 18 | after_model,
|
| 19 | + dynamic_prompt, |
19 | 20 | modify_model_request,
|
20 | 21 | hook_config,
|
21 | 22 | )
|
@@ -572,3 +573,145 @@ async def async_after_with_jumps(state: AgentState, runtime: Runtime) -> dict[st
|
572 | 573 | )
|
573 | 574 |
|
574 | 575 | assert agent_mixed.compile().get_graph().draw_mermaid() == snapshot
|
| 576 | + |
| 577 | + |
| 578 | +def test_dynamic_prompt_decorator() -> None: |
| 579 | + """Test dynamic_prompt decorator with basic usage.""" |
| 580 | + |
| 581 | + @dynamic_prompt |
| 582 | + def my_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str: |
| 583 | + return "Dynamic test prompt" |
| 584 | + |
| 585 | + assert isinstance(my_prompt, AgentMiddleware) |
| 586 | + assert my_prompt.state_schema == AgentState |
| 587 | + assert my_prompt.tools == [] |
| 588 | + assert my_prompt.__class__.__name__ == "my_prompt" |
| 589 | + |
| 590 | + # Verify it modifies the request correctly |
| 591 | + original_request = ModelRequest( |
| 592 | + model="test-model", |
| 593 | + system_prompt="Original", |
| 594 | + messages=[HumanMessage("Hello")], |
| 595 | + tool_choice=None, |
| 596 | + tools=[], |
| 597 | + response_format=None, |
| 598 | + ) |
| 599 | + result = my_prompt.modify_model_request( |
| 600 | + original_request, {"messages": [HumanMessage("Hello")]}, None |
| 601 | + ) |
| 602 | + assert result.system_prompt == "Dynamic test prompt" |
| 603 | + |
| 604 | + |
| 605 | +def test_dynamic_prompt_uses_state() -> None: |
| 606 | + """Test that dynamic_prompt can use state information.""" |
| 607 | + |
| 608 | + @dynamic_prompt |
| 609 | + def custom_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str: |
| 610 | + msg_count = len(state["messages"]) |
| 611 | + return f"Prompt with {msg_count} messages" |
| 612 | + |
| 613 | + # Verify it uses state correctly |
| 614 | + original_request = ModelRequest( |
| 615 | + model="test-model", |
| 616 | + system_prompt="Original", |
| 617 | + messages=[HumanMessage("Hello")], |
| 618 | + tool_choice=None, |
| 619 | + tools=[], |
| 620 | + response_format=None, |
| 621 | + ) |
| 622 | + result = custom_prompt.modify_model_request( |
| 623 | + original_request, {"messages": [HumanMessage("Hello"), HumanMessage("World")]}, None |
| 624 | + ) |
| 625 | + assert result.system_prompt == "Prompt with 2 messages" |
| 626 | + |
| 627 | + |
| 628 | +def test_dynamic_prompt_integration() -> None: |
| 629 | + """Test dynamic_prompt decorator in a full agent.""" |
| 630 | + |
| 631 | + prompt_calls = 0 |
| 632 | + |
| 633 | + @dynamic_prompt |
| 634 | + def context_aware_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str: |
| 635 | + nonlocal prompt_calls |
| 636 | + prompt_calls += 1 |
| 637 | + return f"you are a helpful assistant." |
| 638 | + |
| 639 | + agent = create_agent(model=FakeToolCallingModel(), middleware=[context_aware_prompt]) |
| 640 | + agent = agent.compile() |
| 641 | + |
| 642 | + result = agent.invoke({"messages": [HumanMessage("Hello")]}) |
| 643 | + |
| 644 | + assert prompt_calls == 1 |
| 645 | + assert result["messages"][-1].content == "you are a helpful assistant.-Hello" |
| 646 | + |
| 647 | + |
| 648 | +async def test_async_dynamic_prompt_decorator() -> None: |
| 649 | + """Test dynamic_prompt decorator with async function.""" |
| 650 | + |
| 651 | + @dynamic_prompt |
| 652 | + async def async_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str: |
| 653 | + return "Async dynamic prompt" |
| 654 | + |
| 655 | + assert isinstance(async_prompt, AgentMiddleware) |
| 656 | + assert async_prompt.state_schema == AgentState |
| 657 | + assert async_prompt.tools == [] |
| 658 | + assert async_prompt.__class__.__name__ == "async_prompt" |
| 659 | + |
| 660 | + |
| 661 | +async def test_async_dynamic_prompt_integration() -> None: |
| 662 | + """Test async dynamic_prompt decorator in a full agent.""" |
| 663 | + |
| 664 | + prompt_calls = 0 |
| 665 | + |
| 666 | + @dynamic_prompt |
| 667 | + async def async_context_prompt( |
| 668 | + request: ModelRequest, state: AgentState, runtime: Runtime |
| 669 | + ) -> str: |
| 670 | + nonlocal prompt_calls |
| 671 | + prompt_calls += 1 |
| 672 | + return f"Async assistant." |
| 673 | + |
| 674 | + agent = create_agent(model=FakeToolCallingModel(), middleware=[async_context_prompt]) |
| 675 | + agent = agent.compile() |
| 676 | + |
| 677 | + result = await agent.ainvoke({"messages": [HumanMessage("Hello")]}) |
| 678 | + assert prompt_calls == 1 |
| 679 | + assert result["messages"][-1].content == "Async assistant.-Hello" |
| 680 | + |
| 681 | + |
| 682 | +def test_dynamic_prompt_overwrites_system_prompt() -> None: |
| 683 | + """Test that dynamic_prompt overwrites the original system_prompt.""" |
| 684 | + |
| 685 | + @dynamic_prompt |
| 686 | + def override_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str: |
| 687 | + return "Overridden prompt." |
| 688 | + |
| 689 | + agent = create_agent( |
| 690 | + model=FakeToolCallingModel(), |
| 691 | + system_prompt="Original static prompt", |
| 692 | + middleware=[override_prompt], |
| 693 | + ) |
| 694 | + agent = agent.compile() |
| 695 | + |
| 696 | + result = agent.invoke({"messages": [HumanMessage("Hello")]}) |
| 697 | + assert result["messages"][-1].content == "Overridden prompt.-Hello" |
| 698 | + |
| 699 | + |
| 700 | +def test_dynamic_prompt_multiple_in_sequence() -> None: |
| 701 | + """Test multiple dynamic_prompt decorators in sequence (last wins).""" |
| 702 | + |
| 703 | + @dynamic_prompt |
| 704 | + def first_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str: |
| 705 | + return "First prompt." |
| 706 | + |
| 707 | + @dynamic_prompt |
| 708 | + def second_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str: |
| 709 | + return "Second prompt." |
| 710 | + |
| 711 | + # When used together, the last middleware in the list should win |
| 712 | + # since they're both modify_model_request hooks executed in sequence |
| 713 | + agent = create_agent(model=FakeToolCallingModel(), middleware=[first_prompt, second_prompt]) |
| 714 | + agent = agent.compile() |
| 715 | + |
| 716 | + result = agent.invoke({"messages": [HumanMessage("Hello")]}) |
| 717 | + assert result["messages"][-1].content == "Second prompt.-Hello" |
0 commit comments