Skip to content

Commit a8467cd

Browse files
hanouticelinaWauplin
authored andcommitted
fix text generation (#2982)
1 parent c9f9ad2 commit a8467cd

File tree

5 files changed

+34
-4
lines changed

5 files changed

+34
-4
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,8 +2400,8 @@ def text_generation(
24002400
# Data can be a single element (dict) or an iterable of dicts where we select the first element of.
24012401
if isinstance(data, list):
24022402
data = data[0]
2403-
2404-
return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2403+
response = provider_helper.get_response(data, request_parameters)
2404+
return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"]
24052405

24062406
def text_to_image(
24072407
self,

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,8 +2456,8 @@ async def text_generation(
24562456
# Data can be a single element (dict) or an iterable of dicts where we select the first element of.
24572457
if isinstance(data, list):
24582458
data = data[0]
2459-
2460-
return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"]
2459+
response = provider_helper.get_response(data, request_parameters)
2460+
return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"]
24612461

24622462
async def text_to_image(
24632463
self,

src/huggingface_hub/inference/_providers/nebius.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ class NebiusTextGenerationTask(BaseTextGenerationTask):
1414
def __init__(self):
1515
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")
1616

17+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
18+
output = _as_dict(response)["choices"][0]
19+
return {
20+
"generated_text": output["text"],
21+
"details": {
22+
"finish_reason": output.get("finish_reason"),
23+
"seed": output.get("seed"),
24+
},
25+
}
26+
1727

1828
class NebiusConversationalTask(BaseConversationalTask):
1929
def __init__(self):

src/huggingface_hub/inference/_providers/novita.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
2222
# there is no v1/ route for novita
2323
return "/v3/openai/completions"
2424

25+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
26+
output = _as_dict(response)["choices"][0]
27+
return {
28+
"generated_text": output["text"],
29+
"details": {
30+
"finish_reason": output.get("finish_reason"),
31+
"seed": output.get("seed"),
32+
},
33+
}
34+
2535

2636
class NovitaConversationalTask(BaseConversationalTask):
2737
def __init__(self):

src/huggingface_hub/inference/_providers/together.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ class TogetherTextGenerationTask(BaseTextGenerationTask):
3535
def __init__(self):
3636
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
3737

38+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
39+
output = _as_dict(response)["choices"][0]
40+
return {
41+
"generated_text": output["text"],
42+
"details": {
43+
"finish_reason": output.get("finish_reason"),
44+
"seed": output.get("seed"),
45+
},
46+
}
47+
3848

3949
class TogetherConversationalTask(BaseConversationalTask):
4050
def __init__(self):

0 commit comments

Comments
 (0)