Skip to content

Commit ea97f56

Browse files
committed
fix up test_text_generation
1 parent 8a8bc5e commit ea97f56

File tree

1 file changed

+77
-142
lines changed

1 file changed

+77
-142
lines changed

tests/integrations/huggingface_hub/test_huggingface_hub.py

Lines changed: 77 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -35,97 +35,78 @@
3535

3636

3737
@pytest.fixture
38-
def mock_hf_text_generation_api():
38+
def mock_hf_text_generation_api(httpx_mock):
3939
# type: () -> Any
4040
"""Mock HuggingFace text generation API"""
41-
with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
42-
model_name = "test-model"
41+
model_name = "test-model"
42+
43+
model_endpoint_response = {
44+
"id": model_name,
45+
"pipeline_tag": "text-generation",
46+
"inferenceProviderMapping": {
47+
"hf-inference": {
48+
"status": "live",
49+
"providerId": model_name,
50+
"task": "text-generation",
51+
}
52+
},
53+
}
54+
inference_endpoint_response = {
55+
"generated_text": "[mocked] Hello! How can i help you?",
56+
"details": {
57+
"finish_reason": "length",
58+
"generated_tokens": 10,
59+
"prefill": [],
60+
"tokens": [],
61+
},
62+
}
4363

64+
if HF_VERSION >= (1, 0, 0):
4465
# Mock model info endpoint
45-
rsps.add(
46-
responses.GET,
47-
MODEL_ENDPOINT.format(model_name=model_name),
48-
json={
49-
"id": model_name,
50-
"pipeline_tag": "text-generation",
51-
"inferenceProviderMapping": {
52-
"hf-inference": {
53-
"status": "live",
54-
"providerId": model_name,
55-
"task": "text-generation",
56-
}
57-
},
58-
},
59-
status=200,
66+
httpx_mock.add_response(
67+
method="GET",
68+
url=re.compile(
69+
MODEL_ENDPOINT.format(model_name=model_name)
70+
+ r"(\?expand=inferenceProviderMapping)?"
71+
),
72+
json=model_endpoint_response,
73+
status_code=200,
74+
is_optional=True,
75+
is_reusable=True,
6076
)
6177

6278
# Mock text generation endpoint
63-
rsps.add(
64-
responses.POST,
65-
INFERENCE_ENDPOINT.format(model_name=model_name),
66-
json={
67-
"generated_text": "[mocked] Hello! How can i help you?",
68-
"details": {
69-
"finish_reason": "length",
70-
"generated_tokens": 10,
71-
"prefill": [],
72-
"tokens": [],
73-
},
74-
},
75-
status=200,
76-
)
77-
78-
yield rsps
79-
80-
81-
@pytest.fixture
82-
def mock_hf_text_generation_api_httpx(httpx_mock):
83-
# type: (Any) -> Any
84-
"""Mock HuggingFace text generation API for httpx"""
85-
model_name = "test-model"
86-
87-
# Mock model info endpoint (with query parameters) - allow multiple calls
88-
# Using pattern matching to handle query parameters
89-
model_url_pattern = re.compile(
90-
re.escape(MODEL_ENDPOINT.format(model_name=model_name)) + r"(\?.*)?$"
91-
)
92-
93-
# Add exactly the number of responses we expect (2 calls observed from debug output)
94-
for _ in range(2): # Allow exactly 2 calls to the model info endpoint
9579
httpx_mock.add_response(
96-
method="GET",
97-
url=model_url_pattern,
98-
json={
99-
"id": model_name,
100-
"pipeline_tag": "text-generation",
101-
"inferenceProviderMapping": {
102-
"hf-inference": {
103-
"status": "live",
104-
"providerId": model_name,
105-
"task": "text-generation",
106-
}
107-
},
108-
},
80+
method="POST",
81+
url=INFERENCE_ENDPOINT.format(model_name=model_name),
82+
json=inference_endpoint_response,
10983
status_code=200,
84+
is_optional=True,
85+
is_reusable=True,
11086
)
11187

112-
# Mock text generation endpoint
113-
httpx_mock.add_response(
114-
method="POST",
115-
url=INFERENCE_ENDPOINT.format(model_name=model_name),
116-
json={
117-
"generated_text": "[mocked] Hello! How can i help you?",
118-
"details": {
119-
"finish_reason": "length",
120-
"generated_tokens": 10,
121-
"prefill": [],
122-
"tokens": [],
123-
},
124-
},
125-
status_code=200,
126-
)
88+
yield httpx_mock
12789

128-
return httpx_mock
90+
else:
91+
# Older version of huggingface_hub, we need to mock requests
92+
with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps:
93+
# Mock model info endpoint
94+
rsps.add(
95+
responses.GET,
96+
MODEL_ENDPOINT.format(model_name=model_name),
97+
json=model_endpoint_response,
98+
status=200,
99+
)
100+
101+
# Mock text generation endpoint
102+
rsps.add(
103+
responses.POST,
104+
INFERENCE_ENDPOINT.format(model_name=model_name),
105+
json=inference_endpoint_response,
106+
status=200,
107+
)
108+
109+
yield rsps
129110

130111

131112
@pytest.fixture
@@ -426,6 +407,7 @@ def mock_hf_chat_completion_api_streaming_tools():
426407
yield rsps
427408

428409

410+
@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
429411
@pytest.mark.parametrize("send_default_pii", [True, False])
430412
@pytest.mark.parametrize("include_prompts", [True, False])
431413
def test_text_generation(
@@ -453,11 +435,21 @@ def test_text_generation(
453435
)
454436

455437
(transaction,) = events
456-
(span,) = transaction["spans"]
457438

458-
assert span["op"] == "gen_ai.generate_text"
459-
assert span["description"] == "generate_text test-model"
460-
assert span["origin"] == "auto.ai.huggingface_hub"
439+
gen_ai_span = None
440+
for span in transaction["spans"]:
441+
if span["op"] == "gen_ai.generate_text":
442+
gen_ai_span = span
443+
else:
444+
# there should be no other spans, just the gen_ai.generate_text span
445+
# and optionally some http.client spans from talking to the hf api
446+
assert span["op"] == "http.client"
447+
448+
assert gen_ai_span is not None
449+
450+
assert gen_ai_span["op"] == "gen_ai.generate_text"
451+
assert gen_ai_span["description"] == "generate_text test-model"
452+
assert gen_ai_span["origin"] == "auto.ai.huggingface_hub"
461453

462454
expected_data = {
463455
"gen_ai.operation.name": "generate_text",
@@ -477,67 +469,10 @@ def test_text_generation(
477469
assert "gen_ai.request.messages" not in expected_data
478470
assert "gen_ai.response.text" not in expected_data
479471

480-
assert span["data"] == expected_data
481-
482-
# text generation does not set the response model
483-
assert "gen_ai.response.model" not in span["data"]
484-
485-
486-
@pytest.mark.httpx_mock(assert_all_requests_were_expected=False)
487-
def test_text_generation_TODO(
488-
sentry_init,
489-
capture_events,
490-
mock_hf_text_generation_api_httpx,
491-
):
492-
# type: (Any, Any, Any) -> None
493-
sentry_init(
494-
traces_sample_rate=1.0,
495-
send_default_pii=True,
496-
integrations=[HuggingfaceHubIntegration(include_prompts=True)],
497-
)
498-
events = capture_events()
499-
500-
client = InferenceClient(model="test-model")
501-
502-
with sentry_sdk.start_transaction(name="test"):
503-
client.text_generation(
504-
"Hello",
505-
stream=False,
506-
details=True,
507-
)
508-
509-
(transaction,) = events
510-
511-
# Find the huggingface_hub span (there might be httpx spans too)
512-
hf_spans = [
513-
span for span in transaction["spans"] if span["op"] == "gen_ai.generate_text"
514-
]
515-
assert len(hf_spans) == 1, (
516-
f"Expected 1 huggingface span, got {len(hf_spans)}: {[s['op'] for s in transaction['spans']]}"
517-
)
518-
span = hf_spans[0]
519-
520-
assert span["op"] == "gen_ai.generate_text"
521-
assert span["description"] == "generate_text test-model"
522-
assert span["origin"] == "auto.ai.huggingface_hub"
523-
524-
expected_data = {
525-
"gen_ai.operation.name": "generate_text",
526-
"gen_ai.request.model": "test-model",
527-
"gen_ai.response.finish_reasons": "length",
528-
"gen_ai.response.streaming": False,
529-
"gen_ai.usage.total_tokens": 10,
530-
"thread.id": mock.ANY,
531-
"thread.name": mock.ANY,
532-
}
533-
534-
expected_data["gen_ai.request.messages"] = "Hello"
535-
expected_data["gen_ai.response.text"] = "[mocked] Hello! How can i help you?"
536-
537-
assert span["data"] == expected_data
472+
assert gen_ai_span["data"] == expected_data
538473

539474
# text generation does not set the response model
540-
assert "gen_ai.response.model" not in span["data"]
475+
assert "gen_ai.response.model" not in gen_ai_span["data"]
541476

542477

543478
@pytest.mark.parametrize("send_default_pii", [True, False])

0 commit comments

Comments
 (0)