Skip to content

Commit a49ca75

Browse files
authored
Fix resolve chat completion URL (#2540)
* Add tests for test_resolve_chat_completion_url * Fix passing chat completion url
1 parent 3cd3286 commit a49ca75

File tree

5 files changed

+161
-64
lines changed

5 files changed

+161
-64
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b
829829

830830
@classmethod
831831
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
832-
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
832+
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined]
833833
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
834834
if map_location != "cpu":
835835
logger.warning(
@@ -840,7 +840,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
840840
)
841841
model.to(map_location) # type: ignore [attr-defined]
842842
else:
843-
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
843+
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type]
844844
return model
845845

846846

src/huggingface_hub/inference/_client.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -810,26 +810,7 @@ def chat_completion(
810810
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
811811
```
812812
"""
813-
# Determine model
814-
# `self.xxx` takes precedence over the method argument only in `chat_completion`
815-
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
816-
# server, we need to handle it differently
817-
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
818-
is_url = model_id_or_url.startswith(("http://", "https://"))
819-
820-
# First, resolve the model chat completions URL
821-
if model_id_or_url == self.base_url:
822-
# base_url passed => add server route
823-
model_url = model_id_or_url.rstrip("/")
824-
if not model_url.endswith("/v1"):
825-
model_url += "/v1"
826-
model_url += "/chat/completions"
827-
elif is_url:
828-
# model is a URL => use it directly
829-
model_url = model_id_or_url
830-
else:
831-
# model is a model ID => resolve it + add server route
832-
model_url = self._resolve_url(model_id_or_url).rstrip("/") + "/v1/chat/completions"
813+
model_url = self._resolve_chat_completion_url(model)
833814

834815
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
835816
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
@@ -865,6 +846,31 @@ def chat_completion(
865846

866847
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
867848

849+
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
850+
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
851+
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
852+
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
853+
854+
# Resolve URL if it's a model ID
855+
model_url = (
856+
model_id_or_url
857+
if model_id_or_url.startswith(("http://", "https://"))
858+
else self._resolve_url(model_id_or_url, task="text-generation")
859+
)
860+
861+
# Strip trailing /
862+
model_url = model_url.rstrip("/")
863+
864+
# Append /chat/completions if not already present
865+
if model_url.endswith("/v1"):
866+
model_url += "/chat/completions"
867+
868+
# Append /v1/chat/completions if not already present
869+
if not model_url.endswith("/chat/completions"):
870+
model_url += "/v1/chat/completions"
871+
872+
return model_url
873+
868874
def document_question_answering(
869875
self,
870876
image: ContentT,

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -850,26 +850,7 @@ async def chat_completion(
850850
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
851851
```
852852
"""
853-
# Determine model
854-
# `self.xxx` takes precedence over the method argument only in `chat_completion`
855-
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
856-
# server, we need to handle it differently
857-
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
858-
is_url = model_id_or_url.startswith(("http://", "https://"))
859-
860-
# First, resolve the model chat completions URL
861-
if model_id_or_url == self.base_url:
862-
# base_url passed => add server route
863-
model_url = model_id_or_url.rstrip("/")
864-
if not model_url.endswith("/v1"):
865-
model_url += "/v1"
866-
model_url += "/chat/completions"
867-
elif is_url:
868-
# model is a URL => use it directly
869-
model_url = model_id_or_url
870-
else:
871-
# model is a model ID => resolve it + add server route
872-
model_url = self._resolve_url(model_id_or_url).rstrip("/") + "/v1/chat/completions"
853+
model_url = self._resolve_chat_completion_url(model)
873854

874855
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
875856
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
@@ -905,6 +886,31 @@ async def chat_completion(
905886

906887
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
907888

889+
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
890+
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
891+
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
892+
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
893+
894+
# Resolve URL if it's a model ID
895+
model_url = (
896+
model_id_or_url
897+
if model_id_or_url.startswith(("http://", "https://"))
898+
else self._resolve_url(model_id_or_url, task="text-generation")
899+
)
900+
901+
# Strip trailing /
902+
model_url = model_url.rstrip("/")
903+
904+
# Append /chat/completions if not already present
905+
if model_url.endswith("/v1"):
906+
model_url += "/chat/completions"
907+
908+
# Append /v1/chat/completions if not already present
909+
if not model_url.endswith("/chat/completions"):
910+
model_url += "/v1/chat/completions"
911+
912+
return model_url
913+
908914
async def document_question_answering(
909915
self,
910916
image: ContentT,

src/huggingface_hub/utils/_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None)
473473

474474
# Convert `HTTPError` into a `HfHubHTTPError` to display request information
475475
# as well (request id and/or server error message)
476-
raise _format(HfHubHTTPError, "", response) from e
476+
raise _format(HfHubHTTPError, str(e), response) from e
477477

478478

479479
def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError:

tests/test_inference_client.py

Lines changed: 106 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
import unittest
1818
from pathlib import Path
19+
from typing import Optional
1920
from unittest.mock import MagicMock, patch
2021

2122
import numpy as np
@@ -918,27 +919,6 @@ def test_model_and_base_url_mutually_exclusive(self):
918919
InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000")
919920

920921

921-
@pytest.mark.parametrize(
922-
"base_url",
923-
[
924-
"http://0.0.0.0:8080/v1", # expected from OpenAI client
925-
"http://0.0.0.0:8080", # but not mandatory
926-
"http://0.0.0.0:8080/v1/", # ok with trailing '/' as well
927-
"http://0.0.0.0:8080/", # ok with trailing '/' as well
928-
],
929-
)
930-
def test_chat_completion_base_url_works_with_v1(base_url: str):
931-
"""Test that `/v1/chat/completions` is correctly appended to the base URL.
932-
933-
This is a regression test for https://github.com/huggingface/huggingface_hub/issues/2414
934-
"""
935-
with patch("huggingface_hub.inference._client.InferenceClient.post") as post_mock:
936-
client = InferenceClient(base_url=base_url)
937-
post_mock.return_value = "{}"
938-
client.chat_completion(messages=CHAT_COMPLETION_MESSAGES, stream=False)
939-
assert post_mock.call_args_list[0].kwargs["model"] == "http://0.0.0.0:8080/v1/chat/completions"
940-
941-
942922
@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "])
943923
def test_stream_text_generation_response(stop_signal: bytes):
944924
data = [
@@ -970,3 +950,108 @@ def test_stream_chat_completion_response(stop_signal: bytes):
970950
assert len(output) == 2
971951
assert output[0].choices[0].delta.content == "Both"
972952
assert output[1].choices[0].delta.content == " Rust"
953+
954+
955+
INFERENCE_API_URL = "https://api-inference.huggingface.co/models"
956+
INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example
957+
LOCAL_TGI_URL = "http://0.0.0.0:8080"
958+
959+
960+
@pytest.mark.parametrize(
961+
("client_model", "client_base_url", "model", "expected_url"),
962+
[
963+
(
964+
# Inference API - model global to client
965+
"username/repo_name",
966+
None,
967+
None,
968+
f"{INFERENCE_API_URL}/username/repo_name/v1/chat/completions",
969+
),
970+
(
971+
# Inference API - model specific to request
972+
None,
973+
None,
974+
"username/repo_name",
975+
f"{INFERENCE_API_URL}/username/repo_name/v1/chat/completions",
976+
),
977+
(
978+
# Inference Endpoint - url global to client as 'model'
979+
INFERENCE_ENDPOINT_URL,
980+
None,
981+
None,
982+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
983+
),
984+
(
985+
# Inference Endpoint - url global to client as 'base_url'
986+
None,
987+
INFERENCE_ENDPOINT_URL,
988+
None,
989+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
990+
),
991+
(
992+
# Inference Endpoint - url specific to request
993+
None,
994+
None,
995+
INFERENCE_ENDPOINT_URL,
996+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
997+
),
998+
(
999+
# Inference Endpoint - url global to client as 'base_url' - explicit model id
1000+
None,
1001+
INFERENCE_ENDPOINT_URL,
1002+
"username/repo_name",
1003+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1004+
),
1005+
(
1006+
# Inference Endpoint - full url global to client as 'model'
1007+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1008+
None,
1009+
None,
1010+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1011+
),
1012+
(
1013+
# Inference Endpoint - full url global to client as 'base_url'
1014+
None,
1015+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1016+
None,
1017+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1018+
),
1019+
(
1020+
# Inference Endpoint - full url specific to request
1021+
None,
1022+
None,
1023+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1024+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1025+
),
1026+
(
1027+
# Inference Endpoint - url with '/v1' (OpenAI compatibility)
1028+
# Regression test for https://github.com/huggingface/huggingface_hub/pull/2418
1029+
None,
1030+
None,
1031+
f"{INFERENCE_ENDPOINT_URL}/v1",
1032+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1033+
),
1034+
(
1035+
# Inference Endpoint - url with '/v1/' (OpenAI compatibility)
1036+
# Regression test for https://github.com/huggingface/huggingface_hub/pull/2418
1037+
None,
1038+
None,
1039+
f"{INFERENCE_ENDPOINT_URL}/v1/",
1040+
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
1041+
),
1042+
(
1043+
# Local TGI with trailing '/v1'
1044+
# Regression test for https://github.com/huggingface/huggingface_hub/issues/2414
1045+
f"{LOCAL_TGI_URL}/v1", # expected from OpenAI client
1046+
None,
1047+
None,
1048+
f"{LOCAL_TGI_URL}/v1/chat/completions",
1049+
),
1050+
],
1051+
)
1052+
def test_resolve_chat_completion_url(
1053+
client_model: Optional[str], client_base_url: Optional[str], model: Optional[str], expected_url: str
1054+
):
1055+
client = InferenceClient(model=client_model, base_url=client_base_url)
1056+
url = client._resolve_chat_completion_url(model)
1057+
assert url == expected_url

0 commit comments

Comments
 (0)