Skip to content

Commit 28db039

Browse files
authored
feat: add run_async to HuggingfaceAPIChatGenerator (#8943)
* add run_async * add release notes * Add integration test
1 parent 1b2053b commit 28db039

File tree

4 files changed

+368
-4
lines changed

4 files changed

+368
-4
lines changed

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from datetime import datetime
6-
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
6+
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union
77

88
from haystack import component, default_from_dict, default_to_dict, logging
9-
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
9+
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
1010
from haystack.lazy_imports import LazyImport
1111
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1212
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@@ -15,6 +15,7 @@
1515

1616
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1717
from huggingface_hub import (
18+
AsyncInferenceClient,
1819
ChatCompletionInputFunctionDefinition,
1920
ChatCompletionInputTool,
2021
ChatCompletionOutput,
@@ -181,6 +182,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
181182
self.generation_kwargs = generation_kwargs
182183
self.streaming_callback = streaming_callback
183184
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
185+
self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
184186
self.tools = tools
185187

186188
def to_dict(self) -> Dict[str, Any]:
@@ -250,7 +252,11 @@ def run(
250252
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
251253
_check_duplicate_tool_names(tools)
252254

253-
streaming_callback = streaming_callback or self.streaming_callback
255+
# validate and select the streaming callback
256+
streaming_callback = select_streaming_callback(
257+
self.streaming_callback, streaming_callback, requires_async=False
258+
) # type: ignore
259+
254260
if streaming_callback:
255261
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
256262

@@ -267,6 +273,63 @@ def run(
267273
]
268274
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
269275

276+
@component.output_types(replies=List[ChatMessage])
277+
async def run_async(
278+
self,
279+
messages: List[ChatMessage],
280+
generation_kwargs: Optional[Dict[str, Any]] = None,
281+
tools: Optional[List[Tool]] = None,
282+
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
283+
):
284+
"""
285+
Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
286+
287+
This is the asynchronous version of the `run` method. It has the same parameters
288+
and return values but can be used with `await` in an async code.
289+
290+
:param messages:
291+
A list of ChatMessage objects representing the input messages.
292+
:param generation_kwargs:
293+
Additional keyword arguments for text generation.
294+
:param tools:
295+
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
296+
during component initialization.
297+
:param streaming_callback:
298+
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
299+
parameter set during component initialization.
300+
:returns: A dictionary with the following keys:
301+
- `replies`: A list containing the generated responses as ChatMessage objects.
302+
"""
303+
304+
# update generation kwargs by merging with the default ones
305+
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
306+
307+
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
308+
309+
tools = tools or self.tools
310+
if tools and self.streaming_callback:
311+
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
312+
_check_duplicate_tool_names(tools)
313+
314+
# validate and select the streaming callback
315+
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) # type: ignore
316+
317+
if streaming_callback:
318+
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
319+
320+
hf_tools = None
321+
if tools:
322+
hf_tools = [
323+
ChatCompletionInputTool(
324+
function=ChatCompletionInputFunctionDefinition(
325+
name=tool.name, description=tool.description, arguments=tool.parameters
326+
),
327+
type="function",
328+
)
329+
for tool in tools
330+
]
331+
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)
332+
270333
def _run_streaming(
271334
self,
272335
messages: List[Dict[str, str]],
@@ -359,3 +422,89 @@ def _run_non_streaming(
359422

360423
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
361424
return {"replies": [message]}
425+
426+
async def _run_streaming_async(
427+
self,
428+
messages: List[Dict[str, str]],
429+
generation_kwargs: Dict[str, Any],
430+
streaming_callback: Callable[[StreamingChunk], None],
431+
):
432+
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
433+
messages, stream=True, **generation_kwargs
434+
)
435+
436+
generated_text = ""
437+
first_chunk_time = None
438+
439+
async for chunk in api_output:
440+
choice = chunk.choices[0]
441+
442+
text = choice.delta.content or ""
443+
generated_text += text
444+
445+
finish_reason = choice.finish_reason
446+
447+
meta: Dict[str, Any] = {}
448+
if finish_reason:
449+
meta["finish_reason"] = finish_reason
450+
451+
if first_chunk_time is None:
452+
first_chunk_time = datetime.now().isoformat()
453+
454+
stream_chunk = StreamingChunk(text, meta)
455+
await streaming_callback(stream_chunk) # type: ignore
456+
457+
meta.update(
458+
{
459+
"model": self._async_client.model,
460+
"finish_reason": finish_reason,
461+
"index": 0,
462+
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
463+
"completion_start_time": first_chunk_time,
464+
}
465+
)
466+
467+
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
468+
return {"replies": [message]}
469+
470+
async def _run_non_streaming_async(
471+
self,
472+
messages: List[Dict[str, str]],
473+
generation_kwargs: Dict[str, Any],
474+
tools: Optional[List["ChatCompletionInputTool"]] = None,
475+
) -> Dict[str, List[ChatMessage]]:
476+
api_chat_output: ChatCompletionOutput = await self._async_client.chat_completion(
477+
messages=messages, tools=tools, **generation_kwargs
478+
)
479+
480+
if len(api_chat_output.choices) == 0:
481+
return {"replies": []}
482+
483+
choice = api_chat_output.choices[0]
484+
485+
text = choice.message.content
486+
tool_calls = []
487+
488+
if hfapi_tool_calls := choice.message.tool_calls:
489+
for hfapi_tc in hfapi_tool_calls:
490+
tool_call = ToolCall(
491+
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
492+
)
493+
tool_calls.append(tool_call)
494+
495+
meta: Dict[str, Any] = {
496+
"model": self._async_client.model,
497+
"finish_reason": choice.finish_reason,
498+
"index": choice.index,
499+
}
500+
501+
usage = {"prompt_tokens": 0, "completion_tokens": 0}
502+
if api_chat_output.usage:
503+
usage = {
504+
"prompt_tokens": api_chat_output.usage.prompt_tokens,
505+
"completion_tokens": api_chat_output.usage.completion_tokens,
506+
}
507+
meta["usage"] = usage
508+
509+
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
510+
return {"replies": [message]}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ markers = [
278278
]
279279
log_cli = true
280280
asyncio_mode = "auto"
281+
asyncio_default_fixture_loop_scope = "class"
281282

282283
[tool.mypy]
283284
warn_return_any = false
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
Add `run_async` method to HuggingFaceAPIChatGenerator. This method relies internally on the `AsyncInferenceClient` from huggingface
5+
to generate chat completions and supports the same parameters as the `run` method. It returns a coroutine
6+
that can be awaited.

0 commit comments

Comments
 (0)