Skip to content

Commit cd85541

Browse files
Fix Inference Client VCR tests (#2858)
1 parent 7553646 commit cd85541

File tree

36 files changed

+51831
-101305
lines changed

36 files changed

+51831
-101305
lines changed

.github/workflows/python-tests.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
[
2727
"Repository only",
2828
"Everything else",
29+
"Inference only"
2930

3031
]
3132
include:
@@ -64,7 +65,7 @@ jobs:
6465
6566
case "${{ matrix.test_name }}" in
6667
67-
"Repository only" | "Everything else")
68+
"Repository only" | "Everything else" | "Inference only")
6869
sudo apt update
6970
sudo apt install -y libsndfile1-dev
7071
;;
@@ -112,8 +113,15 @@ jobs:
112113
eval $PYTEST
113114
;;
114115
116+
"Inference only")
117+
# Run inference tests concurrently
118+
PYTEST="$PYTEST ../tests -k 'test_inference' -n 4"
119+
echo $PYTEST
120+
eval $PYTEST
121+
;;
122+
115123
"Everything else")
116-
PYTEST="$PYTEST ../tests -k 'not TestRepository' -n 4"
124+
PYTEST="$PYTEST ../tests -k 'not TestRepository and not test_inference' -n 4"
117125
echo $PYTEST
118126
eval $PYTEST
119127
;;

src/huggingface_hub/inference/_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
from requests import HTTPError
4242

43-
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
43+
from huggingface_hub import constants
4444
from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
4545
from huggingface_hub.inference._common import (
4646
TASKS_EXPECTING_IMAGES,
@@ -3300,9 +3300,9 @@ def list_deployed_models(
33003300

33013301
# Resolve which frameworks to check
33023302
if frameworks is None:
3303-
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
3303+
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
33043304
elif frameworks == "all":
3305-
frameworks = ALL_INFERENCE_API_FRAMEWORKS
3305+
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
33063306
elif isinstance(frameworks, str):
33073307
frameworks = [frameworks]
33083308
frameworks = list(set(frameworks))
@@ -3322,7 +3322,7 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:
33223322

33233323
for framework in frameworks:
33243324
response = get_session().get(
3325-
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
3325+
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
33263326
)
33273327
hf_raise_for_status(response)
33283328
_unpack_response(framework, response.json())
@@ -3384,7 +3384,7 @@ def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
33843384
if model.startswith(("http://", "https://")):
33853385
url = model.rstrip("/") + "/info"
33863386
else:
3387-
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
3387+
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"
33883388

33893389
response = get_session().get(url, headers=build_hf_headers(token=self.token))
33903390
hf_raise_for_status(response)
@@ -3472,7 +3472,7 @@ def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
34723472
raise ValueError("Model id not provided.")
34733473
if model.startswith("https://"):
34743474
raise NotImplementedError("Model status is only available for Inference API endpoints.")
3475-
url = f"{INFERENCE_ENDPOINT}/status/{model}"
3475+
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
34763476

34773477
response = get_session().get(url, headers=build_hf_headers(token=self.token))
34783478
hf_raise_for_status(response)

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import warnings
2626
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload
2727

28-
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
28+
from huggingface_hub import constants
2929
from huggingface_hub.errors import InferenceTimeoutError
3030
from huggingface_hub.inference._common import (
3131
TASKS_EXPECTING_IMAGES,
@@ -3365,9 +3365,9 @@ async def list_deployed_models(
33653365

33663366
# Resolve which frameworks to check
33673367
if frameworks is None:
3368-
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
3368+
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
33693369
elif frameworks == "all":
3370-
frameworks = ALL_INFERENCE_API_FRAMEWORKS
3370+
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
33713371
elif isinstance(frameworks, str):
33723372
frameworks = [frameworks]
33733373
frameworks = list(set(frameworks))
@@ -3387,7 +3387,7 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:
33873387

33883388
for framework in frameworks:
33893389
response = get_session().get(
3390-
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
3390+
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
33913391
)
33923392
hf_raise_for_status(response)
33933393
_unpack_response(framework, response.json())
@@ -3491,7 +3491,7 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A
34913491
if model.startswith(("http://", "https://")):
34923492
url = model.rstrip("/") + "/info"
34933493
else:
3494-
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
3494+
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"
34953495

34963496
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
34973497
response = await client.get(url, proxy=self.proxies)
@@ -3583,7 +3583,7 @@ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
35833583
raise ValueError("Model id not provided.")
35843584
if model.startswith("https://"):
35853585
raise NotImplementedError("Model status is only available for Inference API endpoints.")
3586-
url = f"{INFERENCE_ENDPOINT}/status/{model}"
3586+
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"
35873587

35883588
async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
35893589
response = await client.get(url, proxy=self.proxies)

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@
4646
"image-classification": HFInferenceBinaryInputTask("image-classification"),
4747
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
4848
"document-question-answering": HFInferenceTask("document-question-answering"),
49-
"image-to-text": HFInferenceTask("image-to-text"),
49+
"image-to-text": HFInferenceBinaryInputTask("image-to-text"),
5050
"object-detection": HFInferenceBinaryInputTask("object-detection"),
51-
"audio-to-audio": HFInferenceTask("audio-to-audio"),
51+
"audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
5252
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
5353
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
5454
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# Example:
1919
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
2020
"fal-ai": {},
21+
"fireworks-ai": {},
2122
"hf-inference": {},
2223
"replicate": {},
2324
"sambanova": {},
@@ -65,12 +66,12 @@ def prepare_request(
6566
url = self._prepare_url(api_key, mapped_model)
6667

6768
# prepare payload (to customize in subclasses)
68-
payload = self._prepare_payload(inputs, parameters, mapped_model=mapped_model)
69+
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
6970
if payload is not None:
7071
payload = recursive_merge(payload, extra_payload or {})
7172

7273
# body data (to customize in subclasses)
73-
data = self._prepare_body(inputs, parameters, mapped_model, extra_payload)
74+
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
7475

7576
# check if both payload and data are set and return
7677
if payload is not None and data is not None:
@@ -159,21 +160,21 @@ def _prepare_route(self, mapped_model: str) -> str:
159160
"""
160161
return ""
161162

162-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
163+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
163164
"""Return the payload to use for the request, as a dict.
164165
165166
Override this method in subclasses for customized payloads.
166-
Only one of `_prepare_payload` and `_prepare_body` should return a value.
167+
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
167168
"""
168169
return None
169170

170-
def _prepare_body(
171+
def _prepare_payload_as_bytes(
171172
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
172173
) -> Optional[bytes]:
173174
"""Return the body to use for the request, as bytes.
174175
175176
Override this method in subclasses for customized body data.
176-
Only one of `_prepare_payload` and `_prepare_body` should return a value.
177+
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
177178
"""
178179
return None
179180

@@ -183,9 +184,9 @@ def _fetch_inference_provider_mapping(model: str) -> Dict:
183184
"""
184185
Fetch provider mappings for a model from the Hub.
185186
"""
186-
from huggingface_hub.hf_api import model_info
187+
from huggingface_hub.hf_api import HfApi
187188

188-
info = model_info(model, expand=["inferenceProviderMapping"])
189+
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
189190
provider_mapping = info.inference_provider_mapping
190191
if provider_mapping is None:
191192
raise ValueError(f"No provider mapping found for model {model}")

src/huggingface_hub/inference/_providers/fal_ai.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
2525
def __init__(self):
2626
super().__init__("automatic-speech-recognition")
2727

28-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
28+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
2929
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
3030
# If input is a URL, pass it directly
3131
audio_url = inputs
@@ -52,7 +52,7 @@ class FalAITextToImageTask(FalAITask):
5252
def __init__(self):
5353
super().__init__("text-to-image")
5454

55-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
55+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
5656
parameters = filter_none(parameters)
5757
if "width" in parameters and "height" in parameters:
5858
parameters["image_size"] = {
@@ -70,7 +70,7 @@ class FalAITextToSpeechTask(FalAITask):
7070
def __init__(self):
7171
super().__init__("text-to-speech")
7272

73-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
73+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
7474
return {"lyrics": inputs, **filter_none(parameters)}
7575

7676
def get_response(self, response: Union[bytes, Dict]) -> Any:
@@ -82,7 +82,7 @@ class FalAITextToVideoTask(FalAITask):
8282
def __init__(self):
8383
super().__init__("text-to-video")
8484

85-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
85+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
8686
return {"prompt": inputs, **filter_none(parameters)}
8787

8888
def get_response(self, response: Union[bytes, Dict]) -> Any:

src/huggingface_hub/inference/_providers/fireworks_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def __init__(self):
1010
def _prepare_route(self, mapped_model: str) -> str:
1111
return "/v1/chat/completions"
1212

13-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
13+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
1414
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
4646
else f"{self.base_url}/models/{mapped_model}"
4747
)
4848

49-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
49+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
5050
if isinstance(inputs, bytes):
5151
raise ValueError(f"Unexpected binary input for task {self.task}.")
5252
if isinstance(inputs, Path):
@@ -55,7 +55,10 @@ def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) ->
5555

5656

5757
class HFInferenceBinaryInputTask(HFInferenceTask):
58-
def _prepare_body(
58+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
59+
return None
60+
61+
def _prepare_payload_as_bytes(
5962
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
6063
) -> Optional[bytes]:
6164
parameters = filter_none({k: v for k, v in parameters.items() if v is not None})
@@ -80,7 +83,7 @@ class HFInferenceConversational(HFInferenceTask):
8083
def __init__(self):
8184
super().__init__("text-generation")
8285

83-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
86+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
8487
payload_model = "tgi" if mapped_model.startswith(("http://", "https://")) else mapped_model
8588
return {**filter_none(parameters), "model": payload_model, "messages": inputs}
8689

src/huggingface_hub/inference/_providers/new_provider.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Before adding a new provider to the `huggingface_hub` library, make sure it has
66

77
Create a new file under `src/huggingface_hub/inference/_providers/{provider_name}.py` and copy-paste the following snippet.
88

9-
Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload` or `_prepare_body` must be overwritten.
9+
Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload_as_dict` or `_prepare_payload_as_bytes` must be overwritten.
1010

1111
If the provider supports multiple tasks that require different implementations, create dedicated subclasses for each task, following the pattern shown in `fal_ai.py`.
1212

@@ -42,23 +42,24 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper):
4242
"""
4343
return super()._prepare_route(mapped_model)
4444

45-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
45+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
4646
"""Return the payload to use for the request, as a dict.
4747
4848
Override this method in subclasses for customized payloads.
49-
Only one of `_prepare_payload` and `_prepare_body` should return a value.
49+
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
5050
"""
51-
return super()._prepare_payload(inputs, parameters, mapped_model)
51+
return super()._prepare_payload_as_dict(inputs, parameters, mapped_model)
5252

53-
def _prepare_body(
53+
def _prepare_payload_as_bytes(
5454
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
5555
) -> Optional[bytes]:
5656
"""Return the body to use for the request, as bytes.
5757
5858
Override this method in subclasses for customized body data.
59-
Only one of `_prepare_payload` and `_prepare_body` should return a value.
59+
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
6060
"""
61-
return super()._prepare_body(inputs, parameters, mapped_model, extra_payload)
61+
return super()._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
62+
6263
```
6364

6465
### 2. Register the provider helper in `__init__.py`

src/huggingface_hub/inference/_providers/replicate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _prepare_route(self, mapped_model: str) -> str:
1919
return "/v1/predictions"
2020
return f"/v1/models/{mapped_model}/predictions"
2121

22-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
22+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
2323
payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}}
2424
if ":" in mapped_model:
2525
version = mapped_model.split(":", 1)[1]
@@ -43,7 +43,7 @@ class ReplicateTextToSpeechTask(ReplicateTask):
4343
def __init__(self):
4444
super().__init__("text-to-speech")
4545

46-
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
47-
payload: Dict = super()._prepare_payload(inputs, parameters, mapped_model) # type: ignore[assignment]
46+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
47+
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, mapped_model) # type: ignore[assignment]
4848
payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
4949
return payload

0 commit comments

Comments
 (0)