4343)
4444from ..output import OutputMode
4545from ..profiles import DEFAULT_PROFILE , ModelProfile , ModelProfileSpec
46+ from ..providers import infer_provider
4647from ..settings import ModelSettings , merge_model_settings
4748from ..tools import ToolDefinition
4849from ..usage import RequestUsage
@@ -637,41 +638,39 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
637638 return TestModel ()
638639
639640 try :
640- provider , model_name = model .split (':' , maxsplit = 1 )
641+ provider_name , model_name = model .split (':' , maxsplit = 1 )
641642 except ValueError :
642- provider = None
643+ provider_name = None
643644 model_name = model
644645 if model_name .startswith (('gpt' , 'o1' , 'o3' )):
645- provider = 'openai'
646+ provider_name = 'openai'
646647 elif model_name .startswith ('claude' ):
647- provider = 'anthropic'
648+ provider_name = 'anthropic'
648649 elif model_name .startswith ('gemini' ):
649- provider = 'google-gla'
650+ provider_name = 'google-gla'
650651
651- if provider is not None :
652+ if provider_name is not None :
652653 warnings .warn (
653- f"Specifying a model name without a provider prefix is deprecated. Instead of { model_name !r} , use '{ provider } :{ model_name } '." ,
654+ f"Specifying a model name without a provider prefix is deprecated. Instead of { model_name !r} , use '{ provider_name } :{ model_name } '." ,
654655 DeprecationWarning ,
655656 )
656657 else :
657658 raise UserError (f'Unknown model: { model } ' )
658659
659- if provider == 'vertexai' : # pragma: no cover
660+ if provider_name == 'vertexai' : # pragma: no cover
660661 warnings .warn (
661662 "The 'vertexai' provider name is deprecated. Use 'google-vertex' instead." ,
662663 DeprecationWarning ,
663664 )
664- provider = 'google-vertex'
665+ provider_name = 'google-vertex'
665666
666- if provider == 'gateway' :
667- from ..providers .gateway import infer_model as infer_model_from_gateway
667+ provider = infer_provider (provider_name )
668668
669- return infer_model_from_gateway (model_name )
670- elif provider == 'cohere' :
671- from .cohere import CohereModel
672-
673- return CohereModel (model_name , provider = provider )
674- elif provider in (
669+ model_kind = provider_name
670+ if model_kind .startswith ('gateway/' ):
671+ model_kind = provider_name .removeprefix ('gateway/' )
672+ if model_kind in (
673+ 'openai' ,
675674 'azure' ,
676675 'deepseek' ,
677676 'cerebras' ,
@@ -681,43 +680,50 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
681680 'heroku' ,
682681 'moonshotai' ,
683682 'ollama' ,
684- 'openai' ,
685- 'openai-chat' ,
686683 'openrouter' ,
687684 'together' ,
688685 'vercel' ,
689686 'litellm' ,
690687 'nebius' ,
691688 'ovhcloud' ,
692689 ):
690+ model_kind = 'openai-chat'
691+ elif model_kind in ('google-gla' , 'google-vertex' ):
692+ model_kind = 'google'
693+
694+ if model_kind == 'openai-chat' :
693695 from .openai import OpenAIChatModel
694696
695697 return OpenAIChatModel (model_name , provider = provider )
696- elif provider == 'openai-responses' :
698+ elif model_kind == 'openai-responses' :
697699 from .openai import OpenAIResponsesModel
698700
699- return OpenAIResponsesModel (model_name , provider = 'openai' )
700- elif provider in ( 'google-gla' , 'google-vertex' ) :
701+ return OpenAIResponsesModel (model_name , provider = provider )
702+ elif model_kind == 'google' :
701703 from .google import GoogleModel
702704
703705 return GoogleModel (model_name , provider = provider )
704- elif provider == 'groq' :
706+ elif model_kind == 'groq' :
705707 from .groq import GroqModel
706708
707709 return GroqModel (model_name , provider = provider )
708- elif provider == 'mistral' :
710+ elif model_kind == 'cohere' :
711+ from .cohere import CohereModel
712+
713+ return CohereModel (model_name , provider = provider )
714+ elif model_kind == 'mistral' :
709715 from .mistral import MistralModel
710716
711717 return MistralModel (model_name , provider = provider )
712- elif provider == 'anthropic' :
718+ elif model_kind == 'anthropic' :
713719 from .anthropic import AnthropicModel
714720
715721 return AnthropicModel (model_name , provider = provider )
716- elif provider == 'bedrock' :
722+ elif model_kind == 'bedrock' :
717723 from .bedrock import BedrockConverseModel
718724
719725 return BedrockConverseModel (model_name , provider = provider )
720- elif provider == 'huggingface' :
726+ elif model_kind == 'huggingface' :
721727 from .huggingface import HuggingFaceModel
722728
723729 return HuggingFaceModel (model_name , provider = provider )
0 commit comments