Skip to content

Commit 7728349

Browse files
Fix payload model name when model id is a URL (#2911)
* fix default model name when model id is a URL * better * Update test Co-authored-by: Lucain <[email protected]> --------- Co-authored-by: Lucain <[email protected]>
1 parent 12f81c1 commit 7728349

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def __init__(self):
8484
super().__init__("text-generation")
8585

8686
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
87-
payload_model = "tgi" if mapped_model.startswith(("http://", "https://")) else mapped_model
87+
payload_model = parameters.get("model") or mapped_model
88+
89+
if payload_model is None or payload_model.startswith(("http://", "https://")):
90+
payload_model = "dummy"
91+
8892
return {**filter_none(parameters), "model": payload_model, "messages": inputs}
8993

9094
def _prepare_url(self, api_key: str, mapped_model: str) -> str:

tests/test_inference_providers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,53 @@ def test_prepare_request_conversational(self):
324324
"messages": [{"role": "user", "content": "dummy text input"}],
325325
}
326326

327+
@pytest.mark.parametrize(
328+
"mapped_model,parameters,expected_model",
329+
[
330+
(
331+
"username/repo_name",
332+
{},
333+
"username/repo_name",
334+
),
335+
# URL endpoint with model in parameters - use model from parameters
336+
(
337+
"http://localhost:8000/v1/chat/completions",
338+
{"model": "username/repo_name"},
339+
"username/repo_name",
340+
),
341+
# URL endpoint without model - fallback to dummy
342+
(
343+
"http://localhost:8000/v1/chat/completions",
344+
{},
345+
"dummy",
346+
),
347+
# HTTPS endpoint with model in parameters
348+
(
349+
"https://api.example.com/v1/chat/completions",
350+
{"model": "username/repo_name"},
351+
"username/repo_name",
352+
),
353+
# URL endpoint with other parameters - should still use dummy
354+
(
355+
"http://localhost:8000/v1/chat/completions",
356+
{"temperature": 0.7, "max_tokens": 100},
357+
"dummy",
358+
),
359+
],
360+
)
361+
def test_prepare_payload_as_dict_conversational(self, mapped_model, parameters, expected_model):
362+
helper = HFInferenceConversational()
363+
messages = [{"role": "user", "content": "Hello!"}]
364+
365+
payload = helper._prepare_payload_as_dict(
366+
inputs=messages,
367+
parameters=parameters,
368+
mapped_model=mapped_model,
369+
)
370+
371+
assert payload["model"] == expected_model
372+
assert payload["messages"] == messages
373+
327374

328375
class TestHyperbolicProvider:
329376
def test_prepare_route(self):

0 commit comments

Comments
 (0)