|
12 | 12 | from dataclasses import dataclass, field
|
13 | 13 | from datetime import datetime
|
14 | 14 | from functools import cache
|
15 |
| -from typing import TYPE_CHECKING |
| 15 | +from typing import TYPE_CHECKING, cast |
16 | 16 |
|
17 | 17 | import httpx
|
18 | 18 | from typing_extensions import Literal
|
@@ -383,6 +383,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
383 | 383 |
|
384 | 384 | try:
|
385 | 385 | provider, model_name = model.split(':', maxsplit=1)
|
| 386 | + provider = cast(str, provider) |
386 | 387 | except ValueError:
|
387 | 388 | model_name = model
|
388 | 389 | # TODO(Marcelo): We should deprecate this way.
|
@@ -414,22 +415,19 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
414 | 415 | elif provider == 'groq':
|
415 | 416 | from .groq import GroqModel
|
416 | 417 |
|
417 |
| - # TODO(Marcelo): Missing provider API. |
418 |
| - return GroqModel(model_name) |
| 418 | + return GroqModel(model_name, provider=provider) |
419 | 419 | elif provider == 'mistral':
|
420 | 420 | from .mistral import MistralModel
|
421 | 421 |
|
422 |
| - # TODO(Marcelo): Missing provider API. |
423 |
| - return MistralModel(model_name) |
| 422 | + return MistralModel(model_name, provider=provider) |
424 | 423 | elif provider == 'anthropic':
|
425 | 424 | from .anthropic import AnthropicModel
|
426 | 425 |
|
427 |
| - # TODO(Marcelo): Missing provider API. |
428 |
| - return AnthropicModel(model_name) |
| 426 | + return AnthropicModel(model_name, provider=provider) |
429 | 427 | elif provider == 'bedrock':
|
430 | 428 | from .bedrock import BedrockConverseModel
|
431 | 429 |
|
432 |
| - return BedrockConverseModel(model_name) |
| 430 | + return BedrockConverseModel(model_name, provider=provider) |
433 | 431 | else:
|
434 | 432 | raise UserError(f'Unknown model: {model}')
|
435 | 433 |
|
|
0 commit comments