2020
2121from elasticsearch import AsyncElasticsearch
2222from loguru import logger
23- from pydantic import BaseModel , ConfigDict , Field , PlainSerializer , ValidationError , computed_field
23+ from pydantic import (
24+ BaseModel ,
25+ ConfigDict ,
26+ Field ,
27+ PlainSerializer ,
28+ ValidationError ,
29+ computed_field ,
30+ )
2431
2532from rigging .error import MaxDepthError , UnknownToolError
2633from rigging .generator import GenerateParams , Generator , get_generator
3037from rigging .tool .api import ApiToolCall , ApiToolChoice
3138from rigging .tool .base import Tool , ToolMode
3239from rigging .tool .native import (
40+ TOOL_CALLS_TAG ,
3341 JsonInXmlToolCall ,
3442 JsonInXmlToolDefinition ,
3543 XmlToolCall ,
@@ -101,10 +109,17 @@ class Chat(BaseModel):
101109 params : GenerateParams | None = Field (None , exclude = True , repr = False )
102110 """Any additional generation params used for this chat."""
103111
104- error : t .Annotated [
105- BaseException ,
106- PlainSerializer (lambda x : str (x ), return_type = str , when_used = "json-unless-none" ),
107- ] | None = Field (None , repr = False )
112+ error : (
113+ t .Annotated [
114+ BaseException ,
115+ PlainSerializer (
116+ lambda x : str (x ),
117+ return_type = str ,
118+ when_used = "json-unless-none" ,
119+ ),
120+ ]
121+ | None
122+ ) = Field (None , repr = False )
108123 """Holds any exception that was caught during the generation pipeline."""
109124 failed : bool = Field (False , exclude = False , repr = True )
110125 """
@@ -199,7 +214,10 @@ def message_dicts(self) -> list[MessageDict]:
199214 The MessageDict list
200215 """
201216 return [
202- t .cast (MessageDict , m .model_dump (include = {"role" , "content_parts" }, exclude_none = True ))
217+ t .cast (
218+ MessageDict ,
219+ m .model_dump (include = {"role" , "content_parts" }, exclude_none = True ),
220+ )
203221 for m in self .all
204222 ]
205223
@@ -364,7 +382,11 @@ def inject_system_content(self, content: str) -> "Chat":
364382 self .messages [0 ].content += "\n \n " + content
365383 return self
366384
367- def inject_tool_prompt (self , tools : t .Sequence [Tool [..., t .Any ]], mode : ToolMode ) -> "Chat" :
385+ def inject_tool_prompt (
386+ self ,
387+ tools : t .Sequence [Tool [..., t .Any ]],
388+ mode : ToolMode ,
389+ ) -> "Chat" :
368390 """
369391 Injects a default tool use prompt into the system prompt.
370392
@@ -699,14 +721,19 @@ def __init__(
699721 self .tool_mode : ToolMode = "auto"
700722 self .api_tool_choice : ApiToolChoice | None = None
701723 self .inject_tool_prompt = True
724+ self .stop_on_tool_calls = True
702725 self .then_callbacks : list [tuple [ThenChatCallback , int ]] = []
703726 self .map_callbacks : list [tuple [MapChatCallback , int ]] = []
704727 self .watch_callbacks : list [WatchChatCallback ] = watch_callbacks or []
705728
706729 def __len__ (self ) -> int :
707730 return len (self .chat )
708731
709- def with_ (self , params : GenerateParams | None = None , ** kwargs : t .Any ) -> "ChatPipeline" :
732+ def with_ (
733+ self ,
734+ params : GenerateParams | None = None ,
735+ ** kwargs : t .Any ,
736+ ) -> "ChatPipeline" :
710737 """
711738 Assign specific generation parameter overloads for this chat.
712739
@@ -850,7 +877,12 @@ def fork(
850877 """
851878 return self .clone ().add (messages )
852879
853- def clone (self , * , only_messages : bool = False , chat : Chat | None = None ) -> "ChatPipeline" :
880+ def clone (
881+ self ,
882+ * ,
883+ only_messages : bool = False ,
884+ chat : Chat | None = None ,
885+ ) -> "ChatPipeline" :
854886 """
855887 Creates a clone of the current `ChatPipeline` instance.
856888
@@ -1036,7 +1068,10 @@ def apply_to_all(self, **kwargs: str) -> "ChatPipeline":
10361068 new .chat .apply_to_all (** kwargs )
10371069 return new
10381070
1039- def cache (self , mode : CacheMode | None | t .Literal [False ] = "latest" ) -> "ChatPipeline" :
1071+ def cache (
1072+ self ,
1073+ mode : CacheMode | None | t .Literal [False ] = "latest" ,
1074+ ) -> "ChatPipeline" :
10401075 """
10411076 Sets the caching mode for the pipeline.
10421077
@@ -1072,6 +1107,7 @@ def using(
10721107 mode : ToolMode | None = None ,
10731108 choice : ApiToolChoice | None = None ,
10741109 max_depth : int = DEFAULT_MAX_DEPTH ,
1110+ stop_on_tool_calls : bool | None = None ,
10751111 ) -> "ChatPipeline" :
10761112 """
10771113 Adds a tool or a sequence of tools to participate in the generation process.
@@ -1085,6 +1121,7 @@ def using(
10851121 mode: The tool calling mode to use (e.g., "xml", "json-in-xml", "api").
10861122 choice: The API tool choice to use. This is only relevant when using the "api" tool mode.
10871123 max_depth: The maximum depth for recursive tool calls (this is shared between all tools).
1124+ stop_on_tool_calls: When using natively parsed tools, whether to stop generation when a tool call block is observed.
10881125
10891126 Returns:
10901127 The updated pipeline.
@@ -1115,7 +1152,9 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
11151152 existing_names = {tool .name for tool in self .tools }
11161153 for tool in new_tools :
11171154 if tool .name in existing_names :
1118- raise ValueError (f"Tool with name '{ tool .name } ' already exists in the pipeline." )
1155+ raise ValueError (
1156+ f"Tool with name '{ tool .name } ' already exists in the pipeline." ,
1157+ )
11191158
11201159 self .tools += new_tools
11211160
@@ -1124,14 +1163,20 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
11241163 for callback , max_depth in self .then_callbacks
11251164 if callback != self ._then_tools
11261165 ]
1127- self .then_callbacks .insert (0 , (self ._then_tools , max_depth )) # make sure this is first
1166+ self .then_callbacks .insert (
1167+ 0 ,
1168+ (self ._then_tools , max_depth ),
1169+ ) # make sure this is first
11281170
11291171 if mode is not None :
11301172 self .tool_mode = mode
11311173
11321174 if choice is not None :
11331175 self .api_tool_choice = choice
11341176
1177+ if stop_on_tool_calls is not None :
1178+ self .stop_on_tool_calls = stop_on_tool_calls
1179+
11351180 return self
11361181
11371182 def until_parsed_as (
@@ -1193,7 +1238,30 @@ def until_parsed_as(
11931238 # Internal callbacks for handling tools and parsing
11941239
11951240 async def _then_tools (self , chat : Chat ) -> PipelineStepContextManager | None :
1196- tool_calls : list [ApiToolCall ] | list [XmlToolCall ] | list [JsonInXmlToolCall ] | None = None
1241+ if (
1242+ self .stop_on_tool_calls
1243+ and self .tool_mode in ["xml" , "json-in-xml" ]
1244+ and chat .stop_reason == "stop"
1245+ ):
1246+ # If we:
1247+ # 1. Are using native tools
1248+ # 2. Set a stop token for the tool calls
1249+ # 3. Hit that stop token
1250+ #
1251+ # Then we should re-inject the closing tag for completeness.
1252+
1253+ for part in chat .last .content_parts :
1254+ if (
1255+ part .type == "text"
1256+ and f"<{ TOOL_CALLS_TAG } >" in part .text
1257+ and f"</{ TOOL_CALLS_TAG } >" not in part .text
1258+ ):
1259+ part .text += f"</{ TOOL_CALLS_TAG } >"
1260+ break
1261+
1262+ # Parse the actual tool calls
1263+
1264+ tool_calls : (list [ApiToolCall ] | list [XmlToolCall ] | list [JsonInXmlToolCall ] | None ) = None
11971265 if self .tool_mode == "api" :
11981266 tool_calls = chat .last .tool_calls
11991267 if self .tool_mode == "xml" :
@@ -1265,9 +1333,15 @@ async def _pre_run(self) -> None:
12651333 if self .tool_mode == "auto" and self .tools :
12661334 self .tool_mode = "api" if await self .generator .supports_function_calling () else "xml"
12671335
1268- if self .tools and self .tool_mode in ["xml" , "json-in-xml" ] and self .inject_tool_prompt :
1269- self .chat .inject_tool_prompt (self .tools , self .tool_mode )
1270- self .inject_native_tool_prompt = False
1336+ if self .tools and self .tool_mode in ["xml" , "json-in-xml" ]:
1337+ if self .inject_tool_prompt :
1338+ self .chat .inject_tool_prompt (self .tools , self .tool_mode )
1339+ self .inject_native_tool_prompt = False
1340+
1341+ if self .stop_on_tool_calls :
1342+ self .params = self .params = GenerateParams ()
1343+ self .params .stop = self .params .stop or []
1344+ self .params .stop .append (f"</{ TOOL_CALLS_TAG } >" )
12711345
12721346 if self .tools and self .tool_mode == "api" :
12731347 if self .params is None :
@@ -1287,17 +1361,25 @@ def _fit_params(
12871361 params = [self .params .merge_with (p ) for p in params ]
12881362 return [(p or GenerateParams ()) for p in params ]
12891363
1290- def _apply_cache_mode_to_messages (self , messages : list [list [Message ]]) -> list [list [Message ]]:
1364+ def _apply_cache_mode_to_messages (
1365+ self ,
1366+ messages : list [list [Message ]],
1367+ ) -> list [list [Message ]]:
12911368 if self .caching is None :
12921369 return messages
12931370
12941371 if self .caching != "latest" :
1295- logger .warning (f"Unknown caching mode '{ self .caching } ', defaulting to 'latest'" )
1372+ logger .warning (
1373+ f"Unknown caching mode '{ self .caching } ', defaulting to 'latest'" ,
1374+ )
12961375
12971376 # first remove existing cache settings
12981377 updated : list [list [Message ]] = []
12991378 for _messages in messages :
1300- updated = [* updated , [m .clone ().cache (cache_control = False ) for m in _messages ]]
1379+ updated = [
1380+ * updated ,
1381+ [m .clone ().cache (cache_control = False ) for m in _messages ],
1382+ ]
13011383
13021384 # then apply the latest cache settings
13031385 for _messages in updated :
@@ -1463,15 +1545,6 @@ async def _step( # noqa: PLR0915, PLR0912
14631545
14641546 # Check if we should immediately raise
14651547
1466- # FailMode = t.Literal["raise", "skip", "include"]
1467- # self.on_failed: FailMode = "raise"
1468- # """How to handle failures in the pipeline unless overriden in calls."""
1469-
1470- # self.errors_to_catch: set[type[Exception]] = {MaxDepthError, ValidationError}
1471- # """The list of exceptions to catch during generation if you are including or skipping failures."""
1472- # self.errors_to_exclude: set[type[Exception]] = set()
1473- # """The list of exceptions to exclude from the catch list."""
1474-
14751548 for chat in chats :
14761549 if chat .error is not None and (
14771550 on_failed == "raise"
@@ -1532,7 +1605,11 @@ async def _step( # noqa: PLR0915, PLR0912
15321605 step = state .step .with_parent (current_step )
15331606
15341607 if step .depth > max_depth :
1535- max_depth_error = MaxDepthError (max_depth , step , callback_name )
1608+ max_depth_error = MaxDepthError (
1609+ max_depth ,
1610+ step ,
1611+ callback_name ,
1612+ )
15361613 if on_failed == "raise" :
15371614 raise max_depth_error
15381615
@@ -1597,12 +1674,18 @@ async def _step( # noqa: PLR0915, PLR0912
15971674 if inspect .isasyncgen (chats_or_generator ):
15981675 generator = t .cast (
15991676 PipelineStepGenerator ,
1600- await exit_stack .enter_async_context (aclosing (chats_or_generator )),
1677+ await exit_stack .enter_async_context (
1678+ aclosing (chats_or_generator ),
1679+ ),
16011680 )
16021681 async for step in generator :
16031682 _step = step .with_parent (current_step )
16041683 if _step .depth > max_depth :
1605- max_depth_error = MaxDepthError (max_depth , _step , callback_name )
1684+ max_depth_error = MaxDepthError (
1685+ max_depth ,
1686+ _step ,
1687+ callback_name ,
1688+ )
16061689 if on_failed == "raise" :
16071690 raise max_depth_error
16081691
@@ -1676,10 +1759,17 @@ async def step(
16761759 generator_id = self .generator .to_identifier (),
16771760 params = self .params .to_dict () if self .params is not None else {},
16781761 ) as span :
1679- async with aclosing (self ._step (span , messages , params , on_failed )) as generator :
1762+ async with aclosing (
1763+ self ._step (span , messages , params , on_failed ),
1764+ ) as generator :
16801765 yield generator
16811766
1682- async def run (self , * , on_failed : FailMode | None = None , allow_failed : bool = False ) -> Chat :
1767+ async def run (
1768+ self ,
1769+ * ,
1770+ on_failed : FailMode | None = None ,
1771+ allow_failed : bool = False ,
1772+ ) -> Chat :
16831773 """
16841774 Execute the generation process for a single message.
16851775
@@ -1749,7 +1839,9 @@ async def step_many(
17491839 generator_id = self .generator .to_identifier (),
17501840 params = self .params .to_dict () if self .params is not None else {},
17511841 ) as span :
1752- async with aclosing (self ._step (span , messages , params , on_failed )) as generator :
1842+ async with aclosing (
1843+ self ._step (span , messages , params , on_failed ),
1844+ ) as generator :
17531845 yield generator
17541846
17551847 async def run_many (
@@ -1827,7 +1919,9 @@ async def step_batch(
18271919 messages = [[* self .chat .all , * Message .fit_as_list (m )] for m in many ]
18281920 if len (messages ) < count :
18291921 if len (messages ) != 1 :
1830- raise ValueError (f"Can't fit { len (messages )} messages to { count } params" )
1922+ raise ValueError (
1923+ f"Can't fit { len (messages )} messages to { count } params" ,
1924+ )
18311925 messages = messages * count
18321926
18331927 params = self ._fit_params (count , params )
@@ -1838,7 +1932,9 @@ async def step_batch(
18381932 generator_id = self .generator .to_identifier (),
18391933 params = self .params .to_dict () if self .params is not None else {},
18401934 ) as span :
1841- async with aclosing (self ._step (span , messages , params , on_failed )) as generator :
1935+ async with aclosing (
1936+ self ._step (span , messages , params , on_failed ),
1937+ ) as generator :
18421938 yield generator
18431939
18441940 async def run_batch (
0 commit comments