1
1
"""Wrapper around Cohere APIs."""
2
- from typing import Any , Dict , List , Mapping , Optional
2
+ import logging
3
+ from typing import Any , Dict , List , Optional
3
4
4
5
from pydantic import BaseModel , Extra , root_validator
5
6
6
7
from langchain .llms .base import LLM
7
8
from langchain .llms .utils import enforce_stop_tokens
8
9
from langchain .utils import get_from_dict_or_env
9
10
11
+ logger = logging .getLogger (__name__ )
12
+
10
13
11
14
class Cohere (LLM , BaseModel ):
12
15
"""Wrapper around Cohere large language models.
@@ -46,6 +49,8 @@ class Cohere(LLM, BaseModel):
46
49
47
50
cohere_api_key : Optional [str ] = None
48
51
52
+ stop : Optional [List [str ]] = None
53
+
49
54
class Config :
50
55
"""Configuration for this pydantic object."""
51
56
@@ -69,7 +74,7 @@ def validate_environment(cls, values: Dict) -> Dict:
69
74
return values
70
75
71
76
@property
72
- def _default_params (self ) -> Mapping [str , Any ]:
77
+ def _default_params (self ) -> Dict [str , Any ]:
73
78
"""Get the default parameters for calling Cohere API."""
74
79
return {
75
80
"max_tokens" : self .max_tokens ,
@@ -81,7 +86,7 @@ def _default_params(self) -> Mapping[str, Any]:
81
86
}
82
87
83
88
@property
84
- def _identifying_params (self ) -> Mapping [str , Any ]:
89
+ def _identifying_params (self ) -> Dict [str , Any ]:
85
90
"""Get the identifying parameters."""
86
91
return {** {"model" : self .model }, ** self ._default_params }
87
92
@@ -105,9 +110,15 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
105
110
106
111
response = cohere("Tell me a joke.")
107
112
"""
108
- response = self .client .generate (
109
- model = self .model , prompt = prompt , stop_sequences = stop , ** self ._default_params
110
- )
113
+ params = self ._default_params
114
+ if self .stop is not None and stop is not None :
115
+ raise ValueError ("`stop` found in both the input and default params." )
116
+ elif self .stop is not None :
117
+ params ["stop_sequences" ] = self .stop
118
+ else :
119
+ params ["stop_sequences" ] = stop
120
+
121
+ response = self .client .generate (model = self .model , prompt = prompt , ** params )
111
122
text = response .generations [0 ].text
112
123
# If stop tokens are provided, Cohere's endpoint returns them.
113
124
# In order to make this consistent with other endpoints, we strip them.
0 commit comments