Skip to content

Commit 4c9d08a

Browse files
davidsbatistaAmnah199sjrlanakin87HaystackBot
authored
feat: async support for the HuggingFaceLocalChatGenerator (#8981)
* adding async run method * passing an optional ThreadExecutor * adding tests * adding release notes * nit: license * fixing linting * Update releasenotes/notes/adding-async-huggingface-local-chat-generator-962512f52282d12d.yaml Co-authored-by: Amna Mubashar <[email protected]> * Use Phi isntead (#8982) * build: drop Python 3.8 support (#8978) * draft * readd typing_extensions * small fix + release note * remove ruff target-version * Update releasenotes/notes/drop-python-3.8-868710963e794c83.yaml Co-authored-by: David S. Batista <[email protected]> --------- Co-authored-by: David S. Batista <[email protected]> * Update unstable version to 2.12.0-rc0 (#8983) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * fix: allow support for `include_usage` in streaming using OpenAIChatGenerator (#8968) * fix error in handling usage completion chunk * ci: improve release notes format checking (#8984) * chore: fix invalid release note * try improving relnote linting * add relnotes path * fix bad release note * improve reno config * fix: handle async tests in`HuggingFaceAPIChatGenerator` to prevent error (#8986) * add missing asyncio * explicitly close connection in the test * Fix tests (#8990) * docs: Update docstrings of `BranchJoiner` (#8988) * Update docstrings * Add a bit more explanatory text * Add reno * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <[email protected]> * Update haystack/components/joiners/branch.py Co-authored-by: Daria Fokina <[email protected]> * Fix formatting --------- Co-authored-by: Daria Fokina <[email protected]> * PR comments * destroying ThreadPoolExecutor when the generator instance is being destroyied, only if it was not passed externally * fixing bug in streaming_callback * PR comments --------- Co-authored-by: Amna Mubashar <[email protected]> Co-authored-by: Sebastian Husch Lee <[email protected]> Co-authored-by: Stefano Fiorucci <[email protected]> Co-authored-by: Haystack Bot <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Daria Fokina <[email protected]>
1 parent c4fafd9 commit 4c9d08a

File tree

3 files changed

+260
-6
lines changed

3 files changed

+260
-6
lines changed

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 176 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
import json
67
import re
78
import sys
9+
from concurrent.futures import ThreadPoolExecutor
810
from typing import Any, Callable, Dict, List, Literal, Optional, Union
911

1012
from haystack import component, default_from_dict, default_to_dict, logging
11-
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
13+
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
1214
from haystack.lazy_imports import LazyImport
1315
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1416
from haystack.utils import (
@@ -123,6 +125,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
123125
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
124126
tools: Optional[List[Tool]] = None,
125127
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
128+
async_executor: Optional[ThreadPoolExecutor] = None,
126129
):
127130
"""
128131
Initializes the HuggingFaceLocalChatGenerator component.
@@ -165,6 +168,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
165168
:param tool_parsing_function:
166169
A callable that takes a string and returns a list of ToolCall objects or None.
167170
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
171+
:param async_executor:
172+
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
173+
initialized and used
168174
"""
169175
torch_and_transformers_import.check()
170176

@@ -223,6 +229,27 @@ def __init__( # pylint: disable=too-many-positional-arguments
223229
self.pipeline = None
224230
self.tools = tools
225231

232+
self._owns_executor = async_executor is None
233+
self.executor = (
234+
ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1)
235+
if async_executor is None
236+
else async_executor
237+
)
238+
239+
def __del__(self):
240+
"""
241+
Cleanup when the instance is being destroyed.
242+
"""
243+
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
244+
self.executor.shutdown(wait=True)
245+
246+
def shutdown(self):
247+
"""
248+
Explicitly shutdown the executor if we own it.
249+
"""
250+
if self._owns_executor:
251+
self.executor.shutdown(wait=True)
252+
226253
def _get_telemetry_data(self) -> Dict[str, Any]:
227254
"""
228255
Data that is sent to Posthog for usage analytics.
@@ -332,7 +359,9 @@ def run(
332359
if stop_words_criteria:
333360
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
334361

335-
streaming_callback = streaming_callback or self.streaming_callback
362+
streaming_callback = select_streaming_callback(
363+
self.streaming_callback, streaming_callback, requires_async=False
364+
)
336365
if streaming_callback:
337366
num_responses = generation_kwargs.get("num_return_sequences", 1)
338367
if num_responses > 1:
@@ -427,7 +456,8 @@ def create_message( # pylint: disable=too-many-positional-arguments
427456
# If tool calls are detected, don't include the text content since it contains the raw tool call format
428457
return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
429458

430-
def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
459+
@staticmethod
460+
def _validate_stop_words(stop_words: Optional[List[str]]) -> Optional[List[str]]:
431461
"""
432462
Validates the provided stop words.
433463
@@ -443,3 +473,146 @@ def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List
443473
return None
444474

445475
return list(set(stop_words or []))
476+
477+
@component.output_types(replies=List[ChatMessage])
478+
async def run_async(
479+
self,
480+
messages: List[ChatMessage],
481+
generation_kwargs: Optional[Dict[str, Any]] = None,
482+
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
483+
tools: Optional[List[Tool]] = None,
484+
):
485+
"""
486+
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
487+
488+
This is the asynchronous version of the `run` method. It has the same parameters
489+
and return values but can be used with `await` in an async code.
490+
491+
:param messages: A list of ChatMessage objects representing the input messages.
492+
:param generation_kwargs: Additional keyword arguments for text generation.
493+
:param streaming_callback: An optional callable for handling streaming responses.
494+
:param tools: A list of tools for which the model can prepare calls.
495+
:returns: A dictionary with the following keys:
496+
- `replies`: A list containing the generated responses as ChatMessage instances.
497+
"""
498+
if self.pipeline is None:
499+
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
500+
501+
tools = tools or self.tools
502+
if tools and streaming_callback is not None:
503+
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
504+
_check_duplicate_tool_names(tools)
505+
506+
tokenizer = self.pipeline.tokenizer
507+
508+
# Check and update generation parameters
509+
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
510+
511+
stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
512+
stop_words = self._validate_stop_words(stop_words)
513+
514+
# Set up stop words criteria if stop words exist
515+
stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
516+
if stop_words_criteria:
517+
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
518+
519+
# validate and select the streaming callback
520+
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
521+
522+
if streaming_callback:
523+
return await self._run_streaming_async(
524+
messages, tokenizer, generation_kwargs, stop_words, streaming_callback
525+
)
526+
527+
return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools)
528+
529+
async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
530+
self,
531+
messages: List[ChatMessage],
532+
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
533+
generation_kwargs: Dict[str, Any],
534+
stop_words: Optional[List[str]],
535+
streaming_callback: Callable[[StreamingChunk], None],
536+
):
537+
"""
538+
Handles async streaming generation of responses.
539+
"""
540+
# convert messages to HF format
541+
hf_messages = [convert_message_to_hf_format(message) for message in messages]
542+
prepared_prompt = tokenizer.apply_chat_template(
543+
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
544+
)
545+
546+
# Avoid some unnecessary warnings in the generation pipeline call
547+
generation_kwargs["pad_token_id"] = (
548+
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
549+
)
550+
551+
# Set up streaming handler
552+
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
553+
554+
# Generate responses asynchronously
555+
output = await asyncio.get_running_loop().run_in_executor(
556+
self.executor,
557+
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
558+
)
559+
560+
replies = [o.get("generated_text", "") for o in output]
561+
562+
# Remove stop words from replies if present
563+
for stop_word in stop_words or []:
564+
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
565+
566+
chat_messages = [
567+
self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False)
568+
for r_index, reply in enumerate(replies)
569+
]
570+
571+
return {"replies": chat_messages}
572+
573+
async def _run_non_streaming_async( # pylint: disable=too-many-positional-arguments
574+
self,
575+
messages: List[ChatMessage],
576+
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
577+
generation_kwargs: Dict[str, Any],
578+
stop_words: Optional[List[str]],
579+
tools: Optional[List[Tool]] = None,
580+
):
581+
"""
582+
Handles async non-streaming generation of responses.
583+
"""
584+
# convert messages to HF format
585+
hf_messages = [convert_message_to_hf_format(message) for message in messages]
586+
prepared_prompt = tokenizer.apply_chat_template(
587+
hf_messages,
588+
tokenize=False,
589+
chat_template=self.chat_template,
590+
add_generation_prompt=True,
591+
tools=[tc.tool_spec for tc in tools] if tools else None,
592+
)
593+
594+
# Avoid some unnecessary warnings in the generation pipeline call
595+
generation_kwargs["pad_token_id"] = (
596+
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
597+
)
598+
599+
# Generate responses asynchronously
600+
output = await asyncio.get_running_loop().run_in_executor(
601+
self.executor,
602+
lambda: self.pipeline(prepared_prompt, **generation_kwargs), # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
603+
)
604+
605+
replies = [o.get("generated_text", "") for o in output]
606+
607+
# Remove stop words from replies if present
608+
for stop_word in stop_words or []:
609+
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
610+
611+
chat_messages = [
612+
self.create_message(
613+
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
614+
)
615+
for r_index, reply in enumerate(replies)
616+
]
617+
618+
return {"replies": chat_messages}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
features:
3+
- |
4+
Add `run_async` method to HuggingFaceLocalChatGenerator. This method internally uses ThreadPoolExecutor to return coroutines
5+
that can be awaited.

test/components/generators/chat/test_hugging_face_local.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from unittest.mock import Mock, patch
4+
5+
import asyncio
6+
import gc
57
from typing import Optional, List
8+
from unittest.mock import Mock, patch
69

7-
from haystack.dataclasses.streaming_chunk import StreamingChunk
810
import pytest
911
from transformers import PreTrainedTokenizer
1012

1113
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
1214
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
15+
from haystack.dataclasses.streaming_chunk import StreamingChunk
16+
from haystack.tools import Tool
1317
from haystack.utils import ComponentDevice
1418
from haystack.utils.auth import Secret
15-
from haystack.tools import Tool
1619

1720

1821
# used to test serialization of streaming_callback
@@ -474,3 +477,76 @@ def test_default_tool_parser(self, model_info_mock, tools):
474477
assert len(results["replies"][0].tool_calls) == 1
475478
assert results["replies"][0].tool_calls[0].tool_name == "weather"
476479
assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"}
480+
481+
# Async tests
482+
483+
async def test_run_async(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
484+
"""Test basic async functionality"""
485+
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
486+
generator.pipeline = mock_pipeline_tokenizer
487+
488+
results = await generator.run_async(messages=chat_messages)
489+
490+
assert "replies" in results
491+
assert isinstance(results["replies"][0], ChatMessage)
492+
chat_message = results["replies"][0]
493+
assert chat_message.is_from(ChatRole.ASSISTANT)
494+
assert chat_message.text == "Berlin is cool"
495+
496+
async def test_run_async_with_tools(self, model_info_mock, mock_pipeline_tokenizer, tools):
497+
"""Test async functionality with tools"""
498+
generator = HuggingFaceLocalChatGenerator(model="mocked-model", tools=tools)
499+
generator.pipeline = mock_pipeline_tokenizer
500+
# Mock the pipeline to return a tool call format
501+
generator.pipeline.return_value = [{"generated_text": '{"name": "weather", "arguments": {"city": "Berlin"}}'}]
502+
503+
messages = [ChatMessage.from_user("What's the weather in Berlin?")]
504+
results = await generator.run_async(messages=messages)
505+
506+
assert len(results["replies"]) == 1
507+
message = results["replies"][0]
508+
assert message.tool_calls
509+
tool_call = message.tool_calls[0]
510+
assert isinstance(tool_call, ToolCall)
511+
assert tool_call.tool_name == "weather"
512+
assert tool_call.arguments == {"city": "Berlin"}
513+
514+
async def test_concurrent_async_requests(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
515+
"""Test handling of multiple concurrent async requests"""
516+
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
517+
generator.pipeline = mock_pipeline_tokenizer
518+
519+
# Create multiple concurrent requests
520+
tasks = [generator.run_async(messages=chat_messages) for _ in range(5)]
521+
results = await asyncio.gather(*tasks)
522+
523+
for result in results:
524+
assert "replies" in result
525+
assert isinstance(result["replies"][0], ChatMessage)
526+
assert result["replies"][0].text == "Berlin is cool"
527+
528+
async def test_async_error_handling(self, model_info_mock, mock_pipeline_tokenizer):
529+
"""Test error handling in async context"""
530+
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
531+
532+
# Test without warm_up
533+
with pytest.raises(RuntimeError, match="The generation model has not been loaded"):
534+
await generator.run_async(messages=[ChatMessage.from_user("test")])
535+
536+
# Test with invalid streaming callback
537+
generator.pipeline = mock_pipeline_tokenizer
538+
with pytest.raises(ValueError, match="Using tools and streaming at the same time is not supported"):
539+
await generator.run_async(
540+
messages=[ChatMessage.from_user("test")],
541+
streaming_callback=lambda x: None,
542+
tools=[Tool(name="test", description="test", parameters={}, function=lambda: None)],
543+
)
544+
545+
def test_executor_shutdown(self, model_info_mock, mock_pipeline_tokenizer):
546+
with patch("haystack.components.generators.chat.hugging_face_local.pipeline") as mock_pipeline:
547+
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
548+
executor = generator.executor
549+
with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown:
550+
del generator
551+
gc.collect()
552+
mock_shutdown.assert_called_once_with(wait=True)

0 commit comments

Comments
 (0)