@@ -720,7 +720,7 @@ def __init__(
720720 self .tool_mode : ToolMode = "auto"
721721 self .api_tool_choice : ApiToolChoice | None = None
722722 self .inject_tool_prompt = True
723- self .stop_on_tool_calls = True
723+ self .add_tool_stop_token = True
724724 self .then_callbacks : list [tuple [ThenChatCallback , int ]] = []
725725 self .map_callbacks : list [tuple [MapChatCallback , int ]] = []
726726 self .watch_callbacks : list [WatchChatCallback ] = watch_callbacks or []
@@ -880,6 +880,7 @@ def clone(
880880 * ,
881881 only_messages : bool = False ,
882882 chat : Chat | None = None ,
883+ callbacks : bool | t .Sequence [MapChatCallback | ThenChatCallback ] = True ,
883884 ) -> "ChatPipeline" :
884885 """
885886 Creates a clone of the current `ChatPipeline` instance.
@@ -890,6 +891,8 @@ def clone(
890891 including until callbacks, types, tools, metadata, etc.
891892 chat: An optional chat object clone for use in the new pipeline, otherwise the current
892893 internal chat object will be cloned.
894+ callbacks: If True (default), all callbacks will be cloned. If False, no callbacks will be cloned.
895+ Otherwise provide a sequence of callbacks which should be maintained in the new pipeline.
893896
894897 Returns:
895898 The cloned ChatPipeline.
@@ -906,16 +909,20 @@ def clone(
906909 new .tools = self .tools .copy ()
907910 new .tool_mode = self .tool_mode
908911 new .metadata = deepcopy (self .metadata )
909- new .map_callbacks = self .map_callbacks .copy ()
910912 new .on_failed = self .on_failed
911913 new .errors_to_catch = self .errors_to_catch .copy ()
912914 new .errors_to_exclude = self .errors_to_exclude .copy ()
913915 new .caching = self .caching
914916
917+ new .watch_callbacks = self .watch_callbacks .copy ()
918+
915919 # Check if any of our callbacks are bound methods to a ChatPipline.
916920 # If so, we should rebind them to `self` to ensure they work correctly
917921 # and aren't operating with old state.
918922
923+ if callbacks is False :
924+ return new
925+
919926 new .then_callbacks = [
920927 (callback , max_depth )
921928 if not hasattr (callback , "__self__" )
@@ -931,6 +938,18 @@ def clone(
931938 for callback , max_depth in self .map_callbacks .copy ()
932939 ]
933940
941+ if not isinstance (callbacks , bool ):
942+ new .then_callbacks = [
943+ (callback , max_depth )
944+ for callback , max_depth in self .then_callbacks
945+ if callback in callbacks
946+ ]
947+ new .map_callbacks = [
948+ (callback , max_depth )
949+ for callback , max_depth in self .map_callbacks
950+ if callback in callbacks
951+ ]
952+
934953 return new
935954
936955 def meta (self , ** kwargs : t .Any ) -> "ChatPipeline" :
@@ -1105,7 +1124,7 @@ def using(
11051124 mode : ToolMode | None = None ,
11061125 choice : ApiToolChoice | None = None ,
11071126 max_depth : int = DEFAULT_MAX_DEPTH ,
1108- stop_on_tool_calls : bool | None = None ,
1127+ add_stop_token : bool | None = None ,
11091128 ) -> "ChatPipeline" :
11101129 """
11111130 Adds a tool or a sequence of tools to participate in the generation process.
@@ -1119,7 +1138,8 @@ def using(
11191138 mode: The tool calling mode to use (e.g., "xml", "json-in-xml", "api").
11201139 choice: The API tool choice to use. This is only relevant when using the "api" tool mode.
11211140 max_depth: The maximum depth for recursive tool calls (this is shared between all tools).
1122- stop_on_tool_calls: When using natively parsed tools, whether to stop generation when a tool call block is observed.
1141+ add_stop_token: When using natively parsed tools ("xml", "json-in-xml"), use stop tokens to
1142+ immediately process a tool call when observed.
11231143
11241144 Returns:
11251145 The updated pipeline.
@@ -1172,8 +1192,8 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
11721192 if choice is not None :
11731193 self .api_tool_choice = choice
11741194
1175- if stop_on_tool_calls is not None :
1176- self .stop_on_tool_calls = stop_on_tool_calls
1195+ if add_stop_token is not None :
1196+ self .add_tool_stop_token = add_stop_token
11771197
11781198 return self
11791199
@@ -1237,7 +1257,7 @@ def until_parsed_as(
12371257
12381258 async def _then_tools (self , chat : Chat ) -> PipelineStepContextManager | None :
12391259 if (
1240- self .stop_on_tool_calls
1260+ self .add_tool_stop_token
12411261 and self .tool_mode in ["xml" , "json-in-xml" ]
12421262 and chat .stop_reason == "stop"
12431263 ):
@@ -1270,54 +1290,49 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
12701290 if not tool_calls :
12711291 return None
12721292
1273- next_pipeline = self .clone (chat = chat )
1293+ next_pipeline = self .clone (chat = chat , callbacks = [ self . _then_tools ] )
12741294
1275- should_continue = True
1295+ stop = False
12761296
12771297 for tool_call in tool_calls :
12781298 tool = next ((t for t in self .tools if t .name == tool_call .name ), None )
12791299 if tool is None :
12801300 raise UnknownToolError (tool_call .name )
12811301
1282- message , _should_continue = await tool .handle_tool_call (tool_call )
1302+ message , _stop = await tool .handle_tool_call (tool_call )
1303+ stop = _stop if not _stop else stop
12831304 next_pipeline .add (message )
12841305
1285- # If the tool returns none, we should resolve tool calls, but
1286- # not continue the pipeline.
1287-
1288- if not _should_continue :
1289- should_continue = _should_continue
1290-
12911306 # Need to prevent infinite loops and treat tool_choice like
12921307 # an ephemeral setting which resets after the first tool call.
12931308
12941309 if self .tool_mode == "api" and next_pipeline .params :
12951310 next_pipeline .params .tool_choice = None
12961311
1297- if not should_continue :
1312+ if stop :
12981313 # TODO(nick): Type hints here stop us from mixing step generators
12991314 # and basic chat returns.
13001315 return next_pipeline .chat # type: ignore [return-value]
13011316
13021317 return next_pipeline .step ()
13031318
13041319 async def _then_parse (self , chat : Chat ) -> PipelineStepContextManager | None :
1305- next_pipeline = self .clone (chat = chat )
1320+ next_pipeline = self .clone (chat = chat , callbacks = [ self . _then_parse ] )
13061321
13071322 try :
13081323 chat .last .parse_many (* self .until_types )
13091324 except ValidationError as e :
13101325 next_pipeline .add (
13111326 Message .from_model (
13121327 ValidationErrorModel (content = str (e )),
1313- suffix = "Rewrite your entire message with all the required xml structure ." ,
1328+ suffix = "Rewrite your entire message with all of the required xml elements ." ,
13141329 ),
13151330 )
13161331 except Exception as e : # noqa: BLE001
13171332 next_pipeline .add (
13181333 Message .from_model (
13191334 SystemErrorModel (content = str (e )),
1320- suffix = "Rewrite your entire message with all the required xml structure ." ,
1335+ suffix = "Rewrite your entire message with all of the required xml elements ." ,
13211336 ),
13221337 )
13231338 else : # parsed successfully
@@ -1336,7 +1351,7 @@ async def _pre_run(self) -> None:
13361351 self .chat .inject_tool_prompt (self .tools , self .tool_mode )
13371352 self .inject_native_tool_prompt = False
13381353
1339- if self .stop_on_tool_calls :
1354+ if self .add_tool_stop_token :
13401355 self .params = self .params = GenerateParams ()
13411356 self .params .stop = self .params .stop or []
13421357 self .params .stop .append (f"</{ TOOL_CALLS_TAG } >" )
0 commit comments