33# SPDX-License-Identifier: Apache-2.0
44
55from datetime import datetime
6- from typing import Any , Callable , Dict , Iterable , List , Optional , Union
6+ from typing import Any , AsyncIterable , Callable , Dict , Iterable , List , Optional , Union
77
88from haystack import component , default_from_dict , default_to_dict , logging
9- from haystack .dataclasses import ChatMessage , StreamingChunk , ToolCall
9+ from haystack .dataclasses import ChatMessage , StreamingChunk , ToolCall , select_streaming_callback
1010from haystack .lazy_imports import LazyImport
1111from haystack .tools .tool import Tool , _check_duplicate_tool_names , deserialize_tools_inplace
1212from haystack .utils import Secret , deserialize_callable , deserialize_secrets_inplace , serialize_callable
1515
1616with LazyImport (message = "Run 'pip install \" huggingface_hub[inference]>=0.27.0\" '" ) as huggingface_hub_import :
1717 from huggingface_hub import (
18+ AsyncInferenceClient ,
1819 ChatCompletionInputFunctionDefinition ,
1920 ChatCompletionInputTool ,
2021 ChatCompletionOutput ,
@@ -181,6 +182,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
181182 self .generation_kwargs = generation_kwargs
182183 self .streaming_callback = streaming_callback
183184 self ._client = InferenceClient (model_or_url , token = token .resolve_value () if token else None )
185+ self ._async_client = AsyncInferenceClient (model_or_url , token = token .resolve_value () if token else None )
184186 self .tools = tools
185187
186188 def to_dict (self ) -> Dict [str , Any ]:
@@ -250,7 +252,11 @@ def run(
250252 raise ValueError ("Using tools and streaming at the same time is not supported. Please choose one." )
251253 _check_duplicate_tool_names (tools )
252254
253- streaming_callback = streaming_callback or self .streaming_callback
255+ # validate and select the streaming callback
256+ streaming_callback = select_streaming_callback (
257+ self .streaming_callback , streaming_callback , requires_async = False
258+ ) # type: ignore
259+
254260 if streaming_callback :
255261 return self ._run_streaming (formatted_messages , generation_kwargs , streaming_callback )
256262
@@ -267,6 +273,63 @@ def run(
267273 ]
268274 return self ._run_non_streaming (formatted_messages , generation_kwargs , hf_tools )
269275
276+ @component .output_types (replies = List [ChatMessage ])
277+ async def run_async (
278+ self ,
279+ messages : List [ChatMessage ],
280+ generation_kwargs : Optional [Dict [str , Any ]] = None ,
281+ tools : Optional [List [Tool ]] = None ,
282+ streaming_callback : Optional [Callable [[StreamingChunk ], None ]] = None ,
283+ ):
284+ """
285+ Asynchronously invokes the text generation inference based on the provided messages and generation parameters.
286+
287+ This is the asynchronous version of the `run` method. It has the same parameters
288+ and return values but can be used with `await` in an async code.
289+
290+ :param messages:
291+ A list of ChatMessage objects representing the input messages.
292+ :param generation_kwargs:
293+ Additional keyword arguments for text generation.
294+ :param tools:
295+ A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
296+ during component initialization.
297+ :param streaming_callback:
298+ An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
299+ parameter set during component initialization.
300+ :returns: A dictionary with the following keys:
301+ - `replies`: A list containing the generated responses as ChatMessage objects.
302+ """
303+
304+ # update generation kwargs by merging with the default ones
305+ generation_kwargs = {** self .generation_kwargs , ** (generation_kwargs or {})}
306+
307+ formatted_messages = [convert_message_to_hf_format (message ) for message in messages ]
308+
309+ tools = tools or self .tools
310+ if tools and self .streaming_callback :
311+ raise ValueError ("Using tools and streaming at the same time is not supported. Please choose one." )
312+ _check_duplicate_tool_names (tools )
313+
314+ # validate and select the streaming callback
315+ streaming_callback = select_streaming_callback (self .streaming_callback , streaming_callback , requires_async = True ) # type: ignore
316+
317+ if streaming_callback :
318+ return await self ._run_streaming_async (formatted_messages , generation_kwargs , streaming_callback )
319+
320+ hf_tools = None
321+ if tools :
322+ hf_tools = [
323+ ChatCompletionInputTool (
324+ function = ChatCompletionInputFunctionDefinition (
325+ name = tool .name , description = tool .description , arguments = tool .parameters
326+ ),
327+ type = "function" ,
328+ )
329+ for tool in tools
330+ ]
331+ return await self ._run_non_streaming_async (formatted_messages , generation_kwargs , hf_tools )
332+
270333 def _run_streaming (
271334 self ,
272335 messages : List [Dict [str , str ]],
@@ -359,3 +422,89 @@ def _run_non_streaming(
359422
360423 message = ChatMessage .from_assistant (text = text , tool_calls = tool_calls , meta = meta )
361424 return {"replies" : [message ]}
425+
426+ async def _run_streaming_async (
427+ self ,
428+ messages : List [Dict [str , str ]],
429+ generation_kwargs : Dict [str , Any ],
430+ streaming_callback : Callable [[StreamingChunk ], None ],
431+ ):
432+ api_output : AsyncIterable [ChatCompletionStreamOutput ] = await self ._async_client .chat_completion (
433+ messages , stream = True , ** generation_kwargs
434+ )
435+
436+ generated_text = ""
437+ first_chunk_time = None
438+
439+ async for chunk in api_output :
440+ choice = chunk .choices [0 ]
441+
442+ text = choice .delta .content or ""
443+ generated_text += text
444+
445+ finish_reason = choice .finish_reason
446+
447+ meta : Dict [str , Any ] = {}
448+ if finish_reason :
449+ meta ["finish_reason" ] = finish_reason
450+
451+ if first_chunk_time is None :
452+ first_chunk_time = datetime .now ().isoformat ()
453+
454+ stream_chunk = StreamingChunk (text , meta )
455+ await streaming_callback (stream_chunk ) # type: ignore
456+
457+ meta .update (
458+ {
459+ "model" : self ._async_client .model ,
460+ "finish_reason" : finish_reason ,
461+ "index" : 0 ,
462+ "usage" : {"prompt_tokens" : 0 , "completion_tokens" : 0 },
463+ "completion_start_time" : first_chunk_time ,
464+ }
465+ )
466+
467+ message = ChatMessage .from_assistant (text = generated_text , meta = meta )
468+ return {"replies" : [message ]}
469+
470+ async def _run_non_streaming_async (
471+ self ,
472+ messages : List [Dict [str , str ]],
473+ generation_kwargs : Dict [str , Any ],
474+ tools : Optional [List ["ChatCompletionInputTool" ]] = None ,
475+ ) -> Dict [str , List [ChatMessage ]]:
476+ api_chat_output : ChatCompletionOutput = await self ._async_client .chat_completion (
477+ messages = messages , tools = tools , ** generation_kwargs
478+ )
479+
480+ if len (api_chat_output .choices ) == 0 :
481+ return {"replies" : []}
482+
483+ choice = api_chat_output .choices [0 ]
484+
485+ text = choice .message .content
486+ tool_calls = []
487+
488+ if hfapi_tool_calls := choice .message .tool_calls :
489+ for hfapi_tc in hfapi_tool_calls :
490+ tool_call = ToolCall (
491+ tool_name = hfapi_tc .function .name , arguments = hfapi_tc .function .arguments , id = hfapi_tc .id
492+ )
493+ tool_calls .append (tool_call )
494+
495+ meta : Dict [str , Any ] = {
496+ "model" : self ._async_client .model ,
497+ "finish_reason" : choice .finish_reason ,
498+ "index" : choice .index ,
499+ }
500+
501+ usage = {"prompt_tokens" : 0 , "completion_tokens" : 0 }
502+ if api_chat_output .usage :
503+ usage = {
504+ "prompt_tokens" : api_chat_output .usage .prompt_tokens ,
505+ "completion_tokens" : api_chat_output .usage .completion_tokens ,
506+ }
507+ meta ["usage" ] = usage
508+
509+ message = ChatMessage .from_assistant (text = text , tool_calls = tool_calls , meta = meta )
510+ return {"replies" : [message ]}
0 commit comments