Skip to content

Commit 9e05812

Browse files
committed
Fix chat completion url for OpenAI compatibility (#2418)
1 parent 09fc1d4 commit 9e05812

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,13 +819,16 @@ def chat_completion(
819819
# First, resolve the model chat completions URL
820820
if model == self.base_url:
821821
# base_url passed => add server route
822-
model_url = model + "/v1/chat/completions"
822+
model_url = model.rstrip("/")
823+
if not model_url.endswith("/v1"):
824+
model_url += "/v1"
825+
model_url += "/chat/completions"
823826
elif is_url:
824827
# model is a URL => use it directly
825828
model_url = model
826829
else:
827830
# model is a model ID => resolve it + add server route
828-
model_url = self._resolve_url(model) + "/v1/chat/completions"
831+
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"
829832

830833
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
831834
# If it's a ID on the Hub => use it. Otherwise, we use a random string.

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,13 +825,16 @@ async def chat_completion(
825825
# First, resolve the model chat completions URL
826826
if model == self.base_url:
827827
# base_url passed => add server route
828-
model_url = model + "/v1/chat/completions"
828+
model_url = model.rstrip("/")
829+
if not model_url.endswith("/v1"):
830+
model_url += "/v1"
831+
model_url += "/chat/completions"
829832
elif is_url:
830833
# model is a URL => use it directly
831834
model_url = model
832835
else:
833836
# model is a model ID => resolve it + add server route
834-
model_url = self._resolve_url(model) + "/v1/chat/completions"
837+
model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions"
835838

836839
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
837840
# If it's a ID on the Hub => use it. Otherwise, we use a random string.

tests/test_inference_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,27 @@ def test_model_and_base_url_mutually_exclusive(self):
953953
InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000")
954954

955955

956+
@pytest.mark.parametrize(
957+
"base_url",
958+
[
959+
"http://0.0.0.0:8080/v1", # expected from OpenAI client
960+
"http://0.0.0.0:8080", # but not mandatory
961+
"http://0.0.0.0:8080/v1/", # ok with trailing '/' as well
962+
"http://0.0.0.0:8080/", # ok with trailing '/' as well
963+
],
964+
)
965+
def test_chat_completion_base_url_works_with_v1(base_url: str):
966+
"""Test that `/v1/chat/completions` is correctly appended to the base URL.
967+
968+
This is a regression test for https://github.com/huggingface/huggingface_hub/issues/2414
969+
"""
970+
with patch("huggingface_hub.inference._client.InferenceClient.post") as post_mock:
971+
client = InferenceClient(base_url=base_url)
972+
post_mock.return_value = "{}"
973+
client.chat_completion(messages=CHAT_COMPLETION_MESSAGES, stream=False)
974+
assert post_mock.call_args_list[0].kwargs["model"] == "http://0.0.0.0:8080/v1/chat/completions"
975+
976+
956977
def test_stream_text_generation_response():
957978
data = [
958979
b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}',

0 commit comments

Comments
 (0)