Skip to content

Commit c50be28

Browse files
Fix for chat stream tool choice (#51)
* partially revert changes from PR 50 * improve condition for setting name in tool_call_chunk * fix mypy
1 parent 54d67d2 commit c50be28

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

libs/ibm/langchain_ibm/chat_models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def _convert_delta_to_message_chunk(
274274
_dict: Mapping[str, Any],
275275
default_class: Type[BaseMessageChunk],
276276
call_id: str,
277+
is_first_tool_chunk: bool,
277278
) -> BaseMessageChunk:
278279
id_ = call_id
279280
role = cast(str, _dict.get("role"))
@@ -290,8 +291,12 @@ def _convert_delta_to_message_chunk(
290291
try:
291292
tool_call_chunks = [
292293
tool_call_chunk(
293-
name=rtc["function"].get("name"),
294+
name=rtc["function"].get("name")
295+
if is_first_tool_chunk or (rtc.get("id") is not None)
296+
else None,
294297
args=rtc["function"].get("arguments"),
298+
# `id` is provided only for the first delta with unique tool_calls
299+
# (multiple tool calls scenario)
295300
id=rtc.get("id"),
296301
index=rtc["index"],
297302
)
@@ -328,6 +333,7 @@ def _convert_chunk_to_generation_chunk(
328333
default_chunk_class: Type,
329334
base_generation_info: Optional[Dict],
330335
is_first_chunk: bool,
336+
is_first_tool_chunk: bool,
331337
) -> Optional[ChatGenerationChunk]:
332338
token_usage = chunk.get("usage")
333339
choices = chunk.get("choices", [])
@@ -348,7 +354,7 @@ def _convert_chunk_to_generation_chunk(
348354
return None
349355

350356
message_chunk = _convert_delta_to_message_chunk(
351-
choice["delta"], default_chunk_class, chunk["id"]
357+
choice["delta"], default_chunk_class, chunk["id"], is_first_tool_chunk
352358
)
353359
generation_info = {**base_generation_info} if base_generation_info else {}
354360

@@ -722,6 +728,7 @@ def _stream(
722728
base_generation_info: dict = {}
723729

724730
is_first_chunk = True
731+
is_first_tool_chunk = True
725732

726733
for chunk in self.watsonx_model.chat_stream(
727734
messages=message_dicts, **(kwargs | {"params": updated_params})
@@ -733,6 +740,7 @@ def _stream(
733740
default_chunk_class,
734741
base_generation_info if is_first_chunk else {},
735742
is_first_chunk,
743+
is_first_tool_chunk,
736744
)
737745
if generation_chunk is None:
738746
continue
@@ -742,6 +750,16 @@ def _stream(
742750
run_manager.on_llm_new_token(
743751
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
744752
)
753+
if hasattr(generation_chunk.message, "tool_calls") and isinstance(
754+
generation_chunk.message.tool_calls, list
755+
):
756+
first_tool_call = (
757+
generation_chunk.message.tool_calls[0]
758+
if generation_chunk.message.tool_calls
759+
else None
760+
)
761+
if isinstance(first_tool_call, dict) and first_tool_call.get("name"):
762+
is_first_tool_chunk = False
745763

746764
is_first_chunk = False
747765

libs/ibm/tests/integration_tests/test_chat_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def get_weather(city: Literal["nyc"]) -> str:
667667
if ai_message is None:
668668
ai_message = chunk
669669
else:
670-
ai_message += chunk
670+
ai_message += chunk # type: ignore[assignment]
671671
print(chunk.id, type(chunk.id))
672672
assert isinstance(chunk, AIMessageChunk)
673673
assert chunk.content == ""

libs/ibm/tests/integration_tests/test_chat_models_standard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def chat_model_params(self) -> dict:
6060
"url": URL,
6161
"apikey": WX_APIKEY,
6262
"project_id": WX_PROJECT_ID,
63+
"temperature": 0,
6364
}
6465

6566
@pytest.mark.xfail(reason="Supported for vision model.")

0 commit comments

Comments
 (0)