15
15
16
16
from ibm_watsonx_ai import APIClient , Credentials # type: ignore
17
17
from ibm_watsonx_ai .foundation_models import Model , ModelInference # type: ignore
18
+ from ibm_watsonx_ai .gateway import Gateway # type: ignore
18
19
from ibm_watsonx_ai .metanames import GenTextParamsMetaNames # type: ignore
19
20
from langchain_core .callbacks import (
20
21
AsyncCallbackManagerForLLMRun ,
26
27
from pydantic import ConfigDict , Field , SecretStr , model_validator
27
28
from typing_extensions import Self
28
29
29
- from langchain_ibm .utils import check_for_attribute , extract_params
30
+ from langchain_ibm .utils import (
31
+ async_gateway_error_handler ,
32
+ check_for_attribute ,
33
+ extract_params ,
34
+ gateway_error_handler ,
35
+ )
30
36
31
37
logger = logging .getLogger (__name__ )
32
38
textgen_valid_params = [
@@ -69,6 +75,18 @@ class WatsonxLLM(BaseLLM):
69
75
model_id : Optional [str ] = None
70
76
"""Type of model to use."""
71
77
78
+ model : Optional [str ] = None
79
+ """
80
+ Name or alias of the foundation model to use.
81
+ When using IBM’s watsonx.ai Model Gateway (public preview), you can specify any
82
+ supported third-party model—OpenAI, Anthropic, NVIDIA, Cerebras, or IBM’s own
83
+ Granite series—via a single, OpenAI-compatible interface. Models must be explicitly
84
+ provisioned (opt-in) through the Gateway to ensure secure, vendor-agnostic access
85
+ and easy switch-over without reconfiguration.
86
+
87
+ For more details on configuration and usage, see IBM watsonx Model Gateway docs: https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-model-gateway.html?context=wx&audience=wdp
88
+ """
89
+
72
90
deployment_id : Optional [str ] = None
73
91
"""Type of deployed model to use."""
74
92
@@ -130,6 +148,10 @@ class WatsonxLLM(BaseLLM):
130
148
131
149
watsonx_model : ModelInference = Field (default = None , exclude = True ) #: :meta private:
132
150
151
+ watsonx_model_gateway : Gateway = Field (
152
+ default = None , exclude = True
153
+ ) #: :meta private:
154
+
133
155
watsonx_client : Optional [APIClient ] = Field (default = None )
134
156
135
157
model_config = ConfigDict (
@@ -166,6 +188,12 @@ def lc_secrets(self) -> Dict[str, str]:
166
188
@model_validator (mode = "after" )
167
189
def validate_environment (self ) -> Self :
168
190
"""Validate that credentials and python package exists in environment."""
191
+ if self .watsonx_model_gateway is not None :
192
+ raise NotImplementedError (
193
+ "Passing the 'watsonx_model_gateway' parameter to the WatsonxLLM "
194
+ "constructor is not supported yet."
195
+ )
196
+
169
197
if isinstance (self .watsonx_model , (ModelInference , Model )):
170
198
self .model_id = getattr (self .watsonx_model , "model_id" )
171
199
self .deployment_id = getattr (self .watsonx_model , "deployment_id" , "" )
@@ -179,18 +207,38 @@ def validate_environment(self) -> Self:
179
207
self .params = getattr (self .watsonx_model , "params" )
180
208
181
209
elif isinstance (self .watsonx_client , APIClient ):
182
- watsonx_model = ModelInference (
183
- model_id = self .model_id ,
184
- deployment_id = self .deployment_id ,
185
- params = self .params ,
186
- api_client = self .watsonx_client ,
187
- project_id = self .project_id ,
188
- space_id = self .space_id ,
189
- verify = self .verify ,
190
- )
191
- self .watsonx_model = watsonx_model
210
+ if sum (map (bool , (self .model , self .model_id , self .deployment_id ))) != 1 :
211
+ raise ValueError (
212
+ "The parameters 'model', 'model_id' and 'deployment_id' are "
213
+ "mutually exclusive. Please specify exactly one of these "
214
+ "parameters when initializing WatsonxLLM."
215
+ )
216
+ if self .model is not None :
217
+ watsonx_model_gateway = Gateway (
218
+ api_client = self .watsonx_client ,
219
+ verify = self .verify ,
220
+ )
221
+ self .watsonx_model_gateway = watsonx_model_gateway
222
+ else :
223
+ watsonx_model = ModelInference (
224
+ model_id = self .model_id ,
225
+ deployment_id = self .deployment_id ,
226
+ params = self .params ,
227
+ api_client = self .watsonx_client ,
228
+ project_id = self .project_id ,
229
+ space_id = self .space_id ,
230
+ verify = self .verify ,
231
+ )
232
+ self .watsonx_model = watsonx_model
192
233
193
234
else :
235
+ if sum (map (bool , (self .model , self .model_id , self .deployment_id ))) != 1 :
236
+ raise ValueError (
237
+ "The parameters 'model', 'model_id' and 'deployment_id' are "
238
+ "mutually exclusive. Please specify exactly one of these "
239
+ "parameters when initializing WatsonxLLM."
240
+ )
241
+
194
242
check_for_attribute (self .url , "url" , "WATSONX_URL" )
195
243
196
244
if "cloud.ibm.com" in self .url .get_secret_value ():
@@ -239,19 +287,39 @@ def validate_environment(self) -> Self:
239
287
version = self .version .get_secret_value () if self .version else None ,
240
288
verify = self .verify ,
241
289
)
242
-
243
- watsonx_model = ModelInference (
244
- model_id = self .model_id ,
245
- deployment_id = self .deployment_id ,
246
- credentials = credentials ,
247
- params = self .params ,
248
- project_id = self .project_id ,
249
- space_id = self .space_id ,
250
- )
251
- self .watsonx_model = watsonx_model
290
+ if self .model is not None :
291
+ watsonx_model_gateway = Gateway (
292
+ credentials = credentials ,
293
+ verify = self .verify ,
294
+ )
295
+ self .watsonx_model_gateway = watsonx_model_gateway
296
+ else :
297
+ watsonx_model = ModelInference (
298
+ model_id = self .model_id ,
299
+ deployment_id = self .deployment_id ,
300
+ credentials = credentials ,
301
+ params = self .params ,
302
+ project_id = self .project_id ,
303
+ space_id = self .space_id ,
304
+ )
305
+ self .watsonx_model = watsonx_model
252
306
253
307
return self
254
308
309
+ @gateway_error_handler
310
+ def _call_model_gateway (self , * , model : str , prompt : list , ** params : Any ) -> Any :
311
+ return self .watsonx_model_gateway .completions .create (
312
+ model = model , prompt = prompt , ** params
313
+ )
314
+
315
+ @async_gateway_error_handler
316
+ async def _acall_model_gateway (
317
+ self , * , model : str , prompt : list , ** params : Any
318
+ ) -> Any :
319
+ return await self .watsonx_model_gateway .completions .acreate (
320
+ model = model , prompt = prompt , ** params
321
+ )
322
+
255
323
@property
256
324
def _identifying_params (self ) -> Mapping [str , Any ]:
257
325
"""Get the identifying parameters."""
@@ -361,6 +429,30 @@ def _create_llm_result(self, response: List[dict]) -> LLMResult:
361
429
}
362
430
return LLMResult (generations = generations , llm_output = llm_output )
363
431
432
+ def _create_llm_gateway_result (self , response : dict ) -> LLMResult :
433
+ """Create the LLMResult from the choices and prompts."""
434
+ choices = response ["choices" ]
435
+
436
+ generations = [
437
+ [
438
+ Generation (
439
+ text = choice ["text" ],
440
+ generation_info = dict (
441
+ finish_reason = choice .get ("finish_reason" ),
442
+ logprobs = choice .get ("logprobs" ),
443
+ ),
444
+ )
445
+ ]
446
+ for choice in choices
447
+ ]
448
+
449
+ llm_output = {
450
+ "token_usage" : response ["usage" ]["total_tokens" ],
451
+ "model_id" : self .model_id ,
452
+ "deployment_id" : self .deployment_id ,
453
+ }
454
+ return LLMResult (generations = generations , llm_output = llm_output )
455
+
364
456
def _stream_response_to_generation_chunk (
365
457
self ,
366
458
stream_response : Dict [str , Any ],
@@ -470,10 +562,17 @@ def _generate(
470
562
return LLMResult (generations = [[generation ]], llm_output = llm_output )
471
563
return LLMResult (generations = [[generation ]])
472
564
else :
473
- response = self .watsonx_model .generate (
474
- prompt = prompts , params = params , ** kwargs
475
- )
476
- return self ._create_llm_result (response )
565
+ if self .watsonx_model_gateway is not None :
566
+ call_kwargs = {** kwargs , ** params }
567
+ response = self ._call_model_gateway (
568
+ model = self .model , prompt = prompts , ** call_kwargs
569
+ )
570
+ return self ._create_llm_gateway_result (response )
571
+ else :
572
+ response = self .watsonx_model .generate (
573
+ prompt = prompts , params = params , ** kwargs
574
+ )
575
+ return self ._create_llm_result (response )
477
576
478
577
async def _agenerate (
479
578
self ,
@@ -491,14 +590,21 @@ async def _agenerate(
491
590
prompts = prompts , stop = stop , run_manager = run_manager , ** kwargs
492
591
)
493
592
else :
494
- responses = [
495
- await self .watsonx_model .agenerate (
496
- prompt = prompt , params = params , ** kwargs
593
+ if self .watsonx_model_gateway is not None :
594
+ call_kwargs = {** kwargs , ** params }
595
+ responses = await self ._acall_model_gateway (
596
+ model = self .model , prompt = prompts , ** call_kwargs
497
597
)
498
- for prompt in prompts
499
- ]
598
+ return self ._create_llm_gateway_result (responses )
599
+ else :
600
+ responses = [
601
+ await self .watsonx_model .agenerate (
602
+ prompt = prompt , params = params , ** kwargs
603
+ )
604
+ for prompt in prompts
605
+ ]
500
606
501
- return self ._create_llm_result (responses )
607
+ return self ._create_llm_result (responses )
502
608
503
609
def _stream (
504
610
self ,
@@ -523,9 +629,16 @@ def _stream(
523
629
"""
524
630
params , kwargs = self ._get_chat_params (stop = stop , ** kwargs )
525
631
params = self ._validate_chat_params (params )
526
- for stream_resp in self .watsonx_model .generate_text_stream (
527
- prompt = prompt , params = params , ** (kwargs | {"raw_response" : True })
528
- ):
632
+ if self .watsonx_model_gateway is not None :
633
+ call_kwargs = {** kwargs , ** params , "stream" : True }
634
+ chunk_iter = self ._call_model_gateway (
635
+ model = self .model , prompt = prompt , ** call_kwargs
636
+ )
637
+ else :
638
+ chunk_iter = self .watsonx_model .generate_text_stream (
639
+ prompt = prompt , params = params , ** (kwargs | {"raw_response" : True })
640
+ )
641
+ for stream_resp in chunk_iter :
529
642
if not isinstance (stream_resp , dict ):
530
643
stream_resp = stream_resp .dict ()
531
644
chunk = self ._stream_response_to_generation_chunk (stream_resp )
@@ -543,9 +656,17 @@ async def _astream(
543
656
) -> AsyncIterator [GenerationChunk ]:
544
657
params , kwargs = self ._get_chat_params (stop = stop , ** kwargs )
545
658
params = self ._validate_chat_params (params )
546
- async for stream_resp in await self .watsonx_model .agenerate_stream (
547
- prompt = prompt , params = params
548
- ):
659
+
660
+ if self .watsonx_model_gateway is not None :
661
+ call_kwargs = {** kwargs , ** params , "stream" : True }
662
+ chunk_iter = await self ._acall_model_gateway (
663
+ model = self .model , prompt = prompt , ** call_kwargs
664
+ )
665
+ else :
666
+ chunk_iter = await self .watsonx_model .agenerate_stream (
667
+ prompt = prompt , params = params
668
+ )
669
+ async for stream_resp in chunk_iter :
549
670
if not isinstance (stream_resp , dict ):
550
671
stream_resp = stream_resp .dict ()
551
672
chunk = self ._stream_response_to_generation_chunk (stream_resp )
@@ -555,7 +676,12 @@ async def _astream(
555
676
yield chunk
556
677
557
678
def get_num_tokens (self , text : str ) -> int :
558
- response = self .watsonx_model .tokenize (text , return_tokens = False )
679
+ if self .watsonx_model_gateway is not None :
680
+ raise NotImplementedError (
681
+ "Tokenize endpoint is not supported by IBM Model Gateway endpoint."
682
+ )
683
+ else :
684
+ response = self .watsonx_model .tokenize (text , return_tokens = False )
559
685
return response ["result" ]["token_count" ]
560
686
561
687
def get_token_ids (self , text : str ) -> List [int ]:
0 commit comments