|
6 | 6 | from operator import itemgetter
|
7 | 7 | from typing import (
|
8 | 8 | Any,
|
| 9 | + AsyncIterator, |
9 | 10 | Callable,
|
10 | 11 | Dict,
|
11 | 12 | Iterator,
|
|
27 | 28 | BaseSchema,
|
28 | 29 | TextChatParameters,
|
29 | 30 | )
|
30 |
| -from langchain_core.callbacks import CallbackManagerForLLMRun |
| 31 | +from langchain_core.callbacks import ( |
| 32 | + AsyncCallbackManagerForLLMRun, |
| 33 | + CallbackManagerForLLMRun, |
| 34 | +) |
31 | 35 | from langchain_core.language_models import LanguageModelInput
|
32 | 36 | from langchain_core.language_models.chat_models import (
|
33 | 37 | BaseChatModel,
|
34 | 38 | LangSmithParams,
|
| 39 | + agenerate_from_stream, |
35 | 40 | generate_from_stream,
|
36 | 41 | )
|
37 | 42 | from langchain_core.messages import (
|
@@ -718,6 +723,27 @@ def _generate(
|
718 | 723 | )
|
719 | 724 | return self._create_chat_result(response)
|
720 | 725 |
|
| 726 | + async def _agenerate( |
| 727 | + self, |
| 728 | + messages: List[BaseMessage], |
| 729 | + stop: Optional[List[str]] = None, |
| 730 | + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
| 731 | + **kwargs: Any, |
| 732 | + ) -> ChatResult: |
| 733 | + if self.streaming: |
| 734 | + stream_iter = self._astream( |
| 735 | + messages, stop=stop, run_manager=run_manager, **kwargs |
| 736 | + ) |
| 737 | + return await agenerate_from_stream(stream_iter) |
| 738 | + |
| 739 | + message_dicts, params = self._create_message_dicts(messages, stop, **kwargs) |
| 740 | + updated_params = self._merge_params(params, kwargs) |
| 741 | + |
| 742 | + response = await self.watsonx_model.achat( |
| 743 | + messages=message_dicts, **(kwargs | {"params": updated_params}) |
| 744 | + ) |
| 745 | + return self._create_chat_result(response) |
| 746 | + |
721 | 747 | def _stream(
|
722 | 748 | self,
|
723 | 749 | messages: List[BaseMessage],
|
@@ -768,6 +794,62 @@ def _stream(
|
768 | 794 |
|
769 | 795 | yield generation_chunk
|
770 | 796 |
|
| 797 | + async def _astream( |
| 798 | + self, |
| 799 | + messages: List[BaseMessage], |
| 800 | + stop: Optional[List[str]] = None, |
| 801 | + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
| 802 | + **kwargs: Any, |
| 803 | + ) -> AsyncIterator[ChatGenerationChunk]: |
| 804 | + message_dicts, params = self._create_message_dicts(messages, stop, **kwargs) |
| 805 | + updated_params = self._merge_params(params, kwargs) |
| 806 | + |
| 807 | + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk |
| 808 | + |
| 809 | + is_first_tool_chunk = True |
| 810 | + _prompt_tokens_included = False |
| 811 | + |
| 812 | + response = await self.watsonx_model.achat_stream( |
| 813 | + messages=message_dicts, **(kwargs | {"params": updated_params}) |
| 814 | + ) |
| 815 | + async for chunk in response: |
| 816 | + if not isinstance(chunk, dict): |
| 817 | + chunk = chunk.model_dump() |
| 818 | + generation_chunk = _convert_chunk_to_generation_chunk( |
| 819 | + chunk, |
| 820 | + default_chunk_class, |
| 821 | + is_first_tool_chunk, |
| 822 | + _prompt_tokens_included, |
| 823 | + ) |
| 824 | + if generation_chunk is None: |
| 825 | + continue |
| 826 | + |
| 827 | + if ( |
| 828 | + hasattr(generation_chunk.message, "usage_metadata") |
| 829 | + and generation_chunk.message.usage_metadata |
| 830 | + ): |
| 831 | + _prompt_tokens_included = True |
| 832 | + default_chunk_class = generation_chunk.message.__class__ |
| 833 | + logprobs = (generation_chunk.generation_info or {}).get("logprobs") |
| 834 | + if run_manager: |
| 835 | + await run_manager.on_llm_new_token( |
| 836 | + generation_chunk.text, |
| 837 | + chunk=generation_chunk, |
| 838 | + logprobs=logprobs, |
| 839 | + ) |
| 840 | + if hasattr(generation_chunk.message, "tool_calls") and isinstance( |
| 841 | + generation_chunk.message.tool_calls, list |
| 842 | + ): |
| 843 | + first_tool_call = ( |
| 844 | + generation_chunk.message.tool_calls[0] |
| 845 | + if generation_chunk.message.tool_calls |
| 846 | + else None |
| 847 | + ) |
| 848 | + if isinstance(first_tool_call, dict) and first_tool_call.get("name"): |
| 849 | + is_first_tool_chunk = False |
| 850 | + |
| 851 | + yield generation_chunk |
| 852 | + |
771 | 853 | @staticmethod
|
772 | 854 | def _merge_params(params: dict, kwargs: dict) -> dict:
|
773 | 855 | param_updates = {}
|
|
0 commit comments