Skip to content

Commit 08c96ba

Browse files
committed
fix huggingface hub tests
1 parent b102cec commit 08c96ba

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

tests/integrations/huggingface_hub/test_huggingface_hub.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
from unittest import mock # python 3.3 and above
1313

1414

15+
def mock_client_post(client, post_mock):
16+
# huggingface-hub==0.28.0 deprecates the `post` method
17+
# so patch `_inner_post` instead
18+
client.post = post_mock
19+
client._inner_post = post_mock
20+
21+
1522
@pytest.mark.parametrize(
1623
"send_default_pii, include_prompts, details_arg",
1724
itertools.product([True, False], repeat=3),
@@ -28,7 +35,7 @@ def test_nonstreaming_chat_completion(
2835

2936
client = InferenceClient("some-model")
3037
if details_arg:
31-
client.post = mock.Mock(
38+
post_mock = mock.Mock(
3239
return_value=b"""[{
3340
"generated_text": "the model response",
3441
"details": {
@@ -40,9 +47,11 @@ def test_nonstreaming_chat_completion(
4047
}]"""
4148
)
4249
else:
43-
client.post = mock.Mock(
50+
post_mock = mock.Mock(
4451
return_value=b'[{"generated_text": "the model response"}]'
4552
)
53+
mock_client_post(client, post_mock)
54+
4655
with start_transaction(name="huggingface_hub tx"):
4756
response = client.text_generation(
4857
prompt="hello",
@@ -84,7 +93,8 @@ def test_streaming_chat_completion(
8493
events = capture_events()
8594

8695
client = InferenceClient("some-model")
87-
client.post = mock.Mock(
96+
97+
post_mock = mock.Mock(
8898
return_value=[
8999
b"""data:{
90100
"token":{"id":1, "special": false, "text": "the model "}
@@ -95,6 +105,8 @@ def test_streaming_chat_completion(
95105
}""",
96106
]
97107
)
108+
mock_client_post(client, post_mock)
109+
98110
with start_transaction(name="huggingface_hub tx"):
99111
response = list(
100112
client.text_generation(
@@ -131,7 +143,9 @@ def test_bad_chat_completion(sentry_init, capture_events):
131143
events = capture_events()
132144

133145
client = InferenceClient("some-model")
134-
client.post = mock.Mock(side_effect=OverloadedError("The server is overloaded"))
146+
post_mock = mock.Mock(side_effect=OverloadedError("The server is overloaded"))
147+
mock_client_post(client, post_mock)
148+
135149
with pytest.raises(OverloadedError):
136150
client.text_generation(prompt="hello")
137151

@@ -147,13 +161,15 @@ def test_span_origin(sentry_init, capture_events):
147161
events = capture_events()
148162

149163
client = InferenceClient("some-model")
150-
client.post = mock.Mock(
164+
post_mock = mock.Mock(
151165
return_value=[
152166
b"""data:{
153167
"token":{"id":1, "special": false, "text": "the model "}
154168
}""",
155169
]
156170
)
171+
mock_client_post(client, post_mock)
172+
157173
with start_transaction(name="huggingface_hub tx"):
158174
list(
159175
client.text_generation(

0 commit comments

Comments
 (0)