Skip to content

Commit 83ba43c

Browse files
[inference] Necessary breaking change: nest task-specific route inside of model route (#3044)
* nest task-specific route inside of model route * add unit tests --------- Co-authored-by: Celina Hanouti <[email protected]>
1 parent 557576d commit 83ba43c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
4242
return mapped_model
4343
return (
4444
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
45-
f"{self.base_url}/pipeline/{self.task}/{mapped_model}"
45+
f"{self.base_url}/models/{mapped_model}/pipeline/{self.task}"
4646
if self.task in ("feature-extraction", "sentence-similarity")
4747
# Otherwise, we use the default endpoint
4848
else f"{self.base_url}/models/{mapped_model}"

tests/test_inference_providers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,20 @@ def test_prepare_url(self):
435435

436436
assert helper._prepare_url("hf_test_token", "https://any-url.com") == "https://any-url.com"
437437

438+
def test_prepare_url_feature_extraction(self):
439+
helper = HFInferenceTask("feature-extraction")
440+
assert (
441+
helper._prepare_url("hf_test_token", "username/repo_name")
442+
== "https://router.huggingface.co/hf-inference/models/username/repo_name/pipeline/feature-extraction"
443+
)
444+
445+
def test_prepare_url_sentence_similarity(self):
446+
helper = HFInferenceTask("sentence-similarity")
447+
assert (
448+
helper._prepare_url("hf_test_token", "username/repo_name")
449+
== "https://router.huggingface.co/hf-inference/models/username/repo_name/pipeline/sentence-similarity"
450+
)
451+
438452
def test_prepare_payload_as_dict(self):
439453
helper = HFInferenceTask("text-classification")
440454
mapping_info = InferenceProviderMapping(

0 commit comments

Comments
 (0)