1- from agno .db .base import AsyncBaseDb
2- from agno .db .mysql import AsyncMySQLDb
3- from agno .db .postgres import AsyncPostgresDb
4- from agno .models .aimlapi import AIMLAPI
5- from agno .models .anthropic import Claude
6- from agno .models .base import Model
7- from agno .models .cerebras import Cerebras , CerebrasOpenAI
8- from agno .models .cohere import Cohere
9- from agno .models .cometapi import CometAPI
10- from agno .models .dashscope import DashScope
11- from agno .models .deepinfra import DeepInfra
12- from agno .models .deepseek import DeepSeek
13- from agno .models .fireworks import Fireworks
14- from agno .models .google import Gemini
15- from agno .models .groq import Groq
16- from agno .models .huggingface import HuggingFace
17- from agno .models .langdb import LangDB
18- from agno .models .litellm import LiteLLM , LiteLLMOpenAI
19- from agno .models .llama_cpp import LlamaCpp
20- from agno .models .lmstudio import LMStudio
21- from agno .models .meta import Llama
22- from agno .models .mistral import MistralChat
23- from agno .models .n1n import N1N
24- from agno .models .nebius import Nebius
25- from agno .models .nexus import Nexus
26- from agno .models .nvidia import Nvidia
27- from agno .models .ollama import Ollama
28- from agno .models .openai import OpenAIChat
29- from agno .models .openai .responses import OpenAIResponses
30- from agno .models .openrouter import OpenRouter
31- from agno .models .perplexity import Perplexity
32- from agno .models .portkey import Portkey
33- from agno .models .requesty import Requesty
34- from agno .models .sambanova import Sambanova
35- from agno .models .siliconflow import Siliconflow
36- from agno .models .together import Together
37- from agno .models .vercel import V0
38- from agno .models .vllm import VLLM
39- from agno .models .xai import xAI
1+ from importlib import import_module
2+ from typing import TYPE_CHECKING
403
414from config .database import async_engine
425from config .env import DataBaseConfig
436
44- provider_model_map : dict [str , type [Model ]] = {
45- 'AIMLAPI' : AIMLAPI ,
46- 'Anthropic' : Claude ,
47- 'Cerebras' : Cerebras ,
48- 'CerebrasOpenAI' : CerebrasOpenAI ,
49- 'Cohere' : Cohere ,
50- 'CometAPI' : CometAPI ,
51- 'DashScope' : DashScope ,
52- 'DeepInfra' : DeepInfra ,
53- 'DeepSeek' : DeepSeek ,
54- 'Fireworks' : Fireworks ,
55- 'Google' : Gemini ,
56- 'Groq' : Groq ,
57- 'HuggingFace' : HuggingFace ,
58- 'LangDB' : LangDB ,
59- 'LiteLLM' : LiteLLM ,
60- 'LiteLLMOpenAI' : LiteLLMOpenAI ,
61- 'LlamaCpp' : LlamaCpp ,
62- 'LMStudio' : LMStudio ,
63- 'Meta' : Llama ,
64- 'Mistral' : MistralChat ,
65- 'N1N' : N1N ,
66- 'Nebius' : Nebius ,
67- 'Nexus' : Nexus ,
68- 'Nvidia' : Nvidia ,
69- 'Ollama' : Ollama ,
70- 'OpenAI' : OpenAIChat ,
71- 'OpenAIResponses' : OpenAIResponses ,
72- 'OpenRouter' : OpenRouter ,
73- 'Perplexity' : Perplexity ,
74- 'Portkey' : Portkey ,
75- 'Requesty' : Requesty ,
76- 'Sambanova' : Sambanova ,
77- 'SiliconFlow' : Siliconflow ,
78- 'Together' : Together ,
79- 'Vercel' : V0 ,
80- 'VLLM' : VLLM ,
81- 'xAI' : xAI ,
82- }
7+ if TYPE_CHECKING :
8+ from agno .db .base import AsyncBaseDb
9+ from agno .models .base import Model
8310
11+ # 提供商名称 -> (模块路径, 类名) 的映射,延迟导入避免启动时加载所有AI SDK
12+ _PROVIDER_REGISTRY : dict [str , tuple [str , str ]] = {
13+ 'AIMLAPI' : ('agno.models.aimlapi' , 'AIMLAPI' ),
14+ 'Anthropic' : ('agno.models.anthropic' , 'Claude' ),
15+ 'Cerebras' : ('agno.models.cerebras' , 'Cerebras' ),
16+ 'CerebrasOpenAI' : ('agno.models.cerebras' , 'CerebrasOpenAI' ),
17+ 'Cohere' : ('agno.models.cohere' , 'Cohere' ),
18+ 'CometAPI' : ('agno.models.cometapi' , 'CometAPI' ),
19+ 'DashScope' : ('agno.models.dashscope' , 'DashScope' ),
20+ 'DeepInfra' : ('agno.models.deepinfra' , 'DeepInfra' ),
21+ 'DeepSeek' : ('agno.models.deepseek' , 'DeepSeek' ),
22+ 'Fireworks' : ('agno.models.fireworks' , 'Fireworks' ),
23+ 'Google' : ('agno.models.google' , 'Gemini' ),
24+ 'Groq' : ('agno.models.groq' , 'Groq' ),
25+ 'HuggingFace' : ('agno.models.huggingface' , 'HuggingFace' ),
26+ 'LangDB' : ('agno.models.langdb' , 'LangDB' ),
27+ 'LiteLLM' : ('agno.models.litellm' , 'LiteLLM' ),
28+ 'LiteLLMOpenAI' : ('agno.models.litellm' , 'LiteLLMOpenAI' ),
29+ 'LlamaCpp' : ('agno.models.llama_cpp' , 'LlamaCpp' ),
30+ 'LMStudio' : ('agno.models.lmstudio' , 'LMStudio' ),
31+ 'Meta' : ('agno.models.meta' , 'Llama' ),
32+ 'Mistral' : ('agno.models.mistral' , 'MistralChat' ),
33+ 'N1N' : ('agno.models.n1n' , 'N1N' ),
34+ 'Nebius' : ('agno.models.nebius' , 'Nebius' ),
35+ 'Nexus' : ('agno.models.nexus' , 'Nexus' ),
36+ 'Nvidia' : ('agno.models.nvidia' , 'Nvidia' ),
37+ 'Ollama' : ('agno.models.ollama' , 'Ollama' ),
38+ 'OpenAI' : ('agno.models.openai' , 'OpenAIChat' ),
39+ 'OpenAIResponses' : ('agno.models.openai.responses' , 'OpenAIResponses' ),
40+ 'OpenRouter' : ('agno.models.openrouter' , 'OpenRouter' ),
41+ 'Perplexity' : ('agno.models.perplexity' , 'Perplexity' ),
42+ 'Portkey' : ('agno.models.portkey' , 'Portkey' ),
43+ 'Requesty' : ('agno.models.requesty' , 'Requesty' ),
44+ 'Sambanova' : ('agno.models.sambanova' , 'Sambanova' ),
45+ 'SiliconFlow' : ('agno.models.siliconflow' , 'Siliconflow' ),
46+ 'Together' : ('agno.models.together' , 'Together' ),
47+ 'Vercel' : ('agno.models.vercel' , 'V0' ),
48+ 'VLLM' : ('agno.models.vllm' , 'VLLM' ),
49+ 'xAI' : ('agno.models.xai' , 'xAI' ),
50+ }
8451
85- storage_engine_map : dict [str , type [AsyncBaseDb ]] = {
86- 'mysql' : AsyncMySQLDb ,
87- 'postgresql' : AsyncPostgresDb ,
52+ # 存储引擎名称 -> (模块路径, 类名) 的映射
53+ _STORAGE_ENGINE_REGISTRY : dict [str , tuple [str , str ]] = {
54+ 'mysql' : ('agno.db.mysql' , 'AsyncMySQLDb' ),
55+ 'postgresql' : ('agno.db.postgres' , 'AsyncPostgresDb' ),
8856}
8957
58+ # 已加载的提供商类缓存,避免重复import_module
59+ _provider_class_cache : dict [str , 'type[Model]' ] = {}
60+ _storage_class_cache : dict [str , 'type[AsyncBaseDb]' ] = {}
61+
62+
63+ def _resolve_provider_class (provider : str ) -> 'type[Model] | None' :
64+ """
65+ 按需加载并缓存提供商模型类
66+
67+ :param provider: 提供商名称
68+ :return: 模型类,未找到返回None
69+ """
70+ if provider in _provider_class_cache :
71+ return _provider_class_cache [provider ]
72+ entry = _PROVIDER_REGISTRY .get (provider )
73+ if entry is None :
74+ return None
75+ module_path , class_name = entry
76+ cls = getattr (import_module (module_path ), class_name )
77+ _provider_class_cache [provider ] = cls
78+ return cls
79+
80+
81+ def _resolve_storage_class (db_type : str ) -> 'type[AsyncBaseDb]' :
82+ """
83+ 按需加载并缓存存储引擎类
84+
85+ :param db_type: 数据库类型
86+ :return: 存储引擎类
87+ """
88+ if db_type in _storage_class_cache :
89+ return _storage_class_cache [db_type ]
90+ entry = _STORAGE_ENGINE_REGISTRY .get (db_type )
91+ if entry is None :
92+ # 默认使用MySQL
93+ entry = _STORAGE_ENGINE_REGISTRY ['mysql' ]
94+ module_path , class_name = entry
95+ cls = getattr (import_module (module_path ), class_name )
96+ _storage_class_cache [db_type ] = cls
97+ return cls
98+
9099
91100class AiUtil :
92101 """
93102 AI工具类
94103 """
95104
96105 @classmethod
97- def get_storage_engine (cls ) -> AsyncBaseDb :
106+ def get_storage_engine (cls ) -> ' AsyncBaseDb' :
98107 """
99108 获取存储引擎实例
100109
101110 :return: 存储引擎实例
102111 """
103- storage_engine_class = storage_engine_map . get (DataBaseConfig .db_type , AsyncMySQLDb )
112+ storage_engine_class = _resolve_storage_class (DataBaseConfig .db_type )
104113
105114 return storage_engine_class (
106115 db_engine = async_engine ,
@@ -128,7 +137,7 @@ def get_model_from_factory(
128137 temperature : float | None = None ,
129138 max_tokens : int | None = None ,
130139 ** kwargs ,
131- ) -> Model :
140+ ) -> ' Model' :
132141 """
133142 从工厂获取模型实例
134143
@@ -155,6 +164,9 @@ def get_model_from_factory(
155164 params ['host' ] = base_url
156165 if provider == 'DashScope' and not base_url :
157166 params ['base_url' ] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
158- model_class = provider_model_map .get (provider , OpenAIChat )
167+ model_class = _resolve_provider_class (provider )
168+ if model_class is None :
169+ # 未知提供商,回退到OpenAI
170+ model_class = _resolve_provider_class ('OpenAI' )
159171
160172 return model_class (** params )
0 commit comments