28
28
BaseSchema ,
29
29
TextChatParameters ,
30
30
)
31
+ from ibm_watsonx_ai .gateway import Gateway # type: ignore
31
32
from langchain_core .callbacks import (
32
33
AsyncCallbackManagerForLLMRun ,
33
34
CallbackManagerForLLMRun ,
@@ -428,6 +429,9 @@ class ChatWatsonx(BaseChatModel):
428
429
model_id : Optional [str ] = None
429
430
"""Type of model to use."""
430
431
432
+ model : Optional [str ] = None
433
+ """Name of model for given provider or alias."""
434
+
431
435
deployment_id : Optional [str ] = None
432
436
"""Type of deployed model to use."""
433
437
@@ -558,6 +562,10 @@ class ChatWatsonx(BaseChatModel):
558
562
559
563
watsonx_model : ModelInference = Field (default = None , exclude = True ) #: :meta private:
560
564
565
+ watsonx_model_gateway : Gateway = Field (
566
+ default = None , exclude = True
567
+ ) #: :meta private:
568
+
561
569
watsonx_client : Optional [APIClient ] = Field (default = None , exclude = True )
562
570
563
571
model_config = ConfigDict (populate_by_name = True )
@@ -624,21 +632,58 @@ def validate_environment(self) -> Self:
624
632
if v is not None
625
633
}
626
634
)
635
+ if self .watsonx_model_gateway is not None :
636
+ raise NotImplementedError (
637
+ "Passing the 'watsonx_model_gateway' parameter to the ChatWatsonx "
638
+ "constructor is not supported yet."
639
+ )
627
640
628
- if isinstance (self .watsonx_client , APIClient ):
629
- watsonx_model = ModelInference (
630
- model_id = self .model_id ,
631
- deployment_id = self .deployment_id ,
632
- params = self .params ,
633
- api_client = self .watsonx_client ,
634
- project_id = self .project_id ,
635
- space_id = self .space_id ,
636
- verify = self .verify ,
637
- validate = self .validate_model ,
641
+ if isinstance (self .watsonx_model , ModelInference ):
642
+ self .model_id = getattr (self .watsonx_model , "model_id" )
643
+ self .deployment_id = getattr (self .watsonx_model , "deployment_id" , "" )
644
+ self .project_id = getattr (
645
+ getattr (self .watsonx_model , "_client" ),
646
+ "default_project_id" ,
647
+ )
648
+ self .space_id = getattr (
649
+ getattr (self .watsonx_model , "_client" ), "default_space_id"
638
650
)
639
- self .watsonx_model = watsonx_model
651
+ self .params = getattr (self .watsonx_model , "params" )
652
+ self .watsonx_client = getattr (self .watsonx_model , "_client" )
640
653
654
+ elif isinstance (self .watsonx_client , APIClient ):
655
+ if sum (map (bool , (self .model , self .model_id , self .deployment_id ))) != 1 :
656
+ raise ValueError (
657
+ "The parameters 'model', 'model_id' and 'deployment_id' are "
658
+ "mutually exclusive. Please specify exactly one of these "
659
+ "parameters when initializing ChatWatsonx."
660
+ )
661
+ if self .model is not None :
662
+ watsonx_model_gateway = Gateway (
663
+ api_client = self .watsonx_client ,
664
+ verify = self .verify ,
665
+ )
666
+ self .watsonx_model_gateway = watsonx_model_gateway
667
+ else :
668
+ watsonx_model = ModelInference (
669
+ model_id = self .model_id ,
670
+ deployment_id = self .deployment_id ,
671
+ params = self .params ,
672
+ api_client = self .watsonx_client ,
673
+ project_id = self .project_id ,
674
+ space_id = self .space_id ,
675
+ verify = self .verify ,
676
+ validate = self .validate_model ,
677
+ )
678
+ self .watsonx_model = watsonx_model
641
679
else :
680
+ if sum (map (bool , (self .model , self .model_id , self .deployment_id ))) != 1 :
681
+ raise ValueError (
682
+ "The parameters 'model', 'model_id' and 'deployment_id' are "
683
+ "mutually exclusive. Please specify exactly one of these "
684
+ "parameters when initializing ChatWatsonx."
685
+ )
686
+
642
687
check_for_attribute (self .url , "url" , "WATSONX_URL" )
643
688
644
689
if "cloud.ibm.com" in self .url .get_secret_value ():
@@ -687,18 +732,24 @@ def validate_environment(self) -> Self:
687
732
version = self .version .get_secret_value () if self .version else None ,
688
733
verify = self .verify ,
689
734
)
690
-
691
- watsonx_chat = ModelInference (
692
- model_id = self .model_id ,
693
- deployment_id = self .deployment_id ,
694
- credentials = credentials ,
695
- params = self .params ,
696
- project_id = self .project_id ,
697
- space_id = self .space_id ,
698
- verify = self .verify ,
699
- validate = self .validate_model ,
700
- )
701
- self .watsonx_model = watsonx_chat
735
+ if self .model is not None :
736
+ watsonx_model_gateway = Gateway (
737
+ credentials = credentials ,
738
+ verify = self .verify ,
739
+ )
740
+ self .watsonx_model_gateway = watsonx_model_gateway
741
+ else :
742
+ watsonx_model = ModelInference (
743
+ model_id = self .model_id ,
744
+ deployment_id = self .deployment_id ,
745
+ credentials = credentials ,
746
+ params = self .params ,
747
+ project_id = self .project_id ,
748
+ space_id = self .space_id ,
749
+ verify = self .verify ,
750
+ validate = self .validate_model ,
751
+ )
752
+ self .watsonx_model = watsonx_model
702
753
703
754
return self
704
755
@@ -717,10 +768,14 @@ def _generate(
717
768
718
769
message_dicts , params = self ._create_message_dicts (messages , stop , ** kwargs )
719
770
updated_params = self ._merge_params (params , kwargs )
720
-
721
- response = self .watsonx_model .chat (
722
- messages = message_dicts , ** (kwargs | {"params" : updated_params })
723
- )
771
+ if self .watsonx_model_gateway is not None :
772
+ response = self .watsonx_model_gateway .chat .completions .create (
773
+ model = self .model , messages = message_dicts , ** (kwargs | updated_params )
774
+ )
775
+ else :
776
+ response = self .watsonx_model .chat (
777
+ messages = message_dicts , ** (kwargs | {"params" : updated_params })
778
+ )
724
779
return self ._create_chat_result (response )
725
780
726
781
async def _agenerate (
@@ -738,10 +793,14 @@ async def _agenerate(
738
793
739
794
message_dicts , params = self ._create_message_dicts (messages , stop , ** kwargs )
740
795
updated_params = self ._merge_params (params , kwargs )
741
-
742
- response = await self .watsonx_model .achat (
743
- messages = message_dicts , ** (kwargs | {"params" : updated_params })
744
- )
796
+ if self .watsonx_model_gateway is not None :
797
+ response = await self .watsonx_model_gateway .chat .completions .acreate (
798
+ model = self .model , messages = message_dicts , ** (kwargs | updated_params )
799
+ )
800
+ else :
801
+ response = await self .watsonx_model .achat (
802
+ messages = message_dicts , ** (kwargs | {"params" : updated_params })
803
+ )
745
804
return self ._create_chat_result (response )
746
805
747
806
def _stream (
@@ -754,16 +813,23 @@ def _stream(
754
813
message_dicts , params = self ._create_message_dicts (messages , stop , ** kwargs )
755
814
updated_params = self ._merge_params (params , kwargs )
756
815
757
- default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
816
+ if self .watsonx_model_gateway is not None :
817
+ call_kwargs = {** kwargs , ** updated_params , "stream" : True }
818
+ chunk_iter = self .watsonx_model_gateway .chat .completions .create (
819
+ model = self .model , messages = message_dicts , ** call_kwargs
820
+ )
821
+ else :
822
+ call_kwargs = {** kwargs , "params" : updated_params }
823
+ chunk_iter = self .watsonx_model .chat_stream (
824
+ messages = message_dicts , ** call_kwargs
825
+ )
758
826
827
+ default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
759
828
is_first_tool_chunk = True
760
829
_prompt_tokens_included = False
761
830
762
- for chunk in self .watsonx_model .chat_stream (
763
- messages = message_dicts , ** (kwargs | {"params" : updated_params })
764
- ):
765
- if not isinstance (chunk , dict ):
766
- chunk = chunk .model_dump ()
831
+ for chunk in chunk_iter :
832
+ chunk = chunk if isinstance (chunk , dict ) else chunk .model_dump ()
767
833
generation_chunk = _convert_chunk_to_generation_chunk (
768
834
chunk , default_chunk_class , is_first_tool_chunk , _prompt_tokens_included
769
835
)
@@ -804,17 +870,23 @@ async def _astream(
804
870
message_dicts , params = self ._create_message_dicts (messages , stop , ** kwargs )
805
871
updated_params = self ._merge_params (params , kwargs )
806
872
807
- default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
873
+ if self .watsonx_model_gateway is not None :
874
+ call_kwargs = {** kwargs , ** updated_params , "stream" : True }
875
+ chunk_iter = await self .watsonx_model_gateway .chat .completions .acreate (
876
+ model = self .model , messages = message_dicts , ** call_kwargs
877
+ )
878
+ else :
879
+ call_kwargs = {** kwargs , "params" : updated_params }
880
+ chunk_iter = await self .watsonx_model .achat_stream (
881
+ messages = message_dicts , ** call_kwargs
882
+ )
808
883
884
+ default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
809
885
is_first_tool_chunk = True
810
886
_prompt_tokens_included = False
811
887
812
- response = await self .watsonx_model .achat_stream (
813
- messages = message_dicts , ** (kwargs | {"params" : updated_params })
814
- )
815
- async for chunk in response :
816
- if not isinstance (chunk , dict ):
817
- chunk = chunk .model_dump ()
888
+ async for chunk in chunk_iter :
889
+ chunk = chunk if isinstance (chunk , dict ) else chunk .model_dump ()
818
890
generation_chunk = _convert_chunk_to_generation_chunk (
819
891
chunk ,
820
892
default_chunk_class ,
0 commit comments