Skip to content

Commit 4cc7fee

Browse files
hanouticelinaWauplingithub-actions[bot]
authored
[Inference] Remove default params values for text generation (#3192)
* remove default params values for text generation * fix types * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * Apply style fixes --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 99baddf commit 4cc7fee

File tree

5 files changed

+72
-72
lines changed

5 files changed

+72
-72
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,23 +1858,23 @@ def text_classification(
18581858
return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value]
18591859

18601860
@overload
1861-
def text_generation( # type: ignore
1861+
def text_generation(
18621862
self,
18631863
prompt: str,
18641864
*,
1865-
details: Literal[False] = ...,
1866-
stream: Literal[False] = ...,
1865+
details: Literal[True],
1866+
stream: Literal[True],
18671867
model: Optional[str] = None,
18681868
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
18691869
adapter_id: Optional[str] = None,
18701870
best_of: Optional[int] = None,
18711871
decoder_input_details: Optional[bool] = None,
1872-
do_sample: Optional[bool] = False, # Manual default value
1872+
do_sample: Optional[bool] = None,
18731873
frequency_penalty: Optional[float] = None,
18741874
grammar: Optional[TextGenerationInputGrammarType] = None,
18751875
max_new_tokens: Optional[int] = None,
18761876
repetition_penalty: Optional[float] = None,
1877-
return_full_text: Optional[bool] = False, # Manual default value
1877+
return_full_text: Optional[bool] = None,
18781878
seed: Optional[int] = None,
18791879
stop: Optional[List[str]] = None,
18801880
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
@@ -1885,26 +1885,26 @@ def text_generation( # type: ignore
18851885
truncate: Optional[int] = None,
18861886
typical_p: Optional[float] = None,
18871887
watermark: Optional[bool] = None,
1888-
) -> str: ...
1888+
) -> Iterable[TextGenerationStreamOutput]: ...
18891889

18901890
@overload
1891-
def text_generation( # type: ignore
1891+
def text_generation(
18921892
self,
18931893
prompt: str,
18941894
*,
1895-
details: Literal[True] = ...,
1896-
stream: Literal[False] = ...,
1895+
details: Literal[True],
1896+
stream: Optional[Literal[False]] = None,
18971897
model: Optional[str] = None,
18981898
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
18991899
adapter_id: Optional[str] = None,
19001900
best_of: Optional[int] = None,
19011901
decoder_input_details: Optional[bool] = None,
1902-
do_sample: Optional[bool] = False, # Manual default value
1902+
do_sample: Optional[bool] = None,
19031903
frequency_penalty: Optional[float] = None,
19041904
grammar: Optional[TextGenerationInputGrammarType] = None,
19051905
max_new_tokens: Optional[int] = None,
19061906
repetition_penalty: Optional[float] = None,
1907-
return_full_text: Optional[bool] = False, # Manual default value
1907+
return_full_text: Optional[bool] = None,
19081908
seed: Optional[int] = None,
19091909
stop: Optional[List[str]] = None,
19101910
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
@@ -1918,23 +1918,23 @@ def text_generation( # type: ignore
19181918
) -> TextGenerationOutput: ...
19191919

19201920
@overload
1921-
def text_generation( # type: ignore
1921+
def text_generation(
19221922
self,
19231923
prompt: str,
19241924
*,
1925-
details: Literal[False] = ...,
1926-
stream: Literal[True] = ...,
1925+
details: Optional[Literal[False]] = None,
1926+
stream: Literal[True],
19271927
model: Optional[str] = None,
19281928
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
19291929
adapter_id: Optional[str] = None,
19301930
best_of: Optional[int] = None,
19311931
decoder_input_details: Optional[bool] = None,
1932-
do_sample: Optional[bool] = False, # Manual default value
1932+
do_sample: Optional[bool] = None,
19331933
frequency_penalty: Optional[float] = None,
19341934
grammar: Optional[TextGenerationInputGrammarType] = None,
19351935
max_new_tokens: Optional[int] = None,
19361936
repetition_penalty: Optional[float] = None,
1937-
return_full_text: Optional[bool] = False, # Manual default value
1937+
return_full_text: Optional[bool] = None, # Manual default value
19381938
seed: Optional[int] = None,
19391939
stop: Optional[List[str]] = None,
19401940
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
@@ -1948,23 +1948,23 @@ def text_generation( # type: ignore
19481948
) -> Iterable[str]: ...
19491949

19501950
@overload
1951-
def text_generation( # type: ignore
1951+
def text_generation(
19521952
self,
19531953
prompt: str,
19541954
*,
1955-
details: Literal[True] = ...,
1956-
stream: Literal[True] = ...,
1955+
details: Optional[Literal[False]] = None,
1956+
stream: Optional[Literal[False]] = None,
19571957
model: Optional[str] = None,
19581958
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
19591959
adapter_id: Optional[str] = None,
19601960
best_of: Optional[int] = None,
19611961
decoder_input_details: Optional[bool] = None,
1962-
do_sample: Optional[bool] = False, # Manual default value
1962+
do_sample: Optional[bool] = None,
19631963
frequency_penalty: Optional[float] = None,
19641964
grammar: Optional[TextGenerationInputGrammarType] = None,
19651965
max_new_tokens: Optional[int] = None,
19661966
repetition_penalty: Optional[float] = None,
1967-
return_full_text: Optional[bool] = False, # Manual default value
1967+
return_full_text: Optional[bool] = None,
19681968
seed: Optional[int] = None,
19691969
stop: Optional[List[str]] = None,
19701970
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
@@ -1975,26 +1975,26 @@ def text_generation( # type: ignore
19751975
truncate: Optional[int] = None,
19761976
typical_p: Optional[float] = None,
19771977
watermark: Optional[bool] = None,
1978-
) -> Iterable[TextGenerationStreamOutput]: ...
1978+
) -> str: ...
19791979

19801980
@overload
19811981
def text_generation(
19821982
self,
19831983
prompt: str,
19841984
*,
1985-
details: Literal[True] = ...,
1986-
stream: bool = ...,
1985+
details: Optional[bool] = None,
1986+
stream: Optional[bool] = None,
19871987
model: Optional[str] = None,
19881988
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
19891989
adapter_id: Optional[str] = None,
19901990
best_of: Optional[int] = None,
19911991
decoder_input_details: Optional[bool] = None,
1992-
do_sample: Optional[bool] = False, # Manual default value
1992+
do_sample: Optional[bool] = None,
19931993
frequency_penalty: Optional[float] = None,
19941994
grammar: Optional[TextGenerationInputGrammarType] = None,
19951995
max_new_tokens: Optional[int] = None,
19961996
repetition_penalty: Optional[float] = None,
1997-
return_full_text: Optional[bool] = False, # Manual default value
1997+
return_full_text: Optional[bool] = None,
19981998
seed: Optional[int] = None,
19991999
stop: Optional[List[str]] = None,
20002000
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
@@ -2005,25 +2005,25 @@ def text_generation(
20052005
truncate: Optional[int] = None,
20062006
typical_p: Optional[float] = None,
20072007
watermark: Optional[bool] = None,
2008-
) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ...
2008+
) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]: ...
20092009

20102010
def text_generation(
20112011
self,
20122012
prompt: str,
20132013
*,
2014-
details: bool = False,
2015-
stream: bool = False,
2014+
details: Optional[bool] = None,
2015+
stream: Optional[bool] = None,
20162016
model: Optional[str] = None,
20172017
# Parameters from `TextGenerationInputGenerateParameters` (maintained manually)
20182018
adapter_id: Optional[str] = None,
20192019
best_of: Optional[int] = None,
20202020
decoder_input_details: Optional[bool] = None,
2021-
do_sample: Optional[bool] = False, # Manual default value
2021+
do_sample: Optional[bool] = None,
20222022
frequency_penalty: Optional[float] = None,
20232023
grammar: Optional[TextGenerationInputGrammarType] = None,
20242024
max_new_tokens: Optional[int] = None,
20252025
repetition_penalty: Optional[float] = None,
2026-
return_full_text: Optional[bool] = False, # Manual default value
2026+
return_full_text: Optional[bool] = None,
20272027
seed: Optional[int] = None,
20282028
stop: Optional[List[str]] = None,
20292029
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
@@ -2101,7 +2101,7 @@ def text_generation(
21012101
typical_p (`float`, *optional`):
21022102
Typical Decoding mass
21032103
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
2104-
watermark (`bool`, *optional`):
2104+
watermark (`bool`, *optional*):
21052105
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
21062106
21072107
Returns:
@@ -2251,7 +2251,7 @@ def text_generation(
22512251
"repetition_penalty": repetition_penalty,
22522252
"return_full_text": return_full_text,
22532253
"seed": seed,
2254-
"stop": stop if stop is not None else [],
2254+
"stop": stop,
22552255
"temperature": temperature,
22562256
"top_k": top_k,
22572257
"top_n_tokens": top_n_tokens,
@@ -2305,7 +2305,7 @@ def text_generation(
23052305

23062306
# Handle errors separately for more precise error messages
23072307
try:
2308-
bytes_output = self._inner_post(request_parameters, stream=stream)
2308+
bytes_output = self._inner_post(request_parameters, stream=stream or False)
23092309
except HTTPError as e:
23102310
match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e))
23112311
if isinstance(e, BadRequestError) and match:

0 commit comments

Comments
 (0)