4747 hf_hub_download ,
4848)
4949from huggingface_hub .constants import ALL_INFERENCE_API_FRAMEWORKS , MAIN_INFERENCE_API_FRAMEWORKS
50- from huggingface_hub .errors import HfHubHTTPError
50+ from huggingface_hub .errors import HfHubHTTPError , ValidationError
5151from huggingface_hub .inference ._client import _open_as_binary
5252from huggingface_hub .inference ._common import (
5353 _stream_chat_completion_response ,
@@ -919,7 +919,14 @@ def test_model_and_base_url_mutually_exclusive(self):
919919 InferenceClient (model = "meta-llama/Meta-Llama-3-8B-Instruct" , base_url = "http://127.0.0.1:8000" )
920920
921921
922- @pytest .mark .parametrize ("stop_signal" , [b"data: [DONE]" , b"data: [DONE]\n " , b"data: [DONE] " ])
922+ @pytest .mark .parametrize (
923+ "stop_signal" ,
924+ [
925+ b"data: [DONE]" ,
926+ b"data: [DONE]\n " ,
927+ b"data: [DONE] " ,
928+ ],
929+ )
923930def test_stream_text_generation_response (stop_signal : bytes ):
924931 data = [
925932 b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}' ,
@@ -935,7 +942,14 @@ def test_stream_text_generation_response(stop_signal: bytes):
935942 assert output == [" trying" , " to" ]
936943
937944
938- @pytest .mark .parametrize ("stop_signal" , [b"data: [DONE]" , b"data: [DONE]\n " , b"data: [DONE] " ])
945+ @pytest .mark .parametrize (
946+ "stop_signal" ,
947+ [
948+ b"data: [DONE]" ,
949+ b"data: [DONE]\n " ,
950+ b"data: [DONE] " ,
951+ ],
952+ )
939953def test_stream_chat_completion_response (stop_signal : bytes ):
940954 data = [
941955 b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}' ,
@@ -952,6 +966,20 @@ def test_stream_chat_completion_response(stop_signal: bytes):
952966 assert output [1 ].choices [0 ].delta .content == " Rust"
953967
954968
969+ def test_chat_completion_error_in_stream ():
970+ """
971+ Regression test for https://github.com/huggingface/huggingface_hub/issues/2514.
972+ When an error is encountered in the stream, it should raise a TextGenerationError (e.g. a ValidationError).
973+ """
974+ data = [
975+ b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}' ,
976+ b'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}' ,
977+ ]
978+ with pytest .raises (ValidationError ):
979+ for token in _stream_chat_completion_response (data ):
980+ pass
981+
982+
955983INFERENCE_API_URL = "https://api-inference.huggingface.co/models"
956984INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example
957985LOCAL_TGI_URL = "http://0.0.0.0:8080"
0 commit comments