Skip to content

Commit 8c8e865

Browse files
committed
✨ 允许插件从环境变量中读取配置项而不需要在envfile中声明
1 parent 581ba52 commit 8c8e865

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

nonebot/config.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
"""
1515

1616
import abc
17-
from collections.abc import Mapping
17+
from collections.abc import Iterable, Mapping
1818
from datetime import timedelta
19+
from functools import lru_cache
1920
from ipaddress import IPv4Address
2021
import json
2122
import os
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Optional, Union
24-
from typing_extensions import TypeAlias, get_args, get_origin
25+
from typing_extensions import TypeAlias, get_args, get_origin, override
2526

2627
from dotenv import dotenv_values
2728
from pydantic import BaseModel, Field
@@ -79,7 +80,7 @@ def __repr__(self) -> str:
7980
return f"InitSettingsSource(init_kwargs={self.init_kwargs!r})"
8081

8182

82-
class DotEnvSettingsSource(BaseSettingsSource):
83+
class EnvSettingsSource(BaseSettingsSource):
8384
def __init__(
8485
self,
8586
settings_cls: type["BaseSettings"],
@@ -110,6 +111,11 @@ def __init__(
110111
else self.config.get("env_nested_delimiter", None)
111112
)
112113

114+
@abc.abstractmethod
115+
def get_setting_fields(self) -> Iterable[ModelField]:
116+
"""获取配置类的字段信息"""
117+
raise NotImplementedError
118+
113119
def _apply_case_sensitive(self, var_name: str) -> str:
114120
return var_name if self.case_sensitive else var_name.lower()
115121

@@ -133,6 +139,7 @@ def _read_env_file(self, file_path: Path) -> dict[str, Optional[str]]:
133139
file_vars = dotenv_values(file_path, encoding=self.env_file_encoding)
134140
return self._parse_env_vars(file_vars)
135141

142+
@lru_cache
136143
def _read_env_files(self) -> dict[str, Optional[str]]:
137144
env_files = self.env_file
138145
if env_files is None:
@@ -209,8 +216,8 @@ def __call__(self) -> dict[str, Any]:
209216
env_file_vars = self._read_env_files()
210217
env_vars = {**env_file_vars, **env_vars}
211218

212-
for field in model_fields(self.settings_cls):
213-
field_name = field.name
219+
for field in self.get_setting_fields():
220+
field_name = self._parse_field_name(field)
214221
env_name = self._apply_case_sensitive(field_name)
215222

216223
# try get values from env vars
@@ -283,6 +290,52 @@ def __call__(self) -> dict[str, Any]:
283290

284291
return d
285292

293+
def _parse_field_name(self, field: ModelField) -> str:
294+
return field.field_info.alias or field.name
295+
296+
297+
class DotEnvSettingsSource(EnvSettingsSource):
298+
def __init__(
299+
self,
300+
settings_cls: type["BaseSettings"],
301+
env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL,
302+
env_file_encoding: Optional[str] = None,
303+
case_sensitive: Optional[bool] = None,
304+
env_nested_delimiter: Optional[str] = None,
305+
) -> None:
306+
super().__init__(
307+
settings_cls,
308+
env_file,
309+
env_file_encoding,
310+
case_sensitive,
311+
env_nested_delimiter,
312+
)
313+
314+
@override
315+
def get_setting_fields(self) -> Iterable[ModelField]:
316+
return model_fields(self.settings_cls)
317+
318+
319+
class PluginEnvSettingsSource(EnvSettingsSource):
320+
def __init__(
321+
self,
322+
config_cls: type[BaseModel],
323+
driver_config: "Config",
324+
) -> None:
325+
setting_config: "SettingsConfig" = model_config(driver_config.__class__)
326+
super().__init__(
327+
BaseSettings,
328+
env_file=setting_config.get("env_file", None),
329+
env_file_encoding=setting_config.get("env_file_encoding", "utf-8"),
330+
case_sensitive=setting_config.get("case_sensitive", False),
331+
env_nested_delimiter=setting_config.get("env_nested_delimiter", None),
332+
)
333+
self.config_cls = config_cls
334+
335+
@override
336+
def get_setting_fields(self) -> Iterable[ModelField]:
337+
return model_fields(self.config_cls)
338+
286339

287340
if PYDANTIC_V2: # pragma: pydantic-v2
288341

nonebot/plugin/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
from pydantic import BaseModel
4747

4848
from nonebot import get_driver
49-
from nonebot.compat import model_dump, type_validate_python
49+
from nonebot.compat import type_validate_python
50+
from nonebot.config import PluginEnvSettingsSource
5051

5152
C = TypeVar("C", bound=BaseModel)
5253

@@ -172,7 +173,8 @@ def get_available_plugin_names() -> set[str]:
172173

173174
def get_plugin_config(config: type[C]) -> C:
174175
"""从全局配置获取当前插件需要的配置项。"""
175-
return type_validate_python(config, model_dump(get_driver().config))
176+
env_settings = PluginEnvSettingsSource(config, get_driver().config)
177+
return type_validate_python(config, env_settings())
176178

177179

178180
from .load import inherit_supported_adapters as inherit_supported_adapters

tests/test_plugin/test_load.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
import nonebot
11+
from nonebot.compat import model_dump
1112
from nonebot.plugin import (
1213
Plugin,
1314
PluginManager,
@@ -184,6 +185,18 @@ def test_plugin_metadata():
184185
assert plugin.metadata.get_supported_adapters() == {FakeAdapter}
185186

186187

188+
def test_plugin_load_env_config(monkeypatch: pytest.MonkeyPatch):
189+
no_dummy_val = "no_dummy_val"
190+
monkeypatch.setenv("CUSTOM", no_dummy_val)
191+
from plugins.metadata import Config
192+
193+
global_config = nonebot.get_driver().config
194+
assert "custom" not in model_dump(global_config)
195+
196+
config = nonebot.get_plugin_config(Config)
197+
assert config.custom == no_dummy_val
198+
199+
187200
def test_inherit_supported_adapters_not_found():
188201
with pytest.raises(RuntimeError):
189202
inherit_supported_adapters("some_plugin_not_exist")

0 commit comments

Comments
 (0)