Skip to content

Commit 953e021

Browse files
[langchain-ibm/feat]: Added support for Model Gateway (Embeddings), Added async methods (#88)
* Added support for Model Gateway Embeddings, added async methods * Added more informing docstring under model parameter * Added warning when user get ApiRequestFailure * Added warning for watsonx_embed_gateway
1 parent 75682d3 commit 953e021

File tree

9 files changed

+535
-161
lines changed

9 files changed

+535
-161
lines changed

libs/ibm/langchain_ibm/chat_models.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,11 @@
8282
from typing_extensions import Self
8383

8484
from langchain_ibm.utils import (
85+
async_gateway_error_handler,
8586
check_duplicate_chat_params,
8687
check_for_attribute,
8788
extract_chat_params,
89+
gateway_error_handler,
8890
)
8991

9092
logger = logging.getLogger(__name__)
@@ -430,7 +432,16 @@ class ChatWatsonx(BaseChatModel):
430432
"""Type of model to use."""
431433

432434
model: Optional[str] = None
433-
"""Name of model for given provider or alias."""
435+
"""
436+
Name or alias of the foundation model to use.
437+
When using IBM’s watsonx.ai Model Gateway (public preview), you can specify any
438+
supported third-party model—OpenAI, Anthropic, NVIDIA, Cerebras, or IBM’s own
439+
Granite series—via a single, OpenAI-compatible interface. Models must be explicitly
440+
provisioned (opt-in) through the Gateway to ensure secure, vendor-agnostic access
441+
and easy switch-over without reconfiguration.
442+
443+
For more details on configuration and usage, see IBM watsonx Model Gateway docs: https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-model-gateway.html?context=wx&audience=wdp
444+
"""
434445

435446
deployment_id: Optional[str] = None
436447
"""Type of deployed model to use."""
@@ -753,6 +764,20 @@ def validate_environment(self) -> Self:
753764

754765
return self
755766

767+
@gateway_error_handler
768+
def _call_model_gateway(self, *, model: str, messages: list, **params: Any) -> Any:
769+
return self.watsonx_model_gateway.chat.completions.create(
770+
model=model, messages=messages, **params
771+
)
772+
773+
@async_gateway_error_handler
774+
async def _acall_model_gateway(
775+
self, *, model: str, messages: list, **params: Any
776+
) -> Any:
777+
return await self.watsonx_model_gateway.chat.completions.acreate(
778+
model=model, messages=messages, **params
779+
)
780+
756781
def _generate(
757782
self,
758783
messages: List[BaseMessage],
@@ -769,8 +794,9 @@ def _generate(
769794
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
770795
updated_params = self._merge_params(params, kwargs)
771796
if self.watsonx_model_gateway is not None:
772-
response = self.watsonx_model_gateway.chat.completions.create(
773-
model=self.model, messages=message_dicts, **(kwargs | updated_params)
797+
call_kwargs = {**kwargs, **updated_params}
798+
response = self._call_model_gateway(
799+
model=self.model, messages=message_dicts, **call_kwargs
774800
)
775801
else:
776802
response = self.watsonx_model.chat(
@@ -794,8 +820,9 @@ async def _agenerate(
794820
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
795821
updated_params = self._merge_params(params, kwargs)
796822
if self.watsonx_model_gateway is not None:
797-
response = await self.watsonx_model_gateway.chat.completions.acreate(
798-
model=self.model, messages=message_dicts, **(kwargs | updated_params)
823+
call_kwargs = {**kwargs, **updated_params}
824+
response = await self._acall_model_gateway(
825+
model=self.model, messages=message_dicts, **call_kwargs
799826
)
800827
else:
801828
response = await self.watsonx_model.achat(
@@ -815,7 +842,7 @@ def _stream(
815842

816843
if self.watsonx_model_gateway is not None:
817844
call_kwargs = {**kwargs, **updated_params, "stream": True}
818-
chunk_iter = self.watsonx_model_gateway.chat.completions.create(
845+
chunk_iter = self._call_model_gateway(
819846
model=self.model, messages=message_dicts, **call_kwargs
820847
)
821848
else:
@@ -872,7 +899,7 @@ async def _astream(
872899

873900
if self.watsonx_model_gateway is not None:
874901
call_kwargs = {**kwargs, **updated_params, "stream": True}
875-
chunk_iter = await self.watsonx_model_gateway.chat.completions.acreate(
902+
chunk_iter = await self._acall_model_gateway(
876903
model=self.model, messages=message_dicts, **call_kwargs
877904
)
878905
else:

libs/ibm/langchain_ibm/embeddings.py

Lines changed: 121 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33

44
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
55
from ibm_watsonx_ai.foundation_models.embeddings import Embeddings # type: ignore
6+
from ibm_watsonx_ai.gateway import Gateway # type: ignore
67
from langchain_core.embeddings import Embeddings as LangChainEmbeddings
78
from langchain_core.utils.utils import secret_from_env
89
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
910
from typing_extensions import Self
1011

11-
from langchain_ibm.utils import check_for_attribute, extract_params
12+
from langchain_ibm.utils import (
13+
async_gateway_error_handler,
14+
check_for_attribute,
15+
extract_params,
16+
gateway_error_handler,
17+
)
1218

1319
logger = logging.getLogger(__name__)
1420

@@ -19,6 +25,18 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
1925
model_id: Optional[str] = None
2026
"""Type of model to use."""
2127

28+
model: Optional[str] = None
29+
"""
30+
Name or alias of the foundation model to use.
31+
When using IBM’s watsonx.ai Model Gateway (public preview), you can specify any
32+
supported third-party model—OpenAI, Anthropic, NVIDIA, Cerebras, or IBM’s own
33+
Granite series—via a single, OpenAI-compatible interface. Models must be explicitly
34+
provisioned (opt-in) through the Gateway to ensure secure, vendor-agnostic access
35+
and easy switch-over without reconfiguration.
36+
37+
For more details on configuration and usage, see IBM watsonx Model Gateway docs: https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-model-gateway.html?context=wx&audience=wdp
38+
"""
39+
2240
project_id: Optional[str] = None
2341
"""ID of the Watson Studio project."""
2442

@@ -70,10 +88,15 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
7088
* the path to a CA_BUNDLE file
7189
* the path of directory with certificates of trusted CAs
7290
* True - default path to truststore will be taken
73-
* False - no verification will be made"""
91+
* False - no verification will be made
92+
"""
7493

7594
watsonx_embed: Embeddings = Field(default=None) #: :meta private:
7695

96+
watsonx_embed_gateway: Gateway = Field(
97+
default=None, exclude=True
98+
) #: :meta private:
99+
77100
watsonx_client: Optional[APIClient] = Field(default=None) #: :meta private:
78101

79102
model_config = ConfigDict(
@@ -85,6 +108,12 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
85108
@model_validator(mode="after")
86109
def validate_environment(self) -> Self:
87110
"""Validate that credentials and python package exists in environment."""
111+
if self.watsonx_embed_gateway is not None:
112+
raise NotImplementedError(
113+
"Passing the 'watsonx_embed_gateway' parameter to the "
114+
"WatsonxEmbeddings constructor is not supported yet."
115+
)
116+
88117
if isinstance(self.watsonx_embed, Embeddings):
89118
self.model_id = getattr(self.watsonx_embed, "model_id")
90119
self.project_id = getattr(
@@ -98,17 +127,36 @@ def validate_environment(self) -> Self:
98127
self.params = getattr(self.watsonx_embed, "params")
99128

100129
elif isinstance(self.watsonx_client, APIClient):
101-
watsonx_embed = Embeddings(
102-
model_id=self.model_id,
103-
params=self.params,
104-
api_client=self.watsonx_client,
105-
project_id=self.project_id,
106-
space_id=self.space_id,
107-
verify=self.verify,
108-
)
109-
self.watsonx_embed = watsonx_embed
130+
if sum(map(bool, (self.model, self.model_id))) != 1:
131+
raise ValueError(
132+
"The parameters 'model' and 'model_id' are mutually exclusive. "
133+
"Please specify exactly one of these parameters when "
134+
"initializing WatsonxEmbeddings."
135+
)
136+
if self.model is not None:
137+
watsonx_embed_gateway = Gateway(
138+
api_client=self.watsonx_client,
139+
verify=self.verify,
140+
)
141+
self.watsonx_embed_gateway = watsonx_embed_gateway
142+
else:
143+
watsonx_embed = Embeddings(
144+
model_id=self.model_id,
145+
params=self.params,
146+
api_client=self.watsonx_client,
147+
project_id=self.project_id,
148+
space_id=self.space_id,
149+
verify=self.verify,
150+
)
151+
self.watsonx_embed = watsonx_embed
110152

111153
else:
154+
if sum(map(bool, (self.model, self.model_id))) != 1:
155+
raise ValueError(
156+
"The parameters 'model' and 'model_id' are mutually exclusive. "
157+
"Please specify exactly one of these parameters when "
158+
"initializing WatsonxEmbeddings."
159+
)
112160
check_for_attribute(self.url, "url", "WATSONX_URL")
113161

114162
if "cloud.ibm.com" in self.url.get_secret_value():
@@ -157,26 +205,77 @@ def validate_environment(self) -> Self:
157205
version=self.version.get_secret_value() if self.version else None,
158206
verify=self.verify,
159207
)
208+
if self.model is not None:
209+
watsonx_embed_gateway = Gateway(
210+
credentials=credentials,
211+
verify=self.verify,
212+
)
213+
self.watsonx_embed_gateway = watsonx_embed_gateway
160214

161-
watsonx_embed = Embeddings(
162-
model_id=self.model_id,
163-
params=self.params,
164-
credentials=credentials,
165-
project_id=self.project_id,
166-
space_id=self.space_id,
167-
)
215+
else:
216+
watsonx_embed = Embeddings(
217+
model_id=self.model_id,
218+
params=self.params,
219+
credentials=credentials,
220+
project_id=self.project_id,
221+
space_id=self.space_id,
222+
)
168223

169-
self.watsonx_embed = watsonx_embed
224+
self.watsonx_embed = watsonx_embed
170225

171226
return self
172227

228+
@gateway_error_handler
229+
def _call_model_gateway(
230+
self, *, model: str, texts: List[str], **params: Any
231+
) -> Any:
232+
return self.watsonx_embed_gateway.embeddings.create(
233+
model=model, input=texts, **params
234+
)
235+
236+
@async_gateway_error_handler
237+
async def _acall_model_gateway(
238+
self, *, model: str, texts: List[str], **params: Any
239+
) -> Any:
240+
return await self.watsonx_embed_gateway.embeddings.acreate(
241+
model=model, input=texts, **params
242+
)
243+
173244
def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
174245
"""Embed search docs."""
175246
params = extract_params(kwargs, self.params)
176-
return self.watsonx_embed.embed_documents(
177-
texts=texts, **(kwargs | {"params": params})
178-
)
247+
if self.watsonx_embed_gateway is not None:
248+
call_kwargs = {**kwargs, **params}
249+
embed_response = self._call_model_gateway(
250+
model=self.model, texts=texts, **call_kwargs
251+
)
252+
return [embedding["embedding"] for embedding in embed_response["data"]]
253+
else:
254+
return self.watsonx_embed.embed_documents(
255+
texts=texts, **(kwargs | {"params": params})
256+
)
257+
258+
async def aembed_documents(
259+
self, texts: List[str], **kwargs: Any
260+
) -> List[List[float]]:
261+
"""Asynchronous Embed search docs."""
262+
params = extract_params(kwargs, self.params)
263+
if self.watsonx_embed_gateway is not None:
264+
call_kwargs = {**kwargs, **params}
265+
embed_response = await self._acall_model_gateway(
266+
model=self.model, texts=texts, **call_kwargs
267+
)
268+
return [embedding["embedding"] for embedding in embed_response["data"]]
269+
else:
270+
return await self.watsonx_embed.aembed_documents(
271+
texts=texts, **(kwargs | {"params": params})
272+
)
179273

180274
def embed_query(self, text: str, **kwargs: Any) -> List[float]:
181275
"""Embed query text."""
182276
return self.embed_documents([text], **kwargs)[0]
277+
278+
async def aembed_query(self, text: str, **kwargs: Any) -> List[float]:
279+
"""Asynchronous Embed query text."""
280+
embeddings = await self.aembed_documents([text], **kwargs)
281+
return embeddings[0]

libs/ibm/langchain_ibm/utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
import functools
2+
import logging
13
from copy import deepcopy
2-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
4+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
35

46
from ibm_watsonx_ai.foundation_models.schema import BaseSchema # type: ignore
7+
from ibm_watsonx_ai.wml_client_error import ApiRequestFailure # type: ignore
58
from pydantic import SecretStr
69

710
if TYPE_CHECKING:
811
from langchain_ibm.toolkit import WatsonxTool
912

13+
logger = logging.getLogger(__name__)
14+
1015

1116
def check_for_attribute(value: SecretStr | None, key: str, env_key: str) -> None:
1217
if not value or not value.get_secret_value():
@@ -135,3 +140,48 @@ def parse_parameters(input_schema: dict | None) -> dict:
135140
},
136141
}
137142
return watsonx_tool
143+
144+
145+
def gateway_error_handler(func: Callable) -> Callable:
146+
"""Decorator to catch ApiRequestFailure on Model Gateway calls
147+
and log a uniform warning."""
148+
149+
@functools.wraps(func)
150+
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
151+
try:
152+
return func(self, *args, **kwargs)
153+
except ApiRequestFailure:
154+
if any(
155+
hasattr(self, attr)
156+
for attr in ("watsonx_model_gateway", "watsonx_embed_gateway")
157+
):
158+
logger.warning(
159+
"You are calling the Model Gateway endpoint using the 'model' "
160+
"parameter. Please ensure this model is registered with the "
161+
"Gateway. If you intend to use a watsonx.ai–hosted model, pass "
162+
"'model_id' instead of 'model'."
163+
)
164+
raise
165+
166+
return wrapper
167+
168+
169+
def async_gateway_error_handler(func: Callable) -> Callable:
170+
"""Async decorator to catch ApiRequestFailure on Model Gateway calls
171+
and log a uniform warning."""
172+
173+
@functools.wraps(func)
174+
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
175+
try:
176+
return await func(self, *args, **kwargs)
177+
except ApiRequestFailure:
178+
if getattr(self, "watsonx_model_gateway", None) is not None:
179+
logger.warning(
180+
"You are calling the Model Gateway endpoint using the 'model' "
181+
"parameter. Please ensure this model is registered with the "
182+
"Gateway. If you intend to use a watsonx.ai–hosted model, pass "
183+
"'model_id' instead of 'model'."
184+
)
185+
raise
186+
187+
return wrapper

0 commit comments

Comments
 (0)