|
44 | 44 | "ModelRequest",
|
45 | 45 | "OmitFromSchema",
|
46 | 46 | "PublicAgentState",
|
| 47 | + "after_agent", |
| 48 | + "after_model", |
| 49 | + "before_agent", |
| 50 | + "before_model", |
47 | 51 | "dynamic_prompt",
|
48 | 52 | "hook_config",
|
| 53 | + "modify_model_request", |
49 | 54 | ]
|
50 | 55 |
|
51 | 56 | JumpTo = Literal["tools", "model", "end"]
|
@@ -93,7 +98,7 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
93 | 98 |
|
94 | 99 | messages: Required[Annotated[list[AnyMessage], add_messages]]
|
95 | 100 | jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
96 |
| - structured_response: NotRequired[ResponseT] |
| 101 | + structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]] |
97 | 102 | thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
98 | 103 | run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
|
99 | 104 |
|
@@ -133,6 +138,14 @@ def name(self) -> str:
|
133 | 138 | """
|
134 | 139 | return self.__class__.__name__
|
135 | 140 |
|
| 141 | + def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: |
| 142 | + """Logic to run before the agent execution starts.""" |
| 143 | + |
| 144 | + async def abefore_agent( |
| 145 | + self, state: StateT, runtime: Runtime[ContextT] |
| 146 | + ) -> dict[str, Any] | None: |
| 147 | + """Async logic to run before the agent execution starts.""" |
| 148 | + |
136 | 149 | def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
137 | 150 | """Logic to run before the model is called."""
|
138 | 151 |
|
@@ -215,6 +228,14 @@ async def aretry_model_request(
|
215 | 228 | None, self.retry_model_request, error, request, state, runtime, attempt
|
216 | 229 | )
|
217 | 230 |
|
| 231 | + def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: |
| 232 | + """Logic to run after the agent execution completes.""" |
| 233 | + |
| 234 | + async def aafter_agent( |
| 235 | + self, state: StateT, runtime: Runtime[ContextT] |
| 236 | + ) -> dict[str, Any] | None: |
| 237 | + """Async logic to run after the agent execution completes.""" |
| 238 | + |
218 | 239 |
|
219 | 240 | class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
220 | 241 | """Callable with AgentState and Runtime as arguments."""
|
@@ -707,6 +728,279 @@ def wrapped(
|
707 | 728 | return decorator
|
708 | 729 |
|
709 | 730 |
|
| 731 | +@overload |
| 732 | +def before_agent( |
| 733 | + func: _CallableWithStateAndRuntime[StateT, ContextT], |
| 734 | +) -> AgentMiddleware[StateT, ContextT]: ... |
| 735 | + |
| 736 | + |
| 737 | +@overload |
| 738 | +def before_agent( |
| 739 | + func: None = None, |
| 740 | + *, |
| 741 | + state_schema: type[StateT] | None = None, |
| 742 | + tools: list[BaseTool] | None = None, |
| 743 | + can_jump_to: list[JumpTo] | None = None, |
| 744 | + name: str | None = None, |
| 745 | +) -> Callable[ |
| 746 | + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] |
| 747 | +]: ... |
| 748 | + |
| 749 | + |
| 750 | +def before_agent( |
| 751 | + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, |
| 752 | + *, |
| 753 | + state_schema: type[StateT] | None = None, |
| 754 | + tools: list[BaseTool] | None = None, |
| 755 | + can_jump_to: list[JumpTo] | None = None, |
| 756 | + name: str | None = None, |
| 757 | +) -> ( |
| 758 | + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] |
| 759 | + | AgentMiddleware[StateT, ContextT] |
| 760 | +): |
| 761 | + """Decorator used to dynamically create a middleware with the before_agent hook. |
| 762 | +
|
| 763 | + Args: |
| 764 | + func: The function to be decorated. Must accept: |
| 765 | + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context |
| 766 | + state_schema: Optional custom state schema type. If not provided, uses the default |
| 767 | + AgentState schema. |
| 768 | + tools: Optional list of additional tools to register with this middleware. |
| 769 | + can_jump_to: Optional list of valid jump destinations for conditional edges. |
| 770 | + Valid values are: "tools", "model", "end" |
| 771 | + name: Optional name for the generated middleware class. If not provided, |
| 772 | + uses the decorated function's name. |
| 773 | +
|
| 774 | + Returns: |
| 775 | + Either an AgentMiddleware instance (if func is provided directly) or a decorator function |
| 776 | + that can be applied to a function its wrapping. |
| 777 | +
|
| 778 | + The decorated function should return: |
| 779 | + - `dict[str, Any]` - State updates to merge into the agent state |
| 780 | + - `Command` - A command to control flow (e.g., jump to different node) |
| 781 | + - `None` - No state updates or flow control |
| 782 | +
|
| 783 | + Examples: |
| 784 | + Basic usage: |
| 785 | + ```python |
| 786 | + @before_agent |
| 787 | + def log_before_agent(state: AgentState, runtime: Runtime) -> None: |
| 788 | + print(f"Starting agent with {len(state['messages'])} messages") |
| 789 | + ``` |
| 790 | +
|
| 791 | + With conditional jumping: |
| 792 | + ```python |
| 793 | + @before_agent(can_jump_to=["end"]) |
| 794 | + def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: |
| 795 | + if some_condition(state): |
| 796 | + return {"jump_to": "end"} |
| 797 | + return None |
| 798 | + ``` |
| 799 | +
|
| 800 | + With custom state schema: |
| 801 | + ```python |
| 802 | + @before_agent(state_schema=MyCustomState) |
| 803 | + def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: |
| 804 | + return {"custom_field": "initialized_value"} |
| 805 | + ``` |
| 806 | + """ |
| 807 | + |
| 808 | + def decorator( |
| 809 | + func: _CallableWithStateAndRuntime[StateT, ContextT], |
| 810 | + ) -> AgentMiddleware[StateT, ContextT]: |
| 811 | + is_async = iscoroutinefunction(func) |
| 812 | + |
| 813 | + func_can_jump_to = ( |
| 814 | + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) |
| 815 | + ) |
| 816 | + |
| 817 | + if is_async: |
| 818 | + |
| 819 | + async def async_wrapped( |
| 820 | + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 |
| 821 | + state: StateT, |
| 822 | + runtime: Runtime[ContextT], |
| 823 | + ) -> dict[str, Any] | Command | None: |
| 824 | + return await func(state, runtime) # type: ignore[misc] |
| 825 | + |
| 826 | + # Preserve can_jump_to metadata on the wrapped function |
| 827 | + if func_can_jump_to: |
| 828 | + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] |
| 829 | + |
| 830 | + middleware_name = name or cast( |
| 831 | + "str", getattr(func, "__name__", "BeforeAgentMiddleware") |
| 832 | + ) |
| 833 | + |
| 834 | + return type( |
| 835 | + middleware_name, |
| 836 | + (AgentMiddleware,), |
| 837 | + { |
| 838 | + "state_schema": state_schema or AgentState, |
| 839 | + "tools": tools or [], |
| 840 | + "abefore_agent": async_wrapped, |
| 841 | + }, |
| 842 | + )() |
| 843 | + |
| 844 | + def wrapped( |
| 845 | + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 |
| 846 | + state: StateT, |
| 847 | + runtime: Runtime[ContextT], |
| 848 | + ) -> dict[str, Any] | Command | None: |
| 849 | + return func(state, runtime) # type: ignore[return-value] |
| 850 | + |
| 851 | + # Preserve can_jump_to metadata on the wrapped function |
| 852 | + if func_can_jump_to: |
| 853 | + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] |
| 854 | + |
| 855 | + # Use function name as default if no name provided |
| 856 | + middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware")) |
| 857 | + |
| 858 | + return type( |
| 859 | + middleware_name, |
| 860 | + (AgentMiddleware,), |
| 861 | + { |
| 862 | + "state_schema": state_schema or AgentState, |
| 863 | + "tools": tools or [], |
| 864 | + "before_agent": wrapped, |
| 865 | + }, |
| 866 | + )() |
| 867 | + |
| 868 | + if func is not None: |
| 869 | + return decorator(func) |
| 870 | + return decorator |
| 871 | + |
| 872 | + |
| 873 | +@overload |
| 874 | +def after_agent( |
| 875 | + func: _CallableWithStateAndRuntime[StateT, ContextT], |
| 876 | +) -> AgentMiddleware[StateT, ContextT]: ... |
| 877 | + |
| 878 | + |
| 879 | +@overload |
| 880 | +def after_agent( |
| 881 | + func: None = None, |
| 882 | + *, |
| 883 | + state_schema: type[StateT] | None = None, |
| 884 | + tools: list[BaseTool] | None = None, |
| 885 | + can_jump_to: list[JumpTo] | None = None, |
| 886 | + name: str | None = None, |
| 887 | +) -> Callable[ |
| 888 | + [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT] |
| 889 | +]: ... |
| 890 | + |
| 891 | + |
| 892 | +def after_agent( |
| 893 | + func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None, |
| 894 | + *, |
| 895 | + state_schema: type[StateT] | None = None, |
| 896 | + tools: list[BaseTool] | None = None, |
| 897 | + can_jump_to: list[JumpTo] | None = None, |
| 898 | + name: str | None = None, |
| 899 | +) -> ( |
| 900 | + Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] |
| 901 | + | AgentMiddleware[StateT, ContextT] |
| 902 | +): |
| 903 | + """Decorator used to dynamically create a middleware with the after_agent hook. |
| 904 | +
|
| 905 | + Args: |
| 906 | + func: The function to be decorated. Must accept: |
| 907 | + `state: StateT, runtime: Runtime[ContextT]` - State and runtime context |
| 908 | + state_schema: Optional custom state schema type. If not provided, uses the default |
| 909 | + AgentState schema. |
| 910 | + tools: Optional list of additional tools to register with this middleware. |
| 911 | + can_jump_to: Optional list of valid jump destinations for conditional edges. |
| 912 | + Valid values are: "tools", "model", "end" |
| 913 | + name: Optional name for the generated middleware class. If not provided, |
| 914 | + uses the decorated function's name. |
| 915 | +
|
| 916 | + Returns: |
| 917 | + Either an AgentMiddleware instance (if func is provided) or a decorator function |
| 918 | + that can be applied to a function. |
| 919 | +
|
| 920 | + The decorated function should return: |
| 921 | + - `dict[str, Any]` - State updates to merge into the agent state |
| 922 | + - `Command` - A command to control flow (e.g., jump to different node) |
| 923 | + - `None` - No state updates or flow control |
| 924 | +
|
| 925 | + Examples: |
| 926 | + Basic usage for logging agent completion: |
| 927 | + ```python |
| 928 | + @after_agent |
| 929 | + def log_completion(state: AgentState, runtime: Runtime) -> None: |
| 930 | + print(f"Agent completed with {len(state['messages'])} messages") |
| 931 | + ``` |
| 932 | +
|
| 933 | + With custom state schema: |
| 934 | + ```python |
| 935 | + @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware") |
| 936 | + def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]: |
| 937 | + return {"custom_field": "finalized_value"} |
| 938 | + ``` |
| 939 | + """ |
| 940 | + |
| 941 | + def decorator( |
| 942 | + func: _CallableWithStateAndRuntime[StateT, ContextT], |
| 943 | + ) -> AgentMiddleware[StateT, ContextT]: |
| 944 | + is_async = iscoroutinefunction(func) |
| 945 | + # Extract can_jump_to from decorator parameter or from function metadata |
| 946 | + func_can_jump_to = ( |
| 947 | + can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", []) |
| 948 | + ) |
| 949 | + |
| 950 | + if is_async: |
| 951 | + |
| 952 | + async def async_wrapped( |
| 953 | + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 |
| 954 | + state: StateT, |
| 955 | + runtime: Runtime[ContextT], |
| 956 | + ) -> dict[str, Any] | Command | None: |
| 957 | + return await func(state, runtime) # type: ignore[misc] |
| 958 | + |
| 959 | + # Preserve can_jump_to metadata on the wrapped function |
| 960 | + if func_can_jump_to: |
| 961 | + async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] |
| 962 | + |
| 963 | + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) |
| 964 | + |
| 965 | + return type( |
| 966 | + middleware_name, |
| 967 | + (AgentMiddleware,), |
| 968 | + { |
| 969 | + "state_schema": state_schema or AgentState, |
| 970 | + "tools": tools or [], |
| 971 | + "aafter_agent": async_wrapped, |
| 972 | + }, |
| 973 | + )() |
| 974 | + |
| 975 | + def wrapped( |
| 976 | + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 |
| 977 | + state: StateT, |
| 978 | + runtime: Runtime[ContextT], |
| 979 | + ) -> dict[str, Any] | Command | None: |
| 980 | + return func(state, runtime) # type: ignore[return-value] |
| 981 | + |
| 982 | + # Preserve can_jump_to metadata on the wrapped function |
| 983 | + if func_can_jump_to: |
| 984 | + wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined] |
| 985 | + |
| 986 | + # Use function name as default if no name provided |
| 987 | + middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware")) |
| 988 | + |
| 989 | + return type( |
| 990 | + middleware_name, |
| 991 | + (AgentMiddleware,), |
| 992 | + { |
| 993 | + "state_schema": state_schema or AgentState, |
| 994 | + "tools": tools or [], |
| 995 | + "after_agent": wrapped, |
| 996 | + }, |
| 997 | + )() |
| 998 | + |
| 999 | + if func is not None: |
| 1000 | + return decorator(func) |
| 1001 | + return decorator |
| 1002 | + |
| 1003 | + |
710 | 1004 | @overload
|
711 | 1005 | def dynamic_prompt(
|
712 | 1006 | func: _CallableReturningPromptString[StateT, ContextT],
|
|
0 commit comments