Skip to content

Commit f0f99dd

Browse files
committed
Set self.prev_tool_call_arr after generating the first DeltaToolCall
Signed-off-by: avigny <[email protected]>
1 parent 4d63b38 commit f0f99dd

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

tests/tool_use/test_mistral_tool_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, model_output,
302302
assert len(streamed_tool_calls) == 1
303303
tool_call = streamed_tool_calls[0]
304304

305+
assert len(tool_parser.prev_tool_call_arr) > 0
306+
305307
# if a new tool is being called, set up empty arguments
306308
if tool_call.index != tool_call_idx:
307309
tool_call_idx = tool_call.index

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,13 @@ def _extract_tool_calls_streaming(
245245
content=additional_content,
246246
tool_calls=delta_tool_calls,
247247
)
248+
# HACK: serving_chat.py inspects the internal state of tool parsers
249+
# when determining it's final streaming delta, automatically
250+
# adding autocompleted JSON.
251+
# These two lines avoid that nonsense while ensuring finish_reason
252+
# is set to tool_calls when at least one tool is called.
253+
if delta and not self.prev_tool_call_arr:
254+
self.prev_tool_call_arr = [{"arguments": {}}]
248255
return delta
249256

250257
def _generate_delta_tool_call(self,
@@ -476,6 +483,15 @@ def _extract_tool_calls_streaming_pre_v11_tokenizer(
476483

477484
if current_tool_call_modified:
478485
delta_tool_calls.append(current_tool_call)
486+
487+
# HACK: serving_chat.py inspects the internal state of tool parsers
488+
# when determining it's final streaming delta, automatically
489+
# adding autocompleted JSON.
490+
# These two lines avoid that nonsense while ensuring finish_reason
491+
# is set to tool_calls when at least one tool is called.
492+
if delta_tool_calls and not self.prev_tool_call_arr:
493+
self.prev_tool_call_arr = [{"arguments": {}}]
494+
479495
return DeltaMessage(content=content, tool_calls=delta_tool_calls)
480496

481497
def _split_delta(

0 commit comments

Comments
 (0)