From 050ecbde8fbd6fa0d04613d9db1a15c6407f8eab Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 11:05:24 +0100 Subject: [PATCH 1/2] fix ollama --- .../inference/_mcp/mcp_client.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/huggingface_hub/inference/_mcp/mcp_client.py b/src/huggingface_hub/inference/_mcp/mcp_client.py index 7d7744574e..69383e2d68 100644 --- a/src/huggingface_hub/inference/_mcp/mcp_client.py +++ b/src/huggingface_hub/inference/_mcp/mcp_client.py @@ -3,22 +3,20 @@ from contextlib import AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload +from typing import (TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, + Optional, Union, overload) from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack from ...utils._runtime import get_hf_hub_version from .._generated._async_client import AsyncInferenceClient -from .._generated.types import ( - ChatCompletionInputMessage, - ChatCompletionInputTool, - ChatCompletionStreamOutput, - ChatCompletionStreamOutputDeltaToolCall, -) +from .._generated.types import (ChatCompletionInputMessage, + ChatCompletionInputTool, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputDeltaToolCall) from .._providers import PROVIDER_OR_POLICY_T from .utils import format_result - if TYPE_CHECKING: from mcp import ClientSession @@ -286,16 +284,19 @@ async def process_single_turn_with_tools( # Process tool calls if delta.tool_calls: for tool_call in delta.tool_calls: + idx = tool_call.index # Aggregate chunks into tool calls - if tool_call.index not in final_tool_calls: - if ( - tool_call.function.arguments is None or tool_call.function.arguments == "{}" - ): # Corner case (depends on provider) - tool_call.function.arguments = "" - final_tool_calls[tool_call.index] = tool_call - - elif tool_call.function.arguments: - final_tool_calls[tool_call.index].function.arguments += tool_call.function.arguments + if idx not in final_tool_calls: + final_tool_calls[idx] = tool_call + if final_tool_calls[idx].function.arguments is None: + final_tool_calls[idx].function.arguments = "" + continue + + if final_tool_calls[idx].function.arguments is None: + final_tool_calls[idx].function.arguments = "" + + if tool_call.function.arguments: + final_tool_calls[idx].function.arguments += tool_call.function.arguments # Optionally exit early if no tools in first chunks if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0: From 41011d22a68bef330b6393354741ff5decc77e1c Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 11:13:11 +0100 Subject: [PATCH 2/2] style --- .../inference/_mcp/mcp_client.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/huggingface_hub/inference/_mcp/mcp_client.py b/src/huggingface_hub/inference/_mcp/mcp_client.py index 69383e2d68..a836b7a36c 100644 --- a/src/huggingface_hub/inference/_mcp/mcp_client.py +++ b/src/huggingface_hub/inference/_mcp/mcp_client.py @@ -3,20 +3,22 @@ from contextlib import AsyncExitStack from datetime import timedelta from pathlib import Path -from typing import (TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, - Optional, Union, overload) +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack from ...utils._runtime import get_hf_hub_version from .._generated._async_client import AsyncInferenceClient -from .._generated.types import (ChatCompletionInputMessage, - ChatCompletionInputTool, - ChatCompletionStreamOutput, - ChatCompletionStreamOutputDeltaToolCall) +from .._generated.types import ( + ChatCompletionInputMessage, + ChatCompletionInputTool, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputDeltaToolCall, +) from .._providers import PROVIDER_OR_POLICY_T from .utils import format_result + if TYPE_CHECKING: from mcp import ClientSession @@ -285,16 +287,16 @@ async def process_single_turn_with_tools( if delta.tool_calls: for tool_call in delta.tool_calls: idx = tool_call.index - # Aggregate chunks into tool calls + # first chunk for this tool call if idx not in final_tool_calls: final_tool_calls[idx] = tool_call if final_tool_calls[idx].function.arguments is None: final_tool_calls[idx].function.arguments = "" continue - + # safety before concatenating text to .function.arguments if final_tool_calls[idx].function.arguments is None: final_tool_calls[idx].function.arguments = "" - + if tool_call.function.arguments: final_tool_calls[idx].function.arguments += tool_call.function.arguments