Skip to content

Commit c23b7d7

Browse files
Added native SDK async method for ChatWatsonx (#76)
* Added native sdk async method for CHatWatsonx * Fix unit tests * Fix Integration tests * Added integration tests * Added _astream method to llm
1 parent 5b9a79b commit c23b7d7

File tree

11 files changed

+444
-244
lines changed

11 files changed

+444
-244
lines changed

libs/ibm/langchain_ibm/chat_models.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from operator import itemgetter
77
from typing import (
88
Any,
9+
AsyncIterator,
910
Callable,
1011
Dict,
1112
Iterator,
@@ -27,11 +28,15 @@
2728
BaseSchema,
2829
TextChatParameters,
2930
)
30-
from langchain_core.callbacks import CallbackManagerForLLMRun
31+
from langchain_core.callbacks import (
32+
AsyncCallbackManagerForLLMRun,
33+
CallbackManagerForLLMRun,
34+
)
3135
from langchain_core.language_models import LanguageModelInput
3236
from langchain_core.language_models.chat_models import (
3337
BaseChatModel,
3438
LangSmithParams,
39+
agenerate_from_stream,
3540
generate_from_stream,
3641
)
3742
from langchain_core.messages import (
@@ -718,6 +723,27 @@ def _generate(
718723
)
719724
return self._create_chat_result(response)
720725

726+
async def _agenerate(
727+
self,
728+
messages: List[BaseMessage],
729+
stop: Optional[List[str]] = None,
730+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
731+
**kwargs: Any,
732+
) -> ChatResult:
733+
if self.streaming:
734+
stream_iter = self._astream(
735+
messages, stop=stop, run_manager=run_manager, **kwargs
736+
)
737+
return await agenerate_from_stream(stream_iter)
738+
739+
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
740+
updated_params = self._merge_params(params, kwargs)
741+
742+
response = await self.watsonx_model.achat(
743+
messages=message_dicts, **(kwargs | {"params": updated_params})
744+
)
745+
return self._create_chat_result(response)
746+
721747
def _stream(
722748
self,
723749
messages: List[BaseMessage],
@@ -768,6 +794,62 @@ def _stream(
768794

769795
yield generation_chunk
770796

797+
async def _astream(
798+
self,
799+
messages: List[BaseMessage],
800+
stop: Optional[List[str]] = None,
801+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
802+
**kwargs: Any,
803+
) -> AsyncIterator[ChatGenerationChunk]:
804+
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
805+
updated_params = self._merge_params(params, kwargs)
806+
807+
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
808+
809+
is_first_tool_chunk = True
810+
_prompt_tokens_included = False
811+
812+
response = await self.watsonx_model.achat_stream(
813+
messages=message_dicts, **(kwargs | {"params": updated_params})
814+
)
815+
async for chunk in response:
816+
if not isinstance(chunk, dict):
817+
chunk = chunk.model_dump()
818+
generation_chunk = _convert_chunk_to_generation_chunk(
819+
chunk,
820+
default_chunk_class,
821+
is_first_tool_chunk,
822+
_prompt_tokens_included,
823+
)
824+
if generation_chunk is None:
825+
continue
826+
827+
if (
828+
hasattr(generation_chunk.message, "usage_metadata")
829+
and generation_chunk.message.usage_metadata
830+
):
831+
_prompt_tokens_included = True
832+
default_chunk_class = generation_chunk.message.__class__
833+
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
834+
if run_manager:
835+
await run_manager.on_llm_new_token(
836+
generation_chunk.text,
837+
chunk=generation_chunk,
838+
logprobs=logprobs,
839+
)
840+
if hasattr(generation_chunk.message, "tool_calls") and isinstance(
841+
generation_chunk.message.tool_calls, list
842+
):
843+
first_tool_call = (
844+
generation_chunk.message.tool_calls[0]
845+
if generation_chunk.message.tool_calls
846+
else None
847+
)
848+
if isinstance(first_tool_call, dict) and first_tool_call.get("name"):
849+
is_first_tool_chunk = False
850+
851+
yield generation_chunk
852+
771853
@staticmethod
772854
def _merge_params(params: dict, kwargs: dict) -> dict:
773855
param_updates = {}

libs/ibm/langchain_ibm/llms.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union
4+
from typing import (
5+
Any,
6+
AsyncIterator,
7+
Dict,
8+
Iterator,
9+
List,
10+
Mapping,
11+
Optional,
12+
Tuple,
13+
Union,
14+
)
515

616
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
717
from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore
@@ -524,6 +534,26 @@ def _stream(
524534
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
525535
yield chunk
526536

537+
async def _astream(
538+
self,
539+
prompt: str,
540+
stop: Optional[List[str]] = None,
541+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
542+
**kwargs: Any,
543+
) -> AsyncIterator[GenerationChunk]:
544+
params, kwargs = self._get_chat_params(stop=stop, **kwargs)
545+
params = self._validate_chat_params(params)
546+
async for stream_resp in await self.watsonx_model.agenerate_stream(
547+
prompt=prompt, params=params
548+
):
549+
if not isinstance(stream_resp, dict):
550+
stream_resp = stream_resp.dict()
551+
chunk = self._stream_response_to_generation_chunk(stream_resp)
552+
553+
if run_manager:
554+
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
555+
yield chunk
556+
527557
def get_num_tokens(self, text: str) -> int:
528558
response = self.watsonx_model.tokenize(text, return_tokens=False)
529559
return response["result"]["token_count"]

0 commit comments

Comments
 (0)