1212from unittest import mock # python 3.3 and above
1313
1414
15+ def mock_client_post (client , post_mock ):
16+ # huggingface-hub==0.28.0 deprecates the `post` method
17+ # so patch `_inner_post` instead
18+ client .post = post_mock
19+ client ._inner_post = post_mock
20+
21+
1522@pytest .mark .parametrize (
1623 "send_default_pii, include_prompts, details_arg" ,
1724 itertools .product ([True , False ], repeat = 3 ),
@@ -28,7 +35,7 @@ def test_nonstreaming_chat_completion(
2835
2936 client = InferenceClient ("some-model" )
3037 if details_arg :
31- client . post = mock .Mock (
38+ post_mock = mock .Mock (
3239 return_value = b"""[{
3340 "generated_text": "the model response",
3441 "details": {
@@ -40,9 +47,11 @@ def test_nonstreaming_chat_completion(
4047 }]"""
4148 )
4249 else :
43- client . post = mock .Mock (
50+ post_mock = mock .Mock (
4451 return_value = b'[{"generated_text": "the model response"}]'
4552 )
53+ mock_client_post (client , post_mock )
54+
4655 with start_transaction (name = "huggingface_hub tx" ):
4756 response = client .text_generation (
4857 prompt = "hello" ,
@@ -84,7 +93,8 @@ def test_streaming_chat_completion(
8493 events = capture_events ()
8594
8695 client = InferenceClient ("some-model" )
87- client .post = mock .Mock (
96+
97+ post_mock = mock .Mock (
8898 return_value = [
8999 b"""data:{
90100 "token":{"id":1, "special": false, "text": "the model "}
@@ -95,6 +105,8 @@ def test_streaming_chat_completion(
95105 }""" ,
96106 ]
97107 )
108+ mock_client_post (client , post_mock )
109+
98110 with start_transaction (name = "huggingface_hub tx" ):
99111 response = list (
100112 client .text_generation (
@@ -131,7 +143,9 @@ def test_bad_chat_completion(sentry_init, capture_events):
131143 events = capture_events ()
132144
133145 client = InferenceClient ("some-model" )
134- client .post = mock .Mock (side_effect = OverloadedError ("The server is overloaded" ))
146+ post_mock = mock .Mock (side_effect = OverloadedError ("The server is overloaded" ))
147+ mock_client_post (client , post_mock )
148+
135149 with pytest .raises (OverloadedError ):
136150 client .text_generation (prompt = "hello" )
137151
@@ -147,13 +161,15 @@ def test_span_origin(sentry_init, capture_events):
147161 events = capture_events ()
148162
149163 client = InferenceClient ("some-model" )
150- client . post = mock .Mock (
164+ post_mock = mock .Mock (
151165 return_value = [
152166 b"""data:{
153167 "token":{"id":1, "special": false, "text": "the model "}
154168 }""" ,
155169 ]
156170 )
171+ mock_client_post (client , post_mock )
172+
157173 with start_transaction (name = "huggingface_hub tx" ):
158174 list (
159175 client .text_generation (
0 commit comments