|
1 | 1 | import copy |
| 2 | +from functools import partial |
2 | 3 | from typing import Any, Dict, List, Optional, Tuple, Type, Union |
3 | 4 |
|
4 | 5 | from pydantic import BaseModel |
@@ -259,37 +260,25 @@ async def async_call( |
259 | 260 | 2. Convert the response string to a dict, |
260 | 261 | 3. Log the output |
261 | 262 | """ |
| 263 | + # If the API supports a base model, pass it in. |
| 264 | + api_fn = api |
| 265 | + if api is not None: |
| 266 | + supports_base_model = getattr(api, "supports_base_model", False) |
| 267 | + if supports_base_model: |
| 268 | + api_fn = partial(api, base_model=self.base_model) |
| 269 | + |
262 | 270 | if output is not None: |
263 | 271 | llm_response = LLMResponse( |
264 | 272 | output=output, |
265 | 273 | ) |
266 | | - elif api is None: |
| 274 | + elif api_fn is None: |
267 | 275 | raise ValueError("Either API or output must be provided.") |
268 | 276 | elif msg_history: |
269 | | - try: |
270 | | - llm_response = await api( |
271 | | - msg_history=msg_history_source(msg_history), |
272 | | - base_model=self.base_model, |
273 | | - ) |
274 | | - except Exception: |
275 | | - # If the API call fails, try calling again without the base model. |
276 | | - llm_response = await api(msg_history=msg_history_source(msg_history)) |
| 277 | + llm_response = await api_fn(msg_history=msg_history_source(msg_history)) |
277 | 278 | elif prompt and instructions: |
278 | | - try: |
279 | | - llm_response = await api( |
280 | | - prompt.source, |
281 | | - instructions=instructions.source, |
282 | | - base_model=self.base_model, |
283 | | - ) |
284 | | - except Exception: |
285 | | - llm_response = await api( |
286 | | - prompt.source, instructions=instructions.source |
287 | | - ) |
| 279 | + llm_response = await api_fn(prompt.source, instructions=instructions.source) |
288 | 280 | elif prompt: |
289 | | - try: |
290 | | - llm_response = await api(prompt.source, base_model=self.base_model) |
291 | | - except Exception: |
292 | | - llm_response = await api(prompt.source) |
| 281 | + llm_response = await api_fn(prompt.source) |
293 | 282 | else: |
294 | 283 | raise ValueError("'output', 'prompt' or 'msg_history' must be provided.") |
295 | 284 |
|
|
0 commit comments