Skip to content

Commit d598166

Browse files
Added support for Gateway in langchain-ibm (#79)
* Added draft version of support model gateway in langchain * updated gitignore * Added support for providing ModelInference and Gateway object when initializing ChatWatsonx, added unit tests * Added support init ChatWatsonx with Gateway and Credentials * Added support for async method in model gateway * poetry update * remove passing Gateway as watsonx_model * poetry update * distinguishing watsonx_model and watsonx_model_gateway attribute * fix test_watsonxllm_stream test * fix * patch version * simplify implemenatation with map
1 parent 14db57b commit d598166

File tree

7 files changed

+433
-123
lines changed

7 files changed

+433
-123
lines changed

libs/ibm/langchain_ibm/chat_models.py

Lines changed: 116 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
BaseSchema,
2929
TextChatParameters,
3030
)
31+
from ibm_watsonx_ai.gateway import Gateway # type: ignore
3132
from langchain_core.callbacks import (
3233
AsyncCallbackManagerForLLMRun,
3334
CallbackManagerForLLMRun,
@@ -428,6 +429,9 @@ class ChatWatsonx(BaseChatModel):
428429
model_id: Optional[str] = None
429430
"""Type of model to use."""
430431

432+
model: Optional[str] = None
433+
"""Name of model for given provider or alias."""
434+
431435
deployment_id: Optional[str] = None
432436
"""Type of deployed model to use."""
433437

@@ -558,6 +562,10 @@ class ChatWatsonx(BaseChatModel):
558562

559563
watsonx_model: ModelInference = Field(default=None, exclude=True) #: :meta private:
560564

565+
watsonx_model_gateway: Gateway = Field(
566+
default=None, exclude=True
567+
) #: :meta private:
568+
561569
watsonx_client: Optional[APIClient] = Field(default=None, exclude=True)
562570

563571
model_config = ConfigDict(populate_by_name=True)
@@ -624,21 +632,58 @@ def validate_environment(self) -> Self:
624632
if v is not None
625633
}
626634
)
635+
if self.watsonx_model_gateway is not None:
636+
raise NotImplementedError(
637+
"Passing the 'watsonx_model_gateway' parameter to the ChatWatsonx "
638+
"constructor is not supported yet."
639+
)
627640

628-
if isinstance(self.watsonx_client, APIClient):
629-
watsonx_model = ModelInference(
630-
model_id=self.model_id,
631-
deployment_id=self.deployment_id,
632-
params=self.params,
633-
api_client=self.watsonx_client,
634-
project_id=self.project_id,
635-
space_id=self.space_id,
636-
verify=self.verify,
637-
validate=self.validate_model,
641+
if isinstance(self.watsonx_model, ModelInference):
642+
self.model_id = getattr(self.watsonx_model, "model_id")
643+
self.deployment_id = getattr(self.watsonx_model, "deployment_id", "")
644+
self.project_id = getattr(
645+
getattr(self.watsonx_model, "_client"),
646+
"default_project_id",
647+
)
648+
self.space_id = getattr(
649+
getattr(self.watsonx_model, "_client"), "default_space_id"
638650
)
639-
self.watsonx_model = watsonx_model
651+
self.params = getattr(self.watsonx_model, "params")
652+
self.watsonx_client = getattr(self.watsonx_model, "_client")
640653

654+
elif isinstance(self.watsonx_client, APIClient):
655+
if sum(map(bool, (self.model, self.model_id, self.deployment_id))) != 1:
656+
raise ValueError(
657+
"The parameters 'model', 'model_id' and 'deployment_id' are "
658+
"mutually exclusive. Please specify exactly one of these "
659+
"parameters when initializing ChatWatsonx."
660+
)
661+
if self.model is not None:
662+
watsonx_model_gateway = Gateway(
663+
api_client=self.watsonx_client,
664+
verify=self.verify,
665+
)
666+
self.watsonx_model_gateway = watsonx_model_gateway
667+
else:
668+
watsonx_model = ModelInference(
669+
model_id=self.model_id,
670+
deployment_id=self.deployment_id,
671+
params=self.params,
672+
api_client=self.watsonx_client,
673+
project_id=self.project_id,
674+
space_id=self.space_id,
675+
verify=self.verify,
676+
validate=self.validate_model,
677+
)
678+
self.watsonx_model = watsonx_model
641679
else:
680+
if sum(map(bool, (self.model, self.model_id, self.deployment_id))) != 1:
681+
raise ValueError(
682+
"The parameters 'model', 'model_id' and 'deployment_id' are "
683+
"mutually exclusive. Please specify exactly one of these "
684+
"parameters when initializing ChatWatsonx."
685+
)
686+
642687
check_for_attribute(self.url, "url", "WATSONX_URL")
643688

644689
if "cloud.ibm.com" in self.url.get_secret_value():
@@ -687,18 +732,24 @@ def validate_environment(self) -> Self:
687732
version=self.version.get_secret_value() if self.version else None,
688733
verify=self.verify,
689734
)
690-
691-
watsonx_chat = ModelInference(
692-
model_id=self.model_id,
693-
deployment_id=self.deployment_id,
694-
credentials=credentials,
695-
params=self.params,
696-
project_id=self.project_id,
697-
space_id=self.space_id,
698-
verify=self.verify,
699-
validate=self.validate_model,
700-
)
701-
self.watsonx_model = watsonx_chat
735+
if self.model is not None:
736+
watsonx_model_gateway = Gateway(
737+
credentials=credentials,
738+
verify=self.verify,
739+
)
740+
self.watsonx_model_gateway = watsonx_model_gateway
741+
else:
742+
watsonx_model = ModelInference(
743+
model_id=self.model_id,
744+
deployment_id=self.deployment_id,
745+
credentials=credentials,
746+
params=self.params,
747+
project_id=self.project_id,
748+
space_id=self.space_id,
749+
verify=self.verify,
750+
validate=self.validate_model,
751+
)
752+
self.watsonx_model = watsonx_model
702753

703754
return self
704755

@@ -717,10 +768,14 @@ def _generate(
717768

718769
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
719770
updated_params = self._merge_params(params, kwargs)
720-
721-
response = self.watsonx_model.chat(
722-
messages=message_dicts, **(kwargs | {"params": updated_params})
723-
)
771+
if self.watsonx_model_gateway is not None:
772+
response = self.watsonx_model_gateway.chat.completions.create(
773+
model=self.model, messages=message_dicts, **(kwargs | updated_params)
774+
)
775+
else:
776+
response = self.watsonx_model.chat(
777+
messages=message_dicts, **(kwargs | {"params": updated_params})
778+
)
724779
return self._create_chat_result(response)
725780

726781
async def _agenerate(
@@ -738,10 +793,14 @@ async def _agenerate(
738793

739794
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
740795
updated_params = self._merge_params(params, kwargs)
741-
742-
response = await self.watsonx_model.achat(
743-
messages=message_dicts, **(kwargs | {"params": updated_params})
744-
)
796+
if self.watsonx_model_gateway is not None:
797+
response = await self.watsonx_model_gateway.chat.completions.acreate(
798+
model=self.model, messages=message_dicts, **(kwargs | updated_params)
799+
)
800+
else:
801+
response = await self.watsonx_model.achat(
802+
messages=message_dicts, **(kwargs | {"params": updated_params})
803+
)
745804
return self._create_chat_result(response)
746805

747806
def _stream(
@@ -754,16 +813,23 @@ def _stream(
754813
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
755814
updated_params = self._merge_params(params, kwargs)
756815

757-
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
816+
if self.watsonx_model_gateway is not None:
817+
call_kwargs = {**kwargs, **updated_params, "stream": True}
818+
chunk_iter = self.watsonx_model_gateway.chat.completions.create(
819+
model=self.model, messages=message_dicts, **call_kwargs
820+
)
821+
else:
822+
call_kwargs = {**kwargs, "params": updated_params}
823+
chunk_iter = self.watsonx_model.chat_stream(
824+
messages=message_dicts, **call_kwargs
825+
)
758826

827+
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
759828
is_first_tool_chunk = True
760829
_prompt_tokens_included = False
761830

762-
for chunk in self.watsonx_model.chat_stream(
763-
messages=message_dicts, **(kwargs | {"params": updated_params})
764-
):
765-
if not isinstance(chunk, dict):
766-
chunk = chunk.model_dump()
831+
for chunk in chunk_iter:
832+
chunk = chunk if isinstance(chunk, dict) else chunk.model_dump()
767833
generation_chunk = _convert_chunk_to_generation_chunk(
768834
chunk, default_chunk_class, is_first_tool_chunk, _prompt_tokens_included
769835
)
@@ -804,17 +870,23 @@ async def _astream(
804870
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
805871
updated_params = self._merge_params(params, kwargs)
806872

807-
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
873+
if self.watsonx_model_gateway is not None:
874+
call_kwargs = {**kwargs, **updated_params, "stream": True}
875+
chunk_iter = await self.watsonx_model_gateway.chat.completions.acreate(
876+
model=self.model, messages=message_dicts, **call_kwargs
877+
)
878+
else:
879+
call_kwargs = {**kwargs, "params": updated_params}
880+
chunk_iter = await self.watsonx_model.achat_stream(
881+
messages=message_dicts, **call_kwargs
882+
)
808883

884+
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
809885
is_first_tool_chunk = True
810886
_prompt_tokens_included = False
811887

812-
response = await self.watsonx_model.achat_stream(
813-
messages=message_dicts, **(kwargs | {"params": updated_params})
814-
)
815-
async for chunk in response:
816-
if not isinstance(chunk, dict):
817-
chunk = chunk.model_dump()
888+
async for chunk in chunk_iter:
889+
chunk = chunk if isinstance(chunk, dict) else chunk.model_dump()
818890
generation_chunk = _convert_chunk_to_generation_chunk(
819891
chunk,
820892
default_chunk_class,

libs/ibm/poetry.lock

Lines changed: 24 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/ibm/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "langchain-ibm"
3-
version = "0.3.13"
3+
version = "0.3.14"
44
description = "An integration package connecting IBM watsonx.ai and LangChain"
55
authors = ["IBM"]
66
readme = "README.md"
@@ -13,7 +13,7 @@ license = "MIT"
1313
[tool.poetry.dependencies]
1414
python = ">=3.10,<3.14"
1515
langchain-core = "^0.3.39"
16-
ibm-watsonx-ai = "^1.3.18"
16+
ibm-watsonx-ai = "^1.3.28"
1717

1818
[tool.poetry.group.test]
1919
optional = true

libs/ibm/tests/integration_tests/test_chat_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
MODEL_ID = "ibm/granite-34b-code-instruct"
2828
MODEL_ID_TOOL = "mistralai/mistral-large"
29+
MODEL_ID_TOOL_2 = "meta-llama/llama-3-3-70b-instruct"
2930

3031
PARAMS_WITH_MAX_TOKENS = {"max_tokens": 20}
3132

@@ -484,7 +485,7 @@ class Person(BaseModel):
484485
def test_chat_bind_tools_list_tool_choice_dict() -> None:
485486
"""Test that tool choice is respected just passing in True."""
486487
chat = ChatWatsonx(
487-
model_id=MODEL_ID_TOOL,
488+
model_id=MODEL_ID_TOOL_2,
488489
url=URL, # type: ignore[arg-type]
489490
project_id=WX_PROJECT_ID,
490491
params={"temperature": 0},

0 commit comments

Comments
 (0)