|
14 | 14 |
|
15 | 15 |
|
16 | 16 | class Gemini(tt.ModelInterface): |
| 17 | + |
17 | 18 | def __init__( |
18 | 19 | self, |
19 | | - id: Optional[str] = "gemini-1.5-pro-latest", |
| 20 | + id: Optional[str] = "gemini-1.5-flash", |
20 | 21 | base_url: str = "https://generativelanguage.googleapis.com/v1beta/models/{id}:{rpc}", |
21 | 22 | extra_headers: Optional[Dict[str, str]] = None, |
22 | 23 | ): |
@@ -120,21 +121,28 @@ def chat( |
120 | 121 | **kwargs, |
121 | 122 | ) -> Any: |
122 | 123 | output = "" |
123 | | - for x in self.stream_chat( |
124 | | - chats=chats, |
125 | | - model=model, |
126 | | - max_tokens=max_tokens, |
127 | | - temperature=temperature, |
128 | | - token=token, |
129 | | - timeout=timeout, |
130 | | - extra_headers=extra_headers, |
131 | | - raw=False, |
132 | | - **kwargs, |
133 | | - ): |
134 | | - if isinstance(x, dict): |
135 | | - output = x |
| 124 | + x = None |
| 125 | + try: |
| 126 | + for x in self.stream_chat( |
| 127 | + chats=chats, |
| 128 | + model=model, |
| 129 | + max_tokens=max_tokens, |
| 130 | + temperature=temperature, |
| 131 | + token=token, |
| 132 | + timeout=timeout, |
| 133 | + extra_headers=extra_headers, |
| 134 | + raw=False, |
| 135 | + **kwargs, |
| 136 | + ): |
| 137 | + if isinstance(x, dict): |
| 138 | + output = x |
| 139 | + else: |
| 140 | + output += x |
| 141 | + except Exception as e: |
| 142 | + if not x: |
| 143 | + raise e |
136 | 144 | else: |
137 | | - output += x |
| 145 | + raise ValueError(x) |
138 | 146 | return output |
139 | 147 |
|
140 | 148 | def stream_chat( |
@@ -194,7 +202,14 @@ def stream_chat( |
194 | 202 | "mode": "ANY", |
195 | 203 | } |
196 | 204 | } |
197 | | - data["tools"] = [{"function_declarations": tools}] |
| 205 | + std_tools = [] |
| 206 | + for i, t in enumerate(tools): |
| 207 | + props = t["parameters"]["properties"] |
| 208 | + t_copy = t.copy() |
| 209 | + if not props: |
| 210 | + t_copy.pop("parameters") |
| 211 | + std_tools.append(t_copy) |
| 212 | + data["tools"] = [{"function_declarations": std_tools}] |
198 | 213 | data.update(kwargs) |
199 | 214 |
|
200 | 215 | if debug: |
|
0 commit comments