Skip to content

Commit 966611b

Browse files
authored
add model kwargs to handle stop token from cohere (#773)
1 parent 7198a1c commit 966611b

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

langchain/llms/cohere.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""Wrapper around Cohere APIs."""
2-
from typing import Any, Dict, List, Mapping, Optional
2+
import logging
3+
from typing import Any, Dict, List, Optional
34

45
from pydantic import BaseModel, Extra, root_validator
56

67
from langchain.llms.base import LLM
78
from langchain.llms.utils import enforce_stop_tokens
89
from langchain.utils import get_from_dict_or_env
910

11+
logger = logging.getLogger(__name__)
12+
1013

1114
class Cohere(LLM, BaseModel):
1215
"""Wrapper around Cohere large language models.
@@ -46,6 +49,8 @@ class Cohere(LLM, BaseModel):
4649

4750
cohere_api_key: Optional[str] = None
4851

52+
stop: Optional[List[str]] = None
53+
4954
class Config:
5055
"""Configuration for this pydantic object."""
5156

@@ -69,7 +74,7 @@ def validate_environment(cls, values: Dict) -> Dict:
6974
return values
7075

7176
@property
72-
def _default_params(self) -> Mapping[str, Any]:
77+
def _default_params(self) -> Dict[str, Any]:
7378
"""Get the default parameters for calling Cohere API."""
7479
return {
7580
"max_tokens": self.max_tokens,
@@ -81,7 +86,7 @@ def _default_params(self) -> Mapping[str, Any]:
8186
}
8287

8388
@property
84-
def _identifying_params(self) -> Mapping[str, Any]:
89+
def _identifying_params(self) -> Dict[str, Any]:
8590
"""Get the identifying parameters."""
8691
return {**{"model": self.model}, **self._default_params}
8792

@@ -105,9 +110,15 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
105110
106111
response = cohere("Tell me a joke.")
107112
"""
108-
response = self.client.generate(
109-
model=self.model, prompt=prompt, stop_sequences=stop, **self._default_params
110-
)
113+
params = self._default_params
114+
if self.stop is not None and stop is not None:
115+
raise ValueError("`stop` found in both the input and default params.")
116+
elif self.stop is not None:
117+
params["stop_sequences"] = self.stop
118+
else:
119+
params["stop_sequences"] = stop
120+
121+
response = self.client.generate(model=self.model, prompt=prompt, **params)
111122
text = response.generations[0].text
112123
# If stop tokens are provided, Cohere's endpoint returns them.
113124
# In order to make this consistent with other endpoints, we strip them.

0 commit comments

Comments
 (0)