3535
3636
3737@pytest .fixture
38- def mock_hf_text_generation_api ():
38+ def mock_hf_text_generation_api (httpx_mock ):
3939 # type: () -> Any
4040 """Mock HuggingFace text generation API"""
41- with responses .RequestsMock (assert_all_requests_are_fired = False ) as rsps :
42- model_name = "test-model"
41+ model_name = "test-model"
42+
43+ model_endpoint_response = {
44+ "id" : model_name ,
45+ "pipeline_tag" : "text-generation" ,
46+ "inferenceProviderMapping" : {
47+ "hf-inference" : {
48+ "status" : "live" ,
49+ "providerId" : model_name ,
50+ "task" : "text-generation" ,
51+ }
52+ },
53+ }
54+ inference_endpoint_response = {
55+ "generated_text" : "[mocked] Hello! How can i help you?" ,
56+ "details" : {
57+ "finish_reason" : "length" ,
58+ "generated_tokens" : 10 ,
59+ "prefill" : [],
60+ "tokens" : [],
61+ },
62+ }
4363
64+ if HF_VERSION >= (1 , 0 , 0 ):
4465 # Mock model info endpoint
45- rsps .add (
46- responses .GET ,
47- MODEL_ENDPOINT .format (model_name = model_name ),
48- json = {
49- "id" : model_name ,
50- "pipeline_tag" : "text-generation" ,
51- "inferenceProviderMapping" : {
52- "hf-inference" : {
53- "status" : "live" ,
54- "providerId" : model_name ,
55- "task" : "text-generation" ,
56- }
57- },
58- },
59- status = 200 ,
66+ httpx_mock .add_response (
67+ method = "GET" ,
68+ url = re .compile (
69+ MODEL_ENDPOINT .format (model_name = model_name )
70+ + r"(\?expand=inferenceProviderMapping)?"
71+ ),
72+ json = model_endpoint_response ,
73+ status_code = 200 ,
74+ is_optional = True ,
75+ is_reusable = True ,
6076 )
6177
6278 # Mock text generation endpoint
63- rsps .add (
64- responses .POST ,
65- INFERENCE_ENDPOINT .format (model_name = model_name ),
66- json = {
67- "generated_text" : "[mocked] Hello! How can i help you?" ,
68- "details" : {
69- "finish_reason" : "length" ,
70- "generated_tokens" : 10 ,
71- "prefill" : [],
72- "tokens" : [],
73- },
74- },
75- status = 200 ,
76- )
77-
78- yield rsps
79-
80-
81- @pytest .fixture
82- def mock_hf_text_generation_api_httpx (httpx_mock ):
83- # type: (Any) -> Any
84- """Mock HuggingFace text generation API for httpx"""
85- model_name = "test-model"
86-
87- # Mock model info endpoint (with query parameters) - allow multiple calls
88- # Using pattern matching to handle query parameters
89- model_url_pattern = re .compile (
90- re .escape (MODEL_ENDPOINT .format (model_name = model_name )) + r"(\?.*)?$"
91- )
92-
93- # Add exactly the number of responses we expect (2 calls observed from debug output)
94- for _ in range (2 ): # Allow exactly 2 calls to the model info endpoint
9579 httpx_mock .add_response (
96- method = "GET" ,
97- url = model_url_pattern ,
98- json = {
99- "id" : model_name ,
100- "pipeline_tag" : "text-generation" ,
101- "inferenceProviderMapping" : {
102- "hf-inference" : {
103- "status" : "live" ,
104- "providerId" : model_name ,
105- "task" : "text-generation" ,
106- }
107- },
108- },
80+ method = "POST" ,
81+ url = INFERENCE_ENDPOINT .format (model_name = model_name ),
82+ json = inference_endpoint_response ,
10983 status_code = 200 ,
84+ is_optional = True ,
85+ is_reusable = True ,
11086 )
11187
112- # Mock text generation endpoint
113- httpx_mock .add_response (
114- method = "POST" ,
115- url = INFERENCE_ENDPOINT .format (model_name = model_name ),
116- json = {
117- "generated_text" : "[mocked] Hello! How can i help you?" ,
118- "details" : {
119- "finish_reason" : "length" ,
120- "generated_tokens" : 10 ,
121- "prefill" : [],
122- "tokens" : [],
123- },
124- },
125- status_code = 200 ,
126- )
88+ yield httpx_mock
12789
128- return httpx_mock
90+ else :
91+ # Older version of huggingface_hub, we need to mock requests
92+ with responses .RequestsMock (assert_all_requests_are_fired = False ) as rsps :
93+ # Mock model info endpoint
94+ rsps .add (
95+ responses .GET ,
96+ MODEL_ENDPOINT .format (model_name = model_name ),
97+ json = model_endpoint_response ,
98+ status = 200 ,
99+ )
100+
101+ # Mock text generation endpoint
102+ rsps .add (
103+ responses .POST ,
104+ INFERENCE_ENDPOINT .format (model_name = model_name ),
105+ json = inference_endpoint_response ,
106+ status = 200 ,
107+ )
108+
109+ yield rsps
129110
130111
131112@pytest .fixture
@@ -426,6 +407,7 @@ def mock_hf_chat_completion_api_streaming_tools():
426407 yield rsps
427408
428409
410+ @pytest .mark .httpx_mock (assert_all_requests_were_expected = False )
429411@pytest .mark .parametrize ("send_default_pii" , [True , False ])
430412@pytest .mark .parametrize ("include_prompts" , [True , False ])
431413def test_text_generation (
@@ -453,11 +435,21 @@ def test_text_generation(
453435 )
454436
455437 (transaction ,) = events
456- (span ,) = transaction ["spans" ]
457438
458- assert span ["op" ] == "gen_ai.generate_text"
459- assert span ["description" ] == "generate_text test-model"
460- assert span ["origin" ] == "auto.ai.huggingface_hub"
439+ gen_ai_span = None
440+ for span in transaction ["spans" ]:
441+ if span ["op" ] == "gen_ai.generate_text" :
442+ gen_ai_span = span
443+ else :
444+ # there should be no other spans, just the gen_ai.generate_text span
445+ # and optionally some http.client spans from talking to the hf api
446+ assert span ["op" ] == "http.client"
447+
448+ assert gen_ai_span is not None
449+
450+ assert gen_ai_span ["op" ] == "gen_ai.generate_text"
451+ assert gen_ai_span ["description" ] == "generate_text test-model"
452+ assert gen_ai_span ["origin" ] == "auto.ai.huggingface_hub"
461453
462454 expected_data = {
463455 "gen_ai.operation.name" : "generate_text" ,
@@ -477,67 +469,10 @@ def test_text_generation(
477469 assert "gen_ai.request.messages" not in expected_data
478470 assert "gen_ai.response.text" not in expected_data
479471
480- assert span ["data" ] == expected_data
481-
482- # text generation does not set the response model
483- assert "gen_ai.response.model" not in span ["data" ]
484-
485-
486- @pytest .mark .httpx_mock (assert_all_requests_were_expected = False )
487- def test_text_generation_TODO (
488- sentry_init ,
489- capture_events ,
490- mock_hf_text_generation_api_httpx ,
491- ):
492- # type: (Any, Any, Any) -> None
493- sentry_init (
494- traces_sample_rate = 1.0 ,
495- send_default_pii = True ,
496- integrations = [HuggingfaceHubIntegration (include_prompts = True )],
497- )
498- events = capture_events ()
499-
500- client = InferenceClient (model = "test-model" )
501-
502- with sentry_sdk .start_transaction (name = "test" ):
503- client .text_generation (
504- "Hello" ,
505- stream = False ,
506- details = True ,
507- )
508-
509- (transaction ,) = events
510-
511- # Find the huggingface_hub span (there might be httpx spans too)
512- hf_spans = [
513- span for span in transaction ["spans" ] if span ["op" ] == "gen_ai.generate_text"
514- ]
515- assert len (hf_spans ) == 1 , (
516- f"Expected 1 huggingface span, got { len (hf_spans )} : { [s ['op' ] for s in transaction ['spans' ]]} "
517- )
518- span = hf_spans [0 ]
519-
520- assert span ["op" ] == "gen_ai.generate_text"
521- assert span ["description" ] == "generate_text test-model"
522- assert span ["origin" ] == "auto.ai.huggingface_hub"
523-
524- expected_data = {
525- "gen_ai.operation.name" : "generate_text" ,
526- "gen_ai.request.model" : "test-model" ,
527- "gen_ai.response.finish_reasons" : "length" ,
528- "gen_ai.response.streaming" : False ,
529- "gen_ai.usage.total_tokens" : 10 ,
530- "thread.id" : mock .ANY ,
531- "thread.name" : mock .ANY ,
532- }
533-
534- expected_data ["gen_ai.request.messages" ] = "Hello"
535- expected_data ["gen_ai.response.text" ] = "[mocked] Hello! How can i help you?"
536-
537- assert span ["data" ] == expected_data
472+ assert gen_ai_span ["data" ] == expected_data
538473
539474 # text generation does not set the response model
540- assert "gen_ai.response.model" not in span ["data" ]
475+ assert "gen_ai.response.model" not in gen_ai_span ["data" ]
541476
542477
543478@pytest .mark .parametrize ("send_default_pii" , [True , False ])
0 commit comments