1414"""
1515
1616import abc
17- from collections .abc import Mapping
17+ from collections .abc import Iterable , Mapping
1818from datetime import timedelta
19+ from functools import lru_cache
1920from ipaddress import IPv4Address
2021import json
2122import os
2223from pathlib import Path
2324from typing import TYPE_CHECKING , Any , Optional , Union
24- from typing_extensions import TypeAlias , get_args , get_origin
25+ from typing_extensions import TypeAlias , get_args , get_origin , override
2526
2627from dotenv import dotenv_values
2728from pydantic import BaseModel , Field
@@ -79,7 +80,7 @@ def __repr__(self) -> str:
7980 return f"InitSettingsSource(init_kwargs={ self .init_kwargs !r} )"
8081
8182
82- class DotEnvSettingsSource (BaseSettingsSource ):
83+ class EnvSettingsSource (BaseSettingsSource ):
8384 def __init__ (
8485 self ,
8586 settings_cls : type ["BaseSettings" ],
@@ -110,6 +111,11 @@ def __init__(
110111 else self .config .get ("env_nested_delimiter" , None )
111112 )
112113
114+ @abc .abstractmethod
115+ def get_setting_fields (self ) -> Iterable [ModelField ]:
116+ """获取配置类的字段信息"""
117+ raise NotImplementedError
118+
113119 def _apply_case_sensitive (self , var_name : str ) -> str :
114120 return var_name if self .case_sensitive else var_name .lower ()
115121
@@ -133,6 +139,7 @@ def _read_env_file(self, file_path: Path) -> dict[str, Optional[str]]:
133139 file_vars = dotenv_values (file_path , encoding = self .env_file_encoding )
134140 return self ._parse_env_vars (file_vars )
135141
142+ @lru_cache
136143 def _read_env_files (self ) -> dict [str , Optional [str ]]:
137144 env_files = self .env_file
138145 if env_files is None :
@@ -209,8 +216,8 @@ def __call__(self) -> dict[str, Any]:
209216 env_file_vars = self ._read_env_files ()
210217 env_vars = {** env_file_vars , ** env_vars }
211218
212- for field in model_fields ( self .settings_cls ):
213- field_name = field . name
219+ for field in self .get_setting_fields ( ):
220+ field_name = self . _parse_field_name ( field )
214221 env_name = self ._apply_case_sensitive (field_name )
215222
216223 # try get values from env vars
@@ -283,6 +290,52 @@ def __call__(self) -> dict[str, Any]:
283290
284291 return d
285292
293+ def _parse_field_name (self , field : ModelField ) -> str :
294+ return field .field_info .alias or field .name
295+
296+
297+ class DotEnvSettingsSource (EnvSettingsSource ):
298+ def __init__ (
299+ self ,
300+ settings_cls : type ["BaseSettings" ],
301+ env_file : Optional [DOTENV_TYPE ] = ENV_FILE_SENTINEL ,
302+ env_file_encoding : Optional [str ] = None ,
303+ case_sensitive : Optional [bool ] = None ,
304+ env_nested_delimiter : Optional [str ] = None ,
305+ ) -> None :
306+ super ().__init__ (
307+ settings_cls ,
308+ env_file ,
309+ env_file_encoding ,
310+ case_sensitive ,
311+ env_nested_delimiter ,
312+ )
313+
314+ @override
315+ def get_setting_fields (self ) -> Iterable [ModelField ]:
316+ return model_fields (self .settings_cls )
317+
318+
319+ class PluginEnvSettingsSource (EnvSettingsSource ):
320+ def __init__ (
321+ self ,
322+ config_cls : type [BaseModel ],
323+ driver_config : "Config" ,
324+ ) -> None :
325+ setting_config : "SettingsConfig" = model_config (driver_config .__class__ )
326+ super ().__init__ (
327+ BaseSettings ,
328+ env_file = setting_config .get ("env_file" , None ),
329+ env_file_encoding = setting_config .get ("env_file_encoding" , "utf-8" ),
330+ case_sensitive = setting_config .get ("case_sensitive" , False ),
331+ env_nested_delimiter = setting_config .get ("env_nested_delimiter" , None ),
332+ )
333+ self .config_cls = config_cls
334+
335+ @override
336+ def get_setting_fields (self ) -> Iterable [ModelField ]:
337+ return model_fields (self .config_cls )
338+
286339
287340if PYDANTIC_V2 : # pragma: pydantic-v2
288341
0 commit comments