Skip to content

Commit f741df8

Browse files
authored
fix: Update flaky HugginFace Generator tests to use more reliable model and add instruction tokens (#8980)
* Fix test * Make other HF tests more reliable * Add back test
1 parent ec97f4d commit f741df8

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,15 @@ def test_run_with_tools(self, mock_check_valid_model, tools):
570570
def test_live_run_serverless(self):
571571
generator = HuggingFaceAPIChatGenerator(
572572
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
573-
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
573+
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"},
574574
generation_kwargs={"max_tokens": 20},
575575
)
576576

577-
messages = [ChatMessage.from_user("What is the capital of France?")]
577+
# No need for instruction tokens here since we use the chat_completion endpoint which handles the chat
578+
# templating for us.
579+
messages = [
580+
ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.")
581+
]
578582
response = generator.run(messages=messages)
579583

580584
assert "replies" in response
@@ -594,12 +598,16 @@ def test_live_run_serverless(self):
594598
def test_live_run_serverless_streaming(self):
595599
generator = HuggingFaceAPIChatGenerator(
596600
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
597-
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
601+
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"},
598602
generation_kwargs={"max_tokens": 20},
599603
streaming_callback=streaming_callback_handler,
600604
)
601605

602-
messages = [ChatMessage.from_user("What is the capital of France?")]
606+
# No need for instruction tokens here since we use the chat_completion endpoint which handles the chat
607+
# templating for us.
608+
messages = [
609+
ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.")
610+
]
603611
response = generator.run(messages=messages)
604612

605613
assert "replies" in response
@@ -817,11 +825,15 @@ async def test_run_async_with_tools(self, tools, mock_check_valid_model):
817825
async def test_live_run_async_serverless(self):
818826
generator = HuggingFaceAPIChatGenerator(
819827
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
820-
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
828+
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"},
821829
generation_kwargs={"max_tokens": 20},
822830
)
823831

824-
messages = [ChatMessage.from_user("What is the capital of France?")]
832+
# No need for instruction tokens here since we use the chat_completion endpoint which handles the chat
833+
# templating for us.
834+
messages = [
835+
ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.")
836+
]
825837
response = await generator.run_async(messages=messages)
826838

827839
assert "replies" in response

test/components/generators/test_hugging_face_api.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,15 +298,20 @@ def mock_iter(self):
298298
def test_run_serverless(self):
299299
generator = HuggingFaceAPIGenerator(
300300
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
301-
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
301+
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"},
302302
generation_kwargs={"max_new_tokens": 20},
303303
)
304304

305-
response = generator.run("How are you?")
305+
# You must include the instruction tokens in the prompt. HF does not add them automatically.
306+
# Without them the model will behave erratically.
307+
response = generator.run(
308+
"<s>[INST] What is the capital of France? Be concise only provide the capital, nothing else.[/INST]"
309+
)
310+
306311
# Assert that the response contains the generated replies
307312
assert "replies" in response
308313
assert isinstance(response["replies"], list)
309-
assert len(response["replies"]) > 0
314+
assert len(response["replies"]) == 1
310315
assert [isinstance(reply, str) for reply in response["replies"]]
311316

312317
# Assert that the response contains the metadata
@@ -317,7 +322,10 @@ def test_run_serverless(self):
317322

318323
@pytest.mark.flaky(reruns=5, reruns_delay=5)
319324
@pytest.mark.integration
320-
@pytest.mark.skip(reason="Temporarily skipped due to weird responses from the selected model.")
325+
@pytest.mark.skipif(
326+
not os.environ.get("HF_API_TOKEN", None),
327+
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
328+
)
321329
def test_live_run_streaming_check_completion_start_time(self):
322330
generator = HuggingFaceAPIGenerator(
323331
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
@@ -328,10 +336,13 @@ def test_live_run_streaming_check_completion_start_time(self):
328336

329337
results = generator.run("You are a helpful agent that answers questions. What is the capital of France?")
330338

339+
# Assert that the response contains the generated replies
340+
assert "replies" in results
341+
assert isinstance(results["replies"], list)
331342
assert len(results["replies"]) == 1
332-
assert "Paris" in results["replies"][0]
343+
assert [isinstance(reply, str) for reply in results["replies"]]
333344

334345
# Verify completion start time in final metadata
335346
assert "completion_start_time" in results["meta"][0]
336347
completion_start = datetime.fromisoformat(results["meta"][0]["completion_start_time"])
337-
assert completion_start <= datetime.now()
348+
assert completion_start is not None

test/components/generators/test_hugging_face_local_generator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88
import torch
9-
from transformers import PreTrainedTokenizerFast
9+
from transformers import AutoTokenizer, PreTrainedTokenizerFast
1010

1111
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
1212
from haystack.utils import ComponentDevice
@@ -472,7 +472,13 @@ def test_live_run(self, monkeypatch):
472472
llm = HuggingFaceLocalGenerator(model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50})
473473
llm.warm_up()
474474

475-
result = llm.run(prompt="Please create a summary about the following topic: Climate change")
475+
# You must use the `apply_chat_template` method to add the generation prompt to properly include the instruction
476+
# tokens in the prompt. Otherwise, the model will not generate the expected output.
477+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
478+
messages = [{"role": "user", "content": "Please repeat the phrase 'climate change' and nothing else"}]
479+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
480+
481+
result = llm.run(prompt=prompt)
476482

477483
assert "replies" in result
478484
assert isinstance(result["replies"][0], str)

0 commit comments

Comments
 (0)