@@ -401,19 +401,17 @@ class ChatWatsonx(BaseChatModel):
401
401
Example:
402
402
.. code-block:: python
403
403
404
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
405
- parameters = {
406
- GenTextParamsMetaNames.DECODING_METHOD: "sample",
407
- GenTextParamsMetaNames.MAX_NEW_TOKENS: 100,
408
- GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
409
- GenTextParamsMetaNames.TEMPERATURE: 0.5,
410
- GenTextParamsMetaNames.TOP_K: 50,
411
- GenTextParamsMetaNames.TOP_P: 1,
412
- }
404
+ from ibm_watsonx_ai.foundation_models.schema import TextChatParameters
405
+
406
+ parameters = TextChatParameters(
407
+ max_tokens=100,
408
+ temperature=0.5,
409
+ top_p=1,
410
+ )
413
411
414
412
from langchain_ibm import ChatWatsonx
415
413
watsonx_llm = ChatWatsonx(
416
- model_id="meta-llama/llama-3-70b-instruct",
414
+ model_id="meta-llama/llama-3-3- 70b-instruct",
417
415
url="https://us-south.ml.cloud.ibm.com",
418
416
apikey="*****",
419
417
project_id="*****",
@@ -527,6 +525,18 @@ class ChatWatsonx(BaseChatModel):
527
525
"""Time limit in milliseconds - if not completed within this time,
528
526
generation will stop."""
529
527
528
+ logit_bias : Optional [dict ] = None
529
+ """Increasing or decreasing probability of tokens being selected
530
+ during generation."""
531
+
532
+ seed : Optional [int ] = None
533
+ """Random number generator seed to use in sampling mode
534
+ for experimental repeatability."""
535
+
536
+ stop : Optional [list [str ]] = None
537
+ """Stop sequences are one or more strings which will cause the text generation
538
+ to stop if/when they are produced as part of the output."""
539
+
530
540
verify : Union [str , bool , None ] = None
531
541
"""You can pass one of following as verify:
532
542
* the path to a CA_BUNDLE file
@@ -602,16 +612,8 @@ def validate_environment(self) -> Self:
602
612
{
603
613
k : v
604
614
for k , v in {
605
- "frequency_penalty" : self .frequency_penalty ,
606
- "logprobs" : self .logprobs ,
607
- "top_logprobs" : self .top_logprobs ,
608
- "max_tokens" : self .max_tokens ,
609
- "n" : self .n ,
610
- "presence_penalty" : self .presence_penalty ,
611
- "response_format" : self .response_format ,
612
- "temperature" : self .temperature ,
613
- "top_p" : self .top_p ,
614
- "time_limit" : self .time_limit ,
615
+ param : getattr (self , param )
616
+ for param in ChatWatsonx ._get_supported_chat_params ()
615
617
}.items ()
616
618
if v is not None
617
619
}
@@ -768,18 +770,7 @@ def _stream(
768
770
@staticmethod
769
771
def _merge_params (params : dict , kwargs : dict ) -> dict :
770
772
param_updates = {}
771
- for k in [
772
- "frequency_penalty" ,
773
- "logprobs" ,
774
- "top_logprobs" ,
775
- "max_tokens" ,
776
- "n" ,
777
- "presence_penalty" ,
778
- "response_format" ,
779
- "temperature" ,
780
- "top_p" ,
781
- "time_limit" ,
782
- ]:
773
+ for k in ChatWatsonx ._get_supported_chat_params ():
783
774
if kwargs .get (k ) is not None :
784
775
param_updates [k ] = kwargs .pop (k )
785
776
@@ -837,6 +828,25 @@ def _create_chat_result(
837
828
838
829
return ChatResult (generations = generations , llm_output = llm_output )
839
830
831
+ @staticmethod
832
+ def _get_supported_chat_params () -> list [str ]:
833
+ # watsonx.ai Chat API doc: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
834
+ return [
835
+ "frequency_penalty" ,
836
+ "logprobs" ,
837
+ "top_logprobs" ,
838
+ "max_tokens" ,
839
+ "n" ,
840
+ "presence_penalty" ,
841
+ "response_format" ,
842
+ "temperature" ,
843
+ "top_p" ,
844
+ "time_limit" ,
845
+ "logit_bias" ,
846
+ "seed" ,
847
+ "stop" ,
848
+ ]
849
+
840
850
def bind_functions (
841
851
self ,
842
852
functions : Sequence [Union [Dict [str , Any ], Type [BaseModel ], Callable , BaseTool ]],
0 commit comments