Skip to content

Commit e215a3b

Browse files
fix docstring and clean up (#53)
* fix docstring and clean up * fix linting * fix linter
1 parent c50be28 commit e215a3b

File tree

2 files changed

+95
-32
lines changed

2 files changed

+95
-32
lines changed

libs/ibm/langchain_ibm/chat_models.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -401,19 +401,17 @@ class ChatWatsonx(BaseChatModel):
401401
Example:
402402
.. code-block:: python
403403
404-
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
405-
parameters = {
406-
GenTextParamsMetaNames.DECODING_METHOD: "sample",
407-
GenTextParamsMetaNames.MAX_NEW_TOKENS: 100,
408-
GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
409-
GenTextParamsMetaNames.TEMPERATURE: 0.5,
410-
GenTextParamsMetaNames.TOP_K: 50,
411-
GenTextParamsMetaNames.TOP_P: 1,
412-
}
404+
from ibm_watsonx_ai.foundation_models.schema import TextChatParameters
405+
406+
parameters = TextChatParameters(
407+
max_tokens=100,
408+
temperature=0.5,
409+
top_p=1,
410+
)
413411
414412
from langchain_ibm import ChatWatsonx
415413
watsonx_llm = ChatWatsonx(
416-
model_id="meta-llama/llama-3-70b-instruct",
414+
model_id="meta-llama/llama-3-3-70b-instruct",
417415
url="https://us-south.ml.cloud.ibm.com",
418416
apikey="*****",
419417
project_id="*****",
@@ -527,6 +525,18 @@ class ChatWatsonx(BaseChatModel):
527525
"""Time limit in milliseconds - if not completed within this time,
528526
generation will stop."""
529527

528+
logit_bias: Optional[dict] = None
529+
"""Increasing or decreasing probability of tokens being selected
530+
during generation."""
531+
532+
seed: Optional[int] = None
533+
"""Random number generator seed to use in sampling mode
534+
for experimental repeatability."""
535+
536+
stop: Optional[list[str]] = None
537+
"""Stop sequences are one or more strings which will cause the text generation
538+
to stop if/when they are produced as part of the output."""
539+
530540
verify: Union[str, bool, None] = None
531541
"""You can pass one of following as verify:
532542
* the path to a CA_BUNDLE file
@@ -602,16 +612,8 @@ def validate_environment(self) -> Self:
602612
{
603613
k: v
604614
for k, v in {
605-
"frequency_penalty": self.frequency_penalty,
606-
"logprobs": self.logprobs,
607-
"top_logprobs": self.top_logprobs,
608-
"max_tokens": self.max_tokens,
609-
"n": self.n,
610-
"presence_penalty": self.presence_penalty,
611-
"response_format": self.response_format,
612-
"temperature": self.temperature,
613-
"top_p": self.top_p,
614-
"time_limit": self.time_limit,
615+
param: getattr(self, param)
616+
for param in ChatWatsonx._get_supported_chat_params()
615617
}.items()
616618
if v is not None
617619
}
@@ -768,18 +770,7 @@ def _stream(
768770
@staticmethod
769771
def _merge_params(params: dict, kwargs: dict) -> dict:
770772
param_updates = {}
771-
for k in [
772-
"frequency_penalty",
773-
"logprobs",
774-
"top_logprobs",
775-
"max_tokens",
776-
"n",
777-
"presence_penalty",
778-
"response_format",
779-
"temperature",
780-
"top_p",
781-
"time_limit",
782-
]:
773+
for k in ChatWatsonx._get_supported_chat_params():
783774
if kwargs.get(k) is not None:
784775
param_updates[k] = kwargs.pop(k)
785776

@@ -837,6 +828,25 @@ def _create_chat_result(
837828

838829
return ChatResult(generations=generations, llm_output=llm_output)
839830

831+
@staticmethod
832+
def _get_supported_chat_params() -> list[str]:
833+
# watsonx.ai Chat API doc: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
834+
return [
835+
"frequency_penalty",
836+
"logprobs",
837+
"top_logprobs",
838+
"max_tokens",
839+
"n",
840+
"presence_penalty",
841+
"response_format",
842+
"temperature",
843+
"top_p",
844+
"time_limit",
845+
"logit_bias",
846+
"seed",
847+
"stop",
848+
]
849+
840850
def bind_functions(
841851
self,
842852
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],

libs/ibm/tests/unit_tests/test_chat_models.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test ChatWatsonx API wrapper."""
22

33
import os
4+
from typing import Any
45

56
from langchain_ibm import ChatWatsonx
67

@@ -82,3 +83,55 @@ def test_initialize_chat_watsonx_cpd_bad_path_without_instance_id() -> None:
8283
except ValueError as e:
8384
assert "instance_id" in e.__str__()
8485
assert "WATSONX_INSTANCE_ID" in e.__str__()
86+
87+
88+
def test_initialize_chat_watsonx_with_all_supported_params(mocker: Any) -> None:
89+
# All params values are taken from
90+
# ibm_watsonx_ai.foundation_models.schema.TextChatParameters.get_sample_params()
91+
92+
from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore[import-untyped]
93+
TextChatParameters,
94+
)
95+
96+
TOP_P = 0.8
97+
98+
def mock_modelinference_chat(*args: Any, **kwargs: Any) -> dict:
99+
"""Mock ModelInference.chat method"""
100+
101+
assert kwargs.get("params", None) == (
102+
TextChatParameters.get_sample_params()
103+
| dict(
104+
logit_bias={"1003": -100, "1004": -100}, seed=41, stop=["this", "the"]
105+
)
106+
| dict(top_p=TOP_P)
107+
)
108+
# logit_bias, seed and stop available in sdk since 1.2.7
109+
return {"id": "123", "choices": [{"message": dict(content="Hi", role="ai")}]}
110+
111+
with mocker.patch(
112+
"ibm_watsonx_ai.foundation_models.ModelInference.__init__", return_value=None
113+
), mocker.patch(
114+
"ibm_watsonx_ai.foundation_models.ModelInference.chat",
115+
side_effect=mock_modelinference_chat,
116+
):
117+
chat = ChatWatsonx(
118+
model_id="google/flan-ul2",
119+
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
120+
apikey="test_apikey", # type: ignore[arg-type]
121+
frequency_penalty=0.5,
122+
logprobs=True,
123+
top_logprobs=3,
124+
presence_penalty=0.3,
125+
response_format={"type": "json_object"},
126+
temperature=0.7,
127+
max_tokens=100,
128+
time_limit=600000,
129+
top_p=0.9,
130+
n=1,
131+
logit_bias={"1003": -100, "1004": -100},
132+
seed=41,
133+
stop=["this", "the"],
134+
)
135+
136+
# change only top_n
137+
chat.invoke("Hello", top_p=TOP_P)

0 commit comments

Comments
 (0)