|
3 | 3 | from collections.abc import Iterable
|
4 | 4 | from dataclasses import dataclass, field
|
5 | 5 | from itertools import chain
|
6 |
| -from typing import Literal, TypeAlias, Union, cast |
| 6 | +from typing import Literal, Union, cast |
7 | 7 |
|
8 | 8 | from cohere import TextAssistantMessageContentItem
|
| 9 | +from httpx import AsyncClient as AsyncHTTPClient |
9 | 10 | from typing_extensions import assert_never
|
10 | 11 |
|
11 | 12 | from .. import result
|
|
51 | 52 | "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
52 | 53 | ) from _import_error
|
53 | 54 |
|
54 |
| -CohereModelName: TypeAlias = Union[ |
55 |
| - str, |
56 |
| - Literal[ |
57 |
| - 'c4ai-aya-expanse-32b', |
58 |
| - 'c4ai-aya-expanse-8b', |
59 |
| - 'command', |
60 |
| - 'command-light', |
61 |
| - 'command-light-nightly', |
62 |
| - 'command-nightly', |
63 |
| - 'command-r', |
64 |
| - 'command-r-03-2024', |
65 |
| - 'command-r-08-2024', |
66 |
| - 'command-r-plus', |
67 |
| - 'command-r-plus-04-2024', |
68 |
| - 'command-r-plus-08-2024', |
69 |
| - 'command-r7b-12-2024', |
70 |
| - ], |
| 55 | +NamedCohereModels = Literal[ |
| 56 | + 'c4ai-aya-expanse-32b', |
| 57 | + 'c4ai-aya-expanse-8b', |
| 58 | + 'command', |
| 59 | + 'command-light', |
| 60 | + 'command-light-nightly', |
| 61 | + 'command-nightly', |
| 62 | + 'command-r', |
| 63 | + 'command-r-03-2024', |
| 64 | + 'command-r-08-2024', |
| 65 | + 'command-r-plus', |
| 66 | + 'command-r-plus-04-2024', |
| 67 | + 'command-r-plus-08-2024', |
| 68 | + 'command-r7b-12-2024', |
71 | 69 | ]
|
| 70 | +"""Latest / most popular named Cohere models.""" |
| 71 | + |
| 72 | +CohereModelName = Union[NamedCohereModels, str] |
72 | 73 |
|
73 | 74 |
|
74 | 75 | class CohereModelSettings(ModelSettings):
|
@@ -96,23 +97,26 @@ def __init__(
|
96 | 97 | *,
|
97 | 98 | api_key: str | None = None,
|
98 | 99 | cohere_client: AsyncClientV2 | None = None,
|
| 100 | + http_client: AsyncHTTPClient | None = None, |
99 | 101 | ):
|
100 | 102 | """Initialize an Cohere model.
|
101 | 103 |
|
102 | 104 | Args:
|
103 | 105 | model_name: The name of the Cohere model to use. List of model names
|
104 | 106 | available [here](https://docs.cohere.com/docs/models#command).
|
105 | 107 | api_key: The API key to use for authentication, if not provided, the
|
106 |
| - `COHERE_API_KEY` environment variable will be used if available. |
| 108 | + `CO_API_KEY` environment variable will be used if available. |
107 | 109 | cohere_client: An existing Cohere async client to use. If provided,
|
108 |
| - `api_key` must be `None`. |
| 110 | + `api_key` and `http_client` must be `None`. |
| 111 | + http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. |
109 | 112 | """
|
110 | 113 | self.model_name: CohereModelName = model_name
|
111 | 114 | if cohere_client is not None:
|
| 115 | + assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`' |
112 | 116 | assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
113 | 117 | self.client = cohere_client
|
114 | 118 | else:
|
115 |
| - self.client = AsyncClientV2(api_key=api_key) # type: ignore |
| 119 | + self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore |
116 | 120 |
|
117 | 121 | async def agent_model(
|
118 | 122 | self,
|
|
0 commit comments