Skip to content

Commit 0b15112

Browse files
hanouticelinaWauplingithub-actions[bot]
authored andcommitted
[Inference] Correctly build chat completion URL with query parameters (huggingface#3200)
* fix base url parsing * add comments * add test case * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * Apply style fixes --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 018156f commit 0b15112

File tree

4 files changed

+39
-14
lines changed

4 files changed

+39
-14
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,7 @@ class InferenceClient:
130130
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
131131
automatically selected for the task.
132132
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
133-
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
134-
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
135-
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
133+
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
136134
provider (`str`, *optional*):
137135
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"`, `"swarmind"` or `"together"`.
138136
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ class AsyncInferenceClient:
118118
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
119119
automatically selected for the task.
120120
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
121-
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
122-
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
123-
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
121+
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
124122
provider (`str`, *optional*):
125123
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
126124
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import lru_cache
33
from pathlib import Path
44
from typing import Any, Dict, Optional, Union
5+
from urllib.parse import urlparse, urlunparse
56

67
from huggingface_hub import constants
78
from huggingface_hub.hf_api import InferenceProviderMapping
@@ -125,18 +126,25 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
125126

126127

127128
def _build_chat_completion_url(model_url: str) -> str:
128-
# Strip trailing /
129-
model_url = model_url.rstrip("/")
129+
parsed = urlparse(model_url)
130+
path = parsed.path.rstrip("/")
130131

131-
# Append /chat/completions if not already present
132-
if model_url.endswith("/v1"):
133-
model_url += "/chat/completions"
132+
# If the path already ends with /chat/completions, we're done!
133+
if path.endswith("/chat/completions"):
134+
return model_url
134135

136+
# Append /chat/completions if not already present
137+
if path.endswith("/v1"):
138+
new_path = path + "/chat/completions"
139+
# If path was empty or just "/", set the full path
140+
elif not path:
141+
new_path = "/v1/chat/completions"
135142
# Append /v1/chat/completions if not already present
136-
if not model_url.endswith("/chat/completions"):
137-
model_url += "/v1/chat/completions"
143+
else:
144+
new_path = path + "/v1/chat/completions"
138145

139-
return model_url
146+
# Reconstruct the URL with the new path and original query parameters.
147+
return urlunparse(parsed._replace(path=new_path))
140148

141149

142150
@lru_cache(maxsize=1)

tests/test_inference_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,27 @@ def test_chat_completion_error_in_stream():
10341034
f"{LOCAL_TGI_URL}/v1",
10351035
f"{LOCAL_TGI_URL}/v1/chat/completions",
10361036
),
1037+
# With query parameters
1038+
(
1039+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions?api-version=1",
1040+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions?api-version=1",
1041+
),
1042+
(
1043+
f"{INFERENCE_ENDPOINT_URL}/chat/completions?api-version=1",
1044+
f"{INFERENCE_ENDPOINT_URL}/chat/completions?api-version=1",
1045+
),
1046+
(
1047+
f"{INFERENCE_ENDPOINT_URL}?api-version=1",
1048+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions?api-version=1",
1049+
),
1050+
(
1051+
f"{INFERENCE_ENDPOINT_URL}/v1?api-version=1",
1052+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions?api-version=1",
1053+
),
1054+
(
1055+
f"{INFERENCE_ENDPOINT_URL}/?api-version=1",
1056+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions?api-version=1",
1057+
),
10371058
],
10381059
)
10391060
def test_resolve_chat_completion_url(model_url: str, expected_url: str):

0 commit comments

Comments
 (0)