1+ from collections .abc import Callable , Sequence
2+ from typing import TYPE_CHECKING , Any
3+
14from langchain_core .callbacks .manager import CallbackManagerForLLMRun
25from langchain_core .language_models .chat_models import BaseChatModel
36from langchain_core .messages import AIMessage , BaseMessage
47from langchain_core .outputs import ChatGeneration , ChatResult
8+ from langchain_core .tools import BaseTool
59from langgraph .checkpoint .memory import MemorySaver
610from langgraph .prebuilt import create_react_agent
711
812from langgraph_swarm import create_handoff_tool , create_swarm
913
14+ if TYPE_CHECKING :
15+ from langchain_core .runnables .config import RunnableConfig
16+
1017
1118class FakeChatModel (BaseChatModel ):
1219 idx : int = 0
@@ -21,13 +28,19 @@ def _generate(
2128 messages : list [BaseMessage ],
2229 stop : list [str ] | None = None ,
2330 run_manager : CallbackManagerForLLMRun | None = None ,
24- ** kwargs ,
31+ ** kwargs : Any ,
2532 ) -> ChatResult :
2633 generation = ChatGeneration (message = self .responses [self .idx ])
2734 self .idx += 1
2835 return ChatResult (generations = [generation ])
2936
30- def bind_tools (self , tools : list [any ]) -> "FakeChatModel" :
37+ def bind_tools (
38+ self ,
39+ tools : Sequence [dict [str , Any ] | type | Callable [..., Any ] | BaseTool ],
40+ * ,
41+ tool_choice : str | None = None ,
42+ ** kwargs : Any ,
43+ ) -> "FakeChatModel" :
3144 return self
3245
3346
@@ -80,7 +93,7 @@ def test_basic_swarm() -> None:
8093 ),
8194 ]
8295
83- model = FakeChatModel (responses = recorded_messages )
96+ model = FakeChatModel (responses = recorded_messages ) # type: ignore[arg-type]
8497
8598 def add (a : int , b : int ) -> int :
8699 """Add two numbers."""
@@ -109,9 +122,11 @@ def add(a: int, b: int) -> int:
109122 workflow = create_swarm ([alice , bob ], default_active_agent = "Alice" )
110123 app = workflow .compile (checkpointer = checkpointer )
111124
112- config = {"configurable" : {"thread_id" : "1" }}
125+ config : RunnableConfig = {"configurable" : {"thread_id" : "1" }}
113126 turn_1 = app .invoke (
114- {"messages" : [{"role" : "user" , "content" : "i'd like to speak to Bob" }]},
127+ { # type: ignore[arg-type]
128+ "messages" : [{"role" : "user" , "content" : "i'd like to speak to Bob" }]
129+ },
115130 config ,
116131 )
117132
@@ -122,7 +137,9 @@ def add(a: int, b: int) -> int:
122137 assert turn_1 ["active_agent" ] == "Bob"
123138
124139 turn_2 = app .invoke (
125- {"messages" : [{"role" : "user" , "content" : "what's 5 + 7?" }]},
140+ { # type: ignore[arg-type]
141+ "messages" : [{"role" : "user" , "content" : "what's 5 + 7?" }]
142+ },
126143 config ,
127144 )
128145
0 commit comments