|
21 | 21 | import pytest |
22 | 22 | from langchain_core.messages import AIMessage, HumanMessage, ToolMessage |
23 | 23 | from langchain_core.tools import tool |
| 24 | +from langgraph.prebuilt import InjectedStore |
| 25 | +from langgraph.store.base import BaseStore |
24 | 26 | from langgraph.store.memory import InMemoryStore |
| 27 | +from typing_extensions import Annotated |
25 | 28 |
|
26 | 29 | from langchain.agents import create_agent |
27 | 30 | from langchain.agents.middleware.types import AgentState |
28 | | -from langchain.tools import ToolRuntime |
| 31 | +from langchain.tools import InjectedState, ToolRuntime |
29 | 32 |
|
30 | 33 | from .model import FakeToolCallingModel |
31 | 34 |
|
@@ -589,3 +592,243 @@ def name_based_tool(x: int, runtime: Any) -> str: |
589 | 592 | assert injected_data["tool_call_id"] == "name_call_123" |
590 | 593 | assert injected_data["state"] is not None |
591 | 594 | assert "messages" in injected_data["state"] |
| 595 | + |
| 596 | + |
| 597 | +def test_combined_injected_state_runtime_store() -> None: |
| 598 | + """Test that all injection mechanisms work together in create_agent. |
| 599 | +
|
| 600 | + This test verifies that a tool can receive injected state, tool runtime, |
| 601 | + and injected store simultaneously when specified in the function signature |
| 602 | + but not in the explicit args schema. This is modeled after the pattern |
| 603 | + from mre.py where multiple injection types are combined. |
| 604 | + """ |
| 605 | + # Track what was injected |
| 606 | + injected_data = {} |
| 607 | + |
| 608 | + # Custom state schema with additional fields |
| 609 | + class CustomState(AgentState): |
| 610 | + user_id: str |
| 611 | + session_id: str |
| 612 | + |
| 613 | + # Define explicit args schema that only includes LLM-controlled parameters |
| 614 | + weather_schema = { |
| 615 | + "type": "object", |
| 616 | + "properties": { |
| 617 | + "location": {"type": "string", "description": "The location to get weather for"}, |
| 618 | + }, |
| 619 | + "required": ["location"], |
| 620 | + } |
| 621 | + |
| 622 | + @tool(args_schema=weather_schema) |
| 623 | + def multi_injection_tool( |
| 624 | + location: str, |
| 625 | + state: Annotated[Any, InjectedState], |
| 626 | + runtime: ToolRuntime, |
| 627 | + store: Annotated[Any, InjectedStore()], |
| 628 | + ) -> str: |
| 629 | + """Tool that uses injected state, runtime, and store together. |
| 630 | +
|
| 631 | + Args: |
| 632 | + location: The location to get weather for (LLM-controlled). |
| 633 | + state: The graph state (injected). |
| 634 | + runtime: The tool runtime context (injected). |
| 635 | + store: The persistent store (injected). |
| 636 | + """ |
| 637 | + # Capture all injected parameters |
| 638 | + injected_data["state"] = state |
| 639 | + injected_data["user_id"] = state.get("user_id", "unknown") |
| 640 | + injected_data["session_id"] = state.get("session_id", "unknown") |
| 641 | + injected_data["runtime"] = runtime |
| 642 | + injected_data["tool_call_id"] = runtime.tool_call_id |
| 643 | + injected_data["store"] = store |
| 644 | + injected_data["store_is_none"] = store is None |
| 645 | + |
| 646 | + # Verify runtime.state matches the state parameter |
| 647 | + injected_data["runtime_state_matches"] = runtime.state == state |
| 648 | + |
| 649 | + return f"Weather info for {location}" |
| 650 | + |
| 651 | + # Create model that calls the tool |
| 652 | + model = FakeToolCallingModel( |
| 653 | + tool_calls=[ |
| 654 | + [ |
| 655 | + { |
| 656 | + "name": "multi_injection_tool", |
| 657 | + "args": {"location": "San Francisco"}, # Only LLM-controlled arg |
| 658 | + "id": "call_weather_123", |
| 659 | + } |
| 660 | + ], |
| 661 | + [], # End the loop |
| 662 | + ] |
| 663 | + ) |
| 664 | + |
| 665 | + # Create agent with custom state and store |
| 666 | + agent = create_agent( |
| 667 | + model=model, |
| 668 | + tools=[multi_injection_tool], |
| 669 | + state_schema=CustomState, |
| 670 | + store=InMemoryStore(), |
| 671 | + ) |
| 672 | + |
| 673 | + # Verify the tool's args schema only includes LLM-controlled parameters |
| 674 | + tool_args_schema = multi_injection_tool.args_schema |
| 675 | + assert "location" in tool_args_schema["properties"] |
| 676 | + assert "state" not in tool_args_schema["properties"] |
| 677 | + assert "runtime" not in tool_args_schema["properties"] |
| 678 | + assert "store" not in tool_args_schema["properties"] |
| 679 | + |
| 680 | + # Invoke with custom state fields |
| 681 | + result = agent.invoke( |
| 682 | + { |
| 683 | + "messages": [HumanMessage("What's the weather like?")], |
| 684 | + "user_id": "user_42", |
| 685 | + "session_id": "session_abc123", |
| 686 | + } |
| 687 | + ) |
| 688 | + |
| 689 | + # Verify tool executed successfully |
| 690 | + tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)] |
| 691 | + assert len(tool_messages) == 1 |
| 692 | + tool_message = tool_messages[0] |
| 693 | + assert tool_message.content == "Weather info for San Francisco" |
| 694 | + assert tool_message.tool_call_id == "call_weather_123" |
| 695 | + |
| 696 | + # Verify all injections worked correctly |
| 697 | + assert injected_data["state"] is not None |
| 698 | + assert "messages" in injected_data["state"] |
| 699 | + |
| 700 | + # Verify custom state fields were accessible |
| 701 | + assert injected_data["user_id"] == "user_42" |
| 702 | + assert injected_data["session_id"] == "session_abc123" |
| 703 | + |
| 704 | + # Verify runtime was injected |
| 705 | + assert injected_data["runtime"] is not None |
| 706 | + assert injected_data["tool_call_id"] == "call_weather_123" |
| 707 | + |
| 708 | + # Verify store was injected |
| 709 | + assert injected_data["store_is_none"] is False |
| 710 | + assert injected_data["store"] is not None |
| 711 | + |
| 712 | + # Verify runtime.state matches the injected state |
| 713 | + assert injected_data["runtime_state_matches"] is True |
| 714 | + |
| 715 | + |
| 716 | +async def test_combined_injected_state_runtime_store_async() -> None: |
| 717 | + """Test that all injection mechanisms work together in async execution. |
| 718 | +
|
| 719 | + This async version verifies that injected state, tool runtime, and injected |
| 720 | + store all work correctly with async tools in create_agent. |
| 721 | + """ |
| 722 | + # Track what was injected |
| 723 | + injected_data = {} |
| 724 | + |
| 725 | + # Custom state schema |
| 726 | + class CustomState(AgentState): |
| 727 | + api_key: str |
| 728 | + request_id: str |
| 729 | + |
| 730 | + # Define explicit args schema that only includes LLM-controlled parameters |
| 731 | + # Note: state, runtime, and store are NOT in this schema |
| 732 | + search_schema = { |
| 733 | + "type": "object", |
| 734 | + "properties": { |
| 735 | + "query": {"type": "string", "description": "The search query"}, |
| 736 | + "max_results": {"type": "integer", "description": "Maximum number of results"}, |
| 737 | + }, |
| 738 | + "required": ["query", "max_results"], |
| 739 | + } |
| 740 | + |
| 741 | + @tool(args_schema=search_schema) |
| 742 | + async def async_multi_injection_tool( |
| 743 | + query: str, |
| 744 | + max_results: int, |
| 745 | + state: Annotated[Any, InjectedState], |
| 746 | + runtime: ToolRuntime, |
| 747 | + store: Annotated[Any, InjectedStore()], |
| 748 | + ) -> str: |
| 749 | + """Async tool with multiple injection types. |
| 750 | +
|
| 751 | + Args: |
| 752 | + query: The search query (LLM-controlled). |
| 753 | + max_results: Maximum number of results (LLM-controlled). |
| 754 | + state: The graph state (injected). |
| 755 | + runtime: The tool runtime context (injected). |
| 756 | + store: The persistent store (injected). |
| 757 | + """ |
| 758 | + # Capture all injected parameters |
| 759 | + injected_data["state"] = state |
| 760 | + injected_data["api_key"] = state.get("api_key", "unknown") |
| 761 | + injected_data["request_id"] = state.get("request_id", "unknown") |
| 762 | + injected_data["runtime"] = runtime |
| 763 | + injected_data["tool_call_id"] = runtime.tool_call_id |
| 764 | + injected_data["config"] = runtime.config |
| 765 | + injected_data["store"] = store |
| 766 | + |
| 767 | + # Verify we can write to the store |
| 768 | + if store is not None: |
| 769 | + await store.aput(("test", "namespace"), "test_key", {"query": query}) |
| 770 | + # Read back to verify it worked |
| 771 | + item = await store.aget(("test", "namespace"), "test_key") |
| 772 | + injected_data["store_write_success"] = item is not None |
| 773 | + |
| 774 | + return f"Found {max_results} results for '{query}'" |
| 775 | + |
| 776 | + # Create model that calls the async tool |
| 777 | + model = FakeToolCallingModel( |
| 778 | + tool_calls=[ |
| 779 | + [ |
| 780 | + { |
| 781 | + "name": "async_multi_injection_tool", |
| 782 | + "args": {"query": "test search", "max_results": 10}, |
| 783 | + "id": "call_search_456", |
| 784 | + } |
| 785 | + ], |
| 786 | + [], |
| 787 | + ] |
| 788 | + ) |
| 789 | + |
| 790 | + # Create agent with custom state and store |
| 791 | + agent = create_agent( |
| 792 | + model=model, |
| 793 | + tools=[async_multi_injection_tool], |
| 794 | + state_schema=CustomState, |
| 795 | + store=InMemoryStore(), |
| 796 | + ) |
| 797 | + |
| 798 | + # Verify the tool's args schema only includes LLM-controlled parameters |
| 799 | + tool_args_schema = async_multi_injection_tool.args_schema |
| 800 | + assert "query" in tool_args_schema["properties"] |
| 801 | + assert "max_results" in tool_args_schema["properties"] |
| 802 | + assert "state" not in tool_args_schema["properties"] |
| 803 | + assert "runtime" not in tool_args_schema["properties"] |
| 804 | + assert "store" not in tool_args_schema["properties"] |
| 805 | + |
| 806 | + # Invoke async |
| 807 | + result = await agent.ainvoke( |
| 808 | + { |
| 809 | + "messages": [HumanMessage("Search for something")], |
| 810 | + "api_key": "sk-test-key-xyz", |
| 811 | + "request_id": "req_999", |
| 812 | + } |
| 813 | + ) |
| 814 | + |
| 815 | + # Verify tool executed successfully |
| 816 | + tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)] |
| 817 | + assert len(tool_messages) == 1 |
| 818 | + tool_message = tool_messages[0] |
| 819 | + assert tool_message.content == "Found 10 results for 'test search'" |
| 820 | + assert tool_message.tool_call_id == "call_search_456" |
| 821 | + |
| 822 | + # Verify all injections worked correctly |
| 823 | + assert injected_data["state"] is not None |
| 824 | + assert injected_data["api_key"] == "sk-test-key-xyz" |
| 825 | + assert injected_data["request_id"] == "req_999" |
| 826 | + |
| 827 | + # Verify runtime was injected |
| 828 | + assert injected_data["runtime"] is not None |
| 829 | + assert injected_data["tool_call_id"] == "call_search_456" |
| 830 | + assert injected_data["config"] is not None |
| 831 | + |
| 832 | + # Verify store was injected and writable |
| 833 | + assert injected_data["store"] is not None |
| 834 | + assert injected_data["store_write_success"] is True |
0 commit comments