|
16 | 16 | import time |
17 | 17 | import unittest |
18 | 18 | from pathlib import Path |
| 19 | +from typing import Optional |
19 | 20 | from unittest.mock import MagicMock, patch |
20 | 21 |
|
21 | 22 | import numpy as np |
@@ -918,27 +919,6 @@ def test_model_and_base_url_mutually_exclusive(self): |
918 | 919 | InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000") |
919 | 920 |
|
920 | 921 |
|
921 | | -@pytest.mark.parametrize( |
922 | | - "base_url", |
923 | | - [ |
924 | | - "http://0.0.0.0:8080/v1", # expected from OpenAI client |
925 | | - "http://0.0.0.0:8080", # but not mandatory |
926 | | - "http://0.0.0.0:8080/v1/", # ok with trailing '/' as well |
927 | | - "http://0.0.0.0:8080/", # ok with trailing '/' as well |
928 | | - ], |
929 | | -) |
930 | | -def test_chat_completion_base_url_works_with_v1(base_url: str): |
931 | | - """Test that `/v1/chat/completions` is correctly appended to the base URL. |
932 | | -
|
933 | | - This is a regression test for https://github.com/huggingface/huggingface_hub/issues/2414 |
934 | | - """ |
935 | | - with patch("huggingface_hub.inference._client.InferenceClient.post") as post_mock: |
936 | | - client = InferenceClient(base_url=base_url) |
937 | | - post_mock.return_value = "{}" |
938 | | - client.chat_completion(messages=CHAT_COMPLETION_MESSAGES, stream=False) |
939 | | - assert post_mock.call_args_list[0].kwargs["model"] == "http://0.0.0.0:8080/v1/chat/completions" |
940 | | - |
941 | | - |
942 | 922 | @pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "]) |
943 | 923 | def test_stream_text_generation_response(stop_signal: bytes): |
944 | 924 | data = [ |
@@ -970,3 +950,108 @@ def test_stream_chat_completion_response(stop_signal: bytes): |
970 | 950 | assert len(output) == 2 |
971 | 951 | assert output[0].choices[0].delta.content == "Both" |
972 | 952 | assert output[1].choices[0].delta.content == " Rust" |
| 953 | + |
| 954 | + |
| 955 | +INFERENCE_API_URL = "https://api-inference.huggingface.co/models" |
| 956 | +INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example |
| 957 | +LOCAL_TGI_URL = "http://0.0.0.0:8080" |
| 958 | + |
| 959 | + |
| 960 | +@pytest.mark.parametrize( |
| 961 | + ("client_model", "client_base_url", "model", "expected_url"), |
| 962 | + [ |
| 963 | + ( |
| 964 | + # Inference API - model global to client |
| 965 | + "username/repo_name", |
| 966 | + None, |
| 967 | + None, |
| 968 | + f"{INFERENCE_API_URL}/username/repo_name/v1/chat/completions", |
| 969 | + ), |
| 970 | + ( |
| 971 | + # Inference API - model specific to request |
| 972 | + None, |
| 973 | + None, |
| 974 | + "username/repo_name", |
| 975 | + f"{INFERENCE_API_URL}/username/repo_name/v1/chat/completions", |
| 976 | + ), |
| 977 | + ( |
| 978 | + # Inference Endpoint - url global to client as 'model' |
| 979 | + INFERENCE_ENDPOINT_URL, |
| 980 | + None, |
| 981 | + None, |
| 982 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 983 | + ), |
| 984 | + ( |
| 985 | + # Inference Endpoint - url global to client as 'base_url' |
| 986 | + None, |
| 987 | + INFERENCE_ENDPOINT_URL, |
| 988 | + None, |
| 989 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 990 | + ), |
| 991 | + ( |
| 992 | + # Inference Endpoint - url specific to request |
| 993 | + None, |
| 994 | + None, |
| 995 | + INFERENCE_ENDPOINT_URL, |
| 996 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 997 | + ), |
| 998 | + ( |
| 999 | + # Inference Endpoint - url global to client as 'base_url' - explicit model id |
| 1000 | + None, |
| 1001 | + INFERENCE_ENDPOINT_URL, |
| 1002 | + "username/repo_name", |
| 1003 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1004 | + ), |
| 1005 | + ( |
| 1006 | + # Inference Endpoint - full url global to client as 'model' |
| 1007 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1008 | + None, |
| 1009 | + None, |
| 1010 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1011 | + ), |
| 1012 | + ( |
| 1013 | + # Inference Endpoint - full url global to client as 'base_url' |
| 1014 | + None, |
| 1015 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1016 | + None, |
| 1017 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1018 | + ), |
| 1019 | + ( |
| 1020 | + # Inference Endpoint - full url specific to request |
| 1021 | + None, |
| 1022 | + None, |
| 1023 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1024 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1025 | + ), |
| 1026 | + ( |
| 1027 | + # Inference Endpoint - url with '/v1' (OpenAI compatibility) |
| 1028 | + # Regression test for https://github.com/huggingface/huggingface_hub/pull/2418 |
| 1029 | + None, |
| 1030 | + None, |
| 1031 | + f"{INFERENCE_ENDPOINT_URL}/v1", |
| 1032 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1033 | + ), |
| 1034 | + ( |
| 1035 | + # Inference Endpoint - url with '/v1/' (OpenAI compatibility) |
| 1036 | + # Regression test for https://github.com/huggingface/huggingface_hub/pull/2418 |
| 1037 | + None, |
| 1038 | + None, |
| 1039 | + f"{INFERENCE_ENDPOINT_URL}/v1/", |
| 1040 | + f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions", |
| 1041 | + ), |
| 1042 | + ( |
| 1043 | + # Local TGI with trailing '/v1' |
| 1044 | + # Regression test for https://github.com/huggingface/huggingface_hub/issues/2414 |
| 1045 | + f"{LOCAL_TGI_URL}/v1", # expected from OpenAI client |
| 1046 | + None, |
| 1047 | + None, |
| 1048 | + f"{LOCAL_TGI_URL}/v1/chat/completions", |
| 1049 | + ), |
| 1050 | + ], |
| 1051 | +) |
| 1052 | +def test_resolve_chat_completion_url( |
| 1053 | + client_model: Optional[str], client_base_url: Optional[str], model: Optional[str], expected_url: str |
| 1054 | +): |
| 1055 | + client = InferenceClient(model=client_model, base_url=client_base_url) |
| 1056 | + url = client._resolve_chat_completion_url(model) |
| 1057 | + assert url == expected_url |
0 commit comments