|
1 | 1 | from unittest import mock |
2 | 2 | import pytest |
| 3 | +import re |
3 | 4 | import responses |
| 5 | +import httpx |
4 | 6 |
|
5 | 7 | from huggingface_hub import InferenceClient |
6 | 8 |
|
@@ -76,6 +78,56 @@ def mock_hf_text_generation_api(): |
76 | 78 | yield rsps |
77 | 79 |
|
78 | 80 |
|
| 81 | +@pytest.fixture |
| 82 | +def mock_hf_text_generation_api_httpx(httpx_mock): |
| 83 | + # type: (Any) -> Any |
| 84 | + """Mock HuggingFace text generation API for httpx""" |
| 85 | + model_name = "test-model" |
| 86 | + |
| 87 | + # Mock model info endpoint (with query parameters) - allow multiple calls |
| 88 | + # Using pattern matching to handle query parameters |
| 89 | + model_url_pattern = re.compile( |
| 90 | + re.escape(MODEL_ENDPOINT.format(model_name=model_name)) + r"(\?.*)?$" |
| 91 | + ) |
| 92 | + |
| 93 | + # Add exactly the number of responses we expect (2 calls observed from debug output) |
| 94 | + for _ in range(2): # Allow exactly 2 calls to the model info endpoint |
| 95 | + httpx_mock.add_response( |
| 96 | + method="GET", |
| 97 | + url=model_url_pattern, |
| 98 | + json={ |
| 99 | + "id": model_name, |
| 100 | + "pipeline_tag": "text-generation", |
| 101 | + "inferenceProviderMapping": { |
| 102 | + "hf-inference": { |
| 103 | + "status": "live", |
| 104 | + "providerId": model_name, |
| 105 | + "task": "text-generation", |
| 106 | + } |
| 107 | + }, |
| 108 | + }, |
| 109 | + status_code=200, |
| 110 | + ) |
| 111 | + |
| 112 | + # Mock text generation endpoint |
| 113 | + httpx_mock.add_response( |
| 114 | + method="POST", |
| 115 | + url=INFERENCE_ENDPOINT.format(model_name=model_name), |
| 116 | + json={ |
| 117 | + "generated_text": "[mocked] Hello! How can i help you?", |
| 118 | + "details": { |
| 119 | + "finish_reason": "length", |
| 120 | + "generated_tokens": 10, |
| 121 | + "prefill": [], |
| 122 | + "tokens": [], |
| 123 | + }, |
| 124 | + }, |
| 125 | + status_code=200, |
| 126 | + ) |
| 127 | + |
| 128 | + return httpx_mock |
| 129 | + |
| 130 | + |
79 | 131 | @pytest.fixture |
80 | 132 | def mock_hf_api_with_errors(): |
81 | 133 | # type: () -> Any |
@@ -431,6 +483,63 @@ def test_text_generation( |
431 | 483 | assert "gen_ai.response.model" not in span["data"] |
432 | 484 |
|
433 | 485 |
|
| 486 | +@pytest.mark.httpx_mock(assert_all_requests_were_expected=False) |
| 487 | +def test_text_generation_TODO( |
| 488 | + sentry_init, |
| 489 | + capture_events, |
| 490 | + mock_hf_text_generation_api_httpx, |
| 491 | +): |
| 492 | + # type: (Any, Any, Any) -> None |
| 493 | + sentry_init( |
| 494 | + traces_sample_rate=1.0, |
| 495 | + send_default_pii=True, |
| 496 | + integrations=[HuggingfaceHubIntegration(include_prompts=True)], |
| 497 | + ) |
| 498 | + events = capture_events() |
| 499 | + |
| 500 | + client = InferenceClient(model="test-model") |
| 501 | + |
| 502 | + with sentry_sdk.start_transaction(name="test"): |
| 503 | + client.text_generation( |
| 504 | + "Hello", |
| 505 | + stream=False, |
| 506 | + details=True, |
| 507 | + ) |
| 508 | + |
| 509 | + (transaction,) = events |
| 510 | + |
| 511 | + # Find the huggingface_hub span (there might be httpx spans too) |
| 512 | + hf_spans = [ |
| 513 | + span for span in transaction["spans"] if span["op"] == "gen_ai.generate_text" |
| 514 | + ] |
| 515 | + assert len(hf_spans) == 1, ( |
| 516 | + f"Expected 1 huggingface span, got {len(hf_spans)}: {[s['op'] for s in transaction['spans']]}" |
| 517 | + ) |
| 518 | + span = hf_spans[0] |
| 519 | + |
| 520 | + assert span["op"] == "gen_ai.generate_text" |
| 521 | + assert span["description"] == "generate_text test-model" |
| 522 | + assert span["origin"] == "auto.ai.huggingface_hub" |
| 523 | + |
| 524 | + expected_data = { |
| 525 | + "gen_ai.operation.name": "generate_text", |
| 526 | + "gen_ai.request.model": "test-model", |
| 527 | + "gen_ai.response.finish_reasons": "length", |
| 528 | + "gen_ai.response.streaming": False, |
| 529 | + "gen_ai.usage.total_tokens": 10, |
| 530 | + "thread.id": mock.ANY, |
| 531 | + "thread.name": mock.ANY, |
| 532 | + } |
| 533 | + |
| 534 | + expected_data["gen_ai.request.messages"] = "Hello" |
| 535 | + expected_data["gen_ai.response.text"] = "[mocked] Hello! How can i help you?" |
| 536 | + |
| 537 | + assert span["data"] == expected_data |
| 538 | + |
| 539 | + # text generation does not set the response model |
| 540 | + assert "gen_ai.response.model" not in span["data"] |
| 541 | + |
| 542 | + |
434 | 543 | @pytest.mark.parametrize("send_default_pii", [True, False]) |
435 | 544 | @pytest.mark.parametrize("include_prompts", [True, False]) |
436 | 545 | def test_text_generation_streaming( |
|
0 commit comments