|
7 | 7 | from .model import Cohere, Endpoints |
8 | 8 |
|
9 | 9 |
|
| 10 | +@registry.llm_models("spacy.Cohere.v1") |
| 11 | +def cohere_v1( |
| 12 | + name: str, |
| 13 | + config: Dict[Any, Any] = SimpleFrozenDict(), |
| 14 | + strict: bool = Cohere.DEFAULT_STRICT, |
| 15 | + max_tries: int = Cohere.DEFAULT_MAX_TRIES, |
| 16 | + interval: float = Cohere.DEFAULT_INTERVAL, |
| 17 | + max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, |
| 18 | + context_length: Optional[int] = None, |
| 19 | +) -> Cohere: |
| 20 | + """Returns Cohere model instance using REST to prompt API. |
| 21 | + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. |
| 22 | + name (str): Name of model to use. |
| 23 | + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON |
| 24 | + or other response object that does not conform to the expectation of how a well-formed response object from |
| 25 | + this API should look like). If False, the API error responses are returned by __call__(), but no error will |
| 26 | + be raised. |
| 27 | + max_tries (int): Max. number of tries for API request. |
| 28 | + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff |
| 29 | + at each retry. |
| 30 | + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. |
| 31 | + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length |
| 32 | + natively provided by spacy-llm. |
| 33 | + RETURNS (Cohere): Instance of Cohere model. |
| 34 | + """ |
| 35 | + return Cohere( |
| 36 | + name=name, |
| 37 | + endpoint=Endpoints.COMPLETION.value, |
| 38 | + config=config, |
| 39 | + strict=strict, |
| 40 | + max_tries=max_tries, |
| 41 | + interval=interval, |
| 42 | + max_request_time=max_request_time, |
| 43 | + context_length=context_length, |
| 44 | + ) |
| 45 | + |
| 46 | + |
10 | 47 | @registry.llm_models("spacy.Command.v2") |
11 | 48 | def cohere_command_v2( |
12 | 49 | config: Dict[Any, Any] = SimpleFrozenDict(), |
@@ -56,7 +93,7 @@ def cohere_command( |
56 | 93 | max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, |
57 | 94 | ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: |
58 | 95 | """Returns Cohere instance for 'command' model using REST to prompt API. |
59 | | - name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use. |
| 96 | + name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use. |
60 | 97 | config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. |
61 | 98 | strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON |
62 | 99 | or other response object that does not conform to the expectation of how a well-formed response object from |
|
0 commit comments