Skip to content

Commit 5783272

Browse files
[langchain-ibm/feat]: Added support for Model Gateway (WatsonxLLM) (#90)
* Added support WatsonxLLM with Model Gateway * Change error for get_num_tokens when Model Gateway
1 parent 8646e0f commit 5783272

File tree

4 files changed

+375
-39
lines changed

4 files changed

+375
-39
lines changed

libs/ibm/langchain_ibm/llms.py

Lines changed: 164 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
1717
from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore
18+
from ibm_watsonx_ai.gateway import Gateway # type: ignore
1819
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore
1920
from langchain_core.callbacks import (
2021
AsyncCallbackManagerForLLMRun,
@@ -26,7 +27,12 @@
2627
from pydantic import ConfigDict, Field, SecretStr, model_validator
2728
from typing_extensions import Self
2829

29-
from langchain_ibm.utils import check_for_attribute, extract_params
30+
from langchain_ibm.utils import (
31+
async_gateway_error_handler,
32+
check_for_attribute,
33+
extract_params,
34+
gateway_error_handler,
35+
)
3036

3137
logger = logging.getLogger(__name__)
3238
textgen_valid_params = [
@@ -69,6 +75,18 @@ class WatsonxLLM(BaseLLM):
6975
model_id: Optional[str] = None
7076
"""Type of model to use."""
7177

78+
model: Optional[str] = None
79+
"""
80+
Name or alias of the foundation model to use.
81+
When using IBM’s watsonx.ai Model Gateway (public preview), you can specify any
82+
supported third-party model—OpenAI, Anthropic, NVIDIA, Cerebras, or IBM’s own
83+
Granite series—via a single, OpenAI-compatible interface. Models must be explicitly
84+
provisioned (opt-in) through the Gateway to ensure secure, vendor-agnostic access
85+
and easy switch-over without reconfiguration.
86+
87+
For more details on configuration and usage, see IBM watsonx Model Gateway docs: https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-model-gateway.html?context=wx&audience=wdp
88+
"""
89+
7290
deployment_id: Optional[str] = None
7391
"""Type of deployed model to use."""
7492

@@ -130,6 +148,10 @@ class WatsonxLLM(BaseLLM):
130148

131149
watsonx_model: ModelInference = Field(default=None, exclude=True) #: :meta private:
132150

151+
watsonx_model_gateway: Gateway = Field(
152+
default=None, exclude=True
153+
) #: :meta private:
154+
133155
watsonx_client: Optional[APIClient] = Field(default=None)
134156

135157
model_config = ConfigDict(
@@ -166,6 +188,12 @@ def lc_secrets(self) -> Dict[str, str]:
166188
@model_validator(mode="after")
167189
def validate_environment(self) -> Self:
168190
"""Validate that credentials and python package exists in environment."""
191+
if self.watsonx_model_gateway is not None:
192+
raise NotImplementedError(
193+
"Passing the 'watsonx_model_gateway' parameter to the WatsonxLLM "
194+
"constructor is not supported yet."
195+
)
196+
169197
if isinstance(self.watsonx_model, (ModelInference, Model)):
170198
self.model_id = getattr(self.watsonx_model, "model_id")
171199
self.deployment_id = getattr(self.watsonx_model, "deployment_id", "")
@@ -179,18 +207,38 @@ def validate_environment(self) -> Self:
179207
self.params = getattr(self.watsonx_model, "params")
180208

181209
elif isinstance(self.watsonx_client, APIClient):
182-
watsonx_model = ModelInference(
183-
model_id=self.model_id,
184-
deployment_id=self.deployment_id,
185-
params=self.params,
186-
api_client=self.watsonx_client,
187-
project_id=self.project_id,
188-
space_id=self.space_id,
189-
verify=self.verify,
190-
)
191-
self.watsonx_model = watsonx_model
210+
if sum(map(bool, (self.model, self.model_id, self.deployment_id))) != 1:
211+
raise ValueError(
212+
"The parameters 'model', 'model_id' and 'deployment_id' are "
213+
"mutually exclusive. Please specify exactly one of these "
214+
"parameters when initializing WatsonxLLM."
215+
)
216+
if self.model is not None:
217+
watsonx_model_gateway = Gateway(
218+
api_client=self.watsonx_client,
219+
verify=self.verify,
220+
)
221+
self.watsonx_model_gateway = watsonx_model_gateway
222+
else:
223+
watsonx_model = ModelInference(
224+
model_id=self.model_id,
225+
deployment_id=self.deployment_id,
226+
params=self.params,
227+
api_client=self.watsonx_client,
228+
project_id=self.project_id,
229+
space_id=self.space_id,
230+
verify=self.verify,
231+
)
232+
self.watsonx_model = watsonx_model
192233

193234
else:
235+
if sum(map(bool, (self.model, self.model_id, self.deployment_id))) != 1:
236+
raise ValueError(
237+
"The parameters 'model', 'model_id' and 'deployment_id' are "
238+
"mutually exclusive. Please specify exactly one of these "
239+
"parameters when initializing WatsonxLLM."
240+
)
241+
194242
check_for_attribute(self.url, "url", "WATSONX_URL")
195243

196244
if "cloud.ibm.com" in self.url.get_secret_value():
@@ -239,19 +287,39 @@ def validate_environment(self) -> Self:
239287
version=self.version.get_secret_value() if self.version else None,
240288
verify=self.verify,
241289
)
242-
243-
watsonx_model = ModelInference(
244-
model_id=self.model_id,
245-
deployment_id=self.deployment_id,
246-
credentials=credentials,
247-
params=self.params,
248-
project_id=self.project_id,
249-
space_id=self.space_id,
250-
)
251-
self.watsonx_model = watsonx_model
290+
if self.model is not None:
291+
watsonx_model_gateway = Gateway(
292+
credentials=credentials,
293+
verify=self.verify,
294+
)
295+
self.watsonx_model_gateway = watsonx_model_gateway
296+
else:
297+
watsonx_model = ModelInference(
298+
model_id=self.model_id,
299+
deployment_id=self.deployment_id,
300+
credentials=credentials,
301+
params=self.params,
302+
project_id=self.project_id,
303+
space_id=self.space_id,
304+
)
305+
self.watsonx_model = watsonx_model
252306

253307
return self
254308

309+
@gateway_error_handler
310+
def _call_model_gateway(self, *, model: str, prompt: list, **params: Any) -> Any:
311+
return self.watsonx_model_gateway.completions.create(
312+
model=model, prompt=prompt, **params
313+
)
314+
315+
@async_gateway_error_handler
316+
async def _acall_model_gateway(
317+
self, *, model: str, prompt: list, **params: Any
318+
) -> Any:
319+
return await self.watsonx_model_gateway.completions.acreate(
320+
model=model, prompt=prompt, **params
321+
)
322+
255323
@property
256324
def _identifying_params(self) -> Mapping[str, Any]:
257325
"""Get the identifying parameters."""
@@ -361,6 +429,30 @@ def _create_llm_result(self, response: List[dict]) -> LLMResult:
361429
}
362430
return LLMResult(generations=generations, llm_output=llm_output)
363431

432+
def _create_llm_gateway_result(self, response: dict) -> LLMResult:
433+
"""Create the LLMResult from the choices and prompts."""
434+
choices = response["choices"]
435+
436+
generations = [
437+
[
438+
Generation(
439+
text=choice["text"],
440+
generation_info=dict(
441+
finish_reason=choice.get("finish_reason"),
442+
logprobs=choice.get("logprobs"),
443+
),
444+
)
445+
]
446+
for choice in choices
447+
]
448+
449+
llm_output = {
450+
"token_usage": response["usage"]["total_tokens"],
451+
"model_id": self.model_id,
452+
"deployment_id": self.deployment_id,
453+
}
454+
return LLMResult(generations=generations, llm_output=llm_output)
455+
364456
def _stream_response_to_generation_chunk(
365457
self,
366458
stream_response: Dict[str, Any],
@@ -470,10 +562,17 @@ def _generate(
470562
return LLMResult(generations=[[generation]], llm_output=llm_output)
471563
return LLMResult(generations=[[generation]])
472564
else:
473-
response = self.watsonx_model.generate(
474-
prompt=prompts, params=params, **kwargs
475-
)
476-
return self._create_llm_result(response)
565+
if self.watsonx_model_gateway is not None:
566+
call_kwargs = {**kwargs, **params}
567+
response = self._call_model_gateway(
568+
model=self.model, prompt=prompts, **call_kwargs
569+
)
570+
return self._create_llm_gateway_result(response)
571+
else:
572+
response = self.watsonx_model.generate(
573+
prompt=prompts, params=params, **kwargs
574+
)
575+
return self._create_llm_result(response)
477576

478577
async def _agenerate(
479578
self,
@@ -491,14 +590,21 @@ async def _agenerate(
491590
prompts=prompts, stop=stop, run_manager=run_manager, **kwargs
492591
)
493592
else:
494-
responses = [
495-
await self.watsonx_model.agenerate(
496-
prompt=prompt, params=params, **kwargs
593+
if self.watsonx_model_gateway is not None:
594+
call_kwargs = {**kwargs, **params}
595+
responses = await self._acall_model_gateway(
596+
model=self.model, prompt=prompts, **call_kwargs
497597
)
498-
for prompt in prompts
499-
]
598+
return self._create_llm_gateway_result(responses)
599+
else:
600+
responses = [
601+
await self.watsonx_model.agenerate(
602+
prompt=prompt, params=params, **kwargs
603+
)
604+
for prompt in prompts
605+
]
500606

501-
return self._create_llm_result(responses)
607+
return self._create_llm_result(responses)
502608

503609
def _stream(
504610
self,
@@ -523,9 +629,16 @@ def _stream(
523629
"""
524630
params, kwargs = self._get_chat_params(stop=stop, **kwargs)
525631
params = self._validate_chat_params(params)
526-
for stream_resp in self.watsonx_model.generate_text_stream(
527-
prompt=prompt, params=params, **(kwargs | {"raw_response": True})
528-
):
632+
if self.watsonx_model_gateway is not None:
633+
call_kwargs = {**kwargs, **params, "stream": True}
634+
chunk_iter = self._call_model_gateway(
635+
model=self.model, prompt=prompt, **call_kwargs
636+
)
637+
else:
638+
chunk_iter = self.watsonx_model.generate_text_stream(
639+
prompt=prompt, params=params, **(kwargs | {"raw_response": True})
640+
)
641+
for stream_resp in chunk_iter:
529642
if not isinstance(stream_resp, dict):
530643
stream_resp = stream_resp.dict()
531644
chunk = self._stream_response_to_generation_chunk(stream_resp)
@@ -543,9 +656,17 @@ async def _astream(
543656
) -> AsyncIterator[GenerationChunk]:
544657
params, kwargs = self._get_chat_params(stop=stop, **kwargs)
545658
params = self._validate_chat_params(params)
546-
async for stream_resp in await self.watsonx_model.agenerate_stream(
547-
prompt=prompt, params=params
548-
):
659+
660+
if self.watsonx_model_gateway is not None:
661+
call_kwargs = {**kwargs, **params, "stream": True}
662+
chunk_iter = await self._acall_model_gateway(
663+
model=self.model, prompt=prompt, **call_kwargs
664+
)
665+
else:
666+
chunk_iter = await self.watsonx_model.agenerate_stream(
667+
prompt=prompt, params=params
668+
)
669+
async for stream_resp in chunk_iter:
549670
if not isinstance(stream_resp, dict):
550671
stream_resp = stream_resp.dict()
551672
chunk = self._stream_response_to_generation_chunk(stream_resp)
@@ -555,7 +676,12 @@ async def _astream(
555676
yield chunk
556677

557678
def get_num_tokens(self, text: str) -> int:
558-
response = self.watsonx_model.tokenize(text, return_tokens=False)
679+
if self.watsonx_model_gateway is not None:
680+
raise NotImplementedError(
681+
"Tokenize endpoint is not supported by IBM Model Gateway endpoint."
682+
)
683+
else:
684+
response = self.watsonx_model.tokenize(text, return_tokens=False)
559685
return response["result"]["token_count"]
560686

561687
def get_token_ids(self, text: str) -> List[int]:

0 commit comments

Comments
 (0)