Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 63 additions & 35 deletions nonebot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class SettingsError(ValueError): ...


class BaseSettingsSource(abc.ABC):
def __init__(self, settings_cls: type["BaseSettings"]) -> None:
def __init__(self, settings_cls: type[BaseModel]) -> None:
self.settings_cls = settings_cls

@property
Expand All @@ -67,7 +67,7 @@ class InitSettingsSource(BaseSettingsSource):
__slots__ = ("init_kwargs",)

def __init__(
self, settings_cls: type["BaseSettings"], init_kwargs: dict[str, Any]
self, settings_cls: type[BaseModel], init_kwargs: dict[str, Any]
) -> None:
self.init_kwargs = init_kwargs
super().__init__(settings_cls)
Expand All @@ -82,33 +82,17 @@ def __repr__(self) -> str:
class DotEnvSettingsSource(BaseSettingsSource):
def __init__(
self,
settings_cls: type["BaseSettings"],
env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL,
env_file_encoding: Optional[str] = None,
case_sensitive: Optional[bool] = None,
settings_cls: type[BaseModel],
env_file: Optional[DOTENV_TYPE],
env_file_encoding: str,
case_sensitive: Optional[bool] = False,
env_nested_delimiter: Optional[str] = None,
) -> None:
super().__init__(settings_cls)
self.env_file = (
env_file
if env_file is not ENV_FILE_SENTINEL
else self.config.get("env_file", (".env",))
)
self.env_file_encoding = (
env_file_encoding
if env_file_encoding is not None
else self.config.get("env_file_encoding", "utf-8")
)
self.case_sensitive = (
case_sensitive
if case_sensitive is not None
else self.config.get("case_sensitive", False)
)
self.env_nested_delimiter = (
env_nested_delimiter
if env_nested_delimiter is not None
else self.config.get("env_nested_delimiter", None)
)
self.env_file = env_file
self.env_file_encoding = env_file_encoding
self.case_sensitive = case_sensitive
self.env_nested_delimiter = env_nested_delimiter

def _apply_case_sensitive(self, var_name: str) -> str:
return var_name if self.case_sensitive else var_name.lower()
Expand Down Expand Up @@ -212,12 +196,33 @@ def __call__(self) -> dict[str, Any]:
for field in model_fields(self.settings_cls):
field_name = field.name
env_name = self._apply_case_sensitive(field_name)
alias_name = field.field_info.alias
alias_env_name = (
None if alias_name is None else self._apply_case_sensitive(alias_name)
)

# pydantic use alias name to validate if exist
if alias_name is not None:
field_name = alias_name

# try get values from env vars
env_val = env_vars.get(env_name, PydanticUndefined)
alias_env_val = (
PydanticUndefined
if alias_env_name is None
else env_vars.get(alias_env_name, PydanticUndefined)
)
# alias env value has higher priority
env_val = (
env_val
if isinstance(alias_env_val, PydanticUndefinedType)
else alias_env_val
)
# delete from file vars when used
if env_name in env_file_vars:
del env_file_vars[env_name]
if alias_env_name is not None and alias_env_name in env_file_vars:
del env_file_vars[alias_env_name]

is_complex, allow_parse_failure = self._field_is_complex(field)
if is_complex:
Expand Down Expand Up @@ -331,25 +336,48 @@ def __init__(
_env_nested_delimiter: Optional[str] = None,
**values: Any,
) -> None:
settings_config = model_config(__settings_self__.__class__)
env_file = (
_env_file
if _env_file is not ENV_FILE_SENTINEL
else settings_config.get("env_file", (".env",))
)
env_file_encoding = (
_env_file_encoding
if _env_file_encoding is not None
else settings_config.get("env_file_encoding", "utf-8")
)
env_nested_delimiter = (
_env_nested_delimiter
if _env_nested_delimiter is not None
else settings_config.get("env_nested_delimiter", None)
)

super().__init__(
**__settings_self__._settings_build_values(
__settings_self__.__class__,
values,
env_file=_env_file,
env_file_encoding=_env_file_encoding,
env_nested_delimiter=_env_nested_delimiter,
env_file=env_file,
env_file_encoding=env_file_encoding,
env_nested_delimiter=env_nested_delimiter,
)
)

__settings_self__._env_file = env_file
__settings_self__._env_file_encoding = env_file_encoding
__settings_self__._env_nested_delimiter = env_nested_delimiter

@staticmethod
def _settings_build_values(
self,
settings_cls: type[BaseModel],
init_kwargs: dict[str, Any],
env_file: Optional[DOTENV_TYPE] = None,
env_file_encoding: Optional[str] = None,
env_nested_delimiter: Optional[str] = None,
env_file: Optional[DOTENV_TYPE],
env_file_encoding: str,
env_nested_delimiter: Optional[str],
) -> dict[str, Any]:
init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs)
init_settings = InitSettingsSource(settings_cls, init_kwargs=init_kwargs)
env_settings = DotEnvSettingsSource(
self.__class__,
settings_cls,
env_file=env_file,
env_file_encoding=env_file_encoding,
env_nested_delimiter=env_nested_delimiter,
Expand Down
13 changes: 12 additions & 1 deletion nonebot/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

from nonebot import get_driver
from nonebot.compat import model_dump, type_validate_python
from nonebot.config import BaseSettings

C = TypeVar("C", bound=BaseModel)

Expand Down Expand Up @@ -172,7 +173,17 @@ def get_available_plugin_names() -> set[str]:

def get_plugin_config(config: type[C]) -> C:
"""从全局配置获取当前插件需要的配置项。"""
return type_validate_python(config, model_dump(get_driver().config))
global_config = get_driver().config
return type_validate_python(
config,
BaseSettings._settings_build_values(
config,
model_dump(global_config),
env_file=global_config._env_file,
env_file_encoding=global_config._env_file_encoding,
env_nested_delimiter=global_config._env_nested_delimiter,
),
)


from .load import inherit_supported_adapters as inherit_supported_adapters
Expand Down
1 change: 1 addition & 0 deletions tests/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ NESTED__C__C=3
NESTED__COMPLEX=[1, 2, 3]
NESTED_INNER__A=1
NESTED_INNER__B=2
ALIAS_SIMPLE=aliased_simple
OTHER_SIMPLE=simple
OTHER_NESTED={"a": 1}
OTHER_NESTED__B=2
Expand Down
3 changes: 3 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Config( # pyright: ignore[reportIncompatibleVariableOverride]
complex_union: Union[int, list[int]] = 1
nested: Simple = Simple()
nested_inner: Simple = Simple()
aliased_simple: str = Field(default="", alias="alias_simple")


class ExampleWithoutDelimiter(Example):
Expand Down Expand Up @@ -85,6 +86,8 @@ def test_config_with_env():
with pytest.raises(AttributeError):
config.nested_inner__b

assert config.aliased_simple == "aliased_simple"

assert config.common_config == "common"

assert config.other_simple == "simple"
Expand Down
27 changes: 26 additions & 1 deletion tests/test_plugin/test_get.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
import pytest

import nonebot
from nonebot.plugin import PluginManager, _managers
Expand Down Expand Up @@ -67,3 +68,27 @@ class Config(BaseModel):
config = nonebot.get_plugin_config(Config)
assert isinstance(config, Config)
assert config.plugin_config == 1


def test_get_plugin_config_with_env(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("PLUGIN_CONFIG_ONE", "no_dummy_val")
monkeypatch.setenv("PLUGIN_SUB_CONFIG__TWO", "two")
monkeypatch.setenv("PLUGIN_CFG_THREE", "33")
monkeypatch.setenv("CONFIG_FROM_INIT", "impossible")

class SubConfig(BaseModel):
two: str = "dummy_val"

class Config(BaseModel):
plugin_config: int
plugin_config_one: str = "dummy_val"
plugin_sub_config: SubConfig = Field(default_factory=SubConfig)
plugin_config_three: int = Field(default=3, alias="plugin_cfg_three")
config_from_init: str = "dummy_val"

config = nonebot.get_plugin_config(Config)
assert config.plugin_config == 1
assert config.plugin_config_one == "no_dummy_val"
assert config.plugin_sub_config.two == "two"
assert config.plugin_config_three == 33
assert config.config_from_init == "init"
12 changes: 9 additions & 3 deletions website/docs/appendices/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export CUSTOM_CONFIG='config in environment variables'
那最终 NoneBot 所读取的内容为环境变量中的内容,即 `config in environment variables`。

:::caution 注意
NoneBot 不会自发读取未被定义的配置项的环境变量,如果需要读取某一环境变量需要在 dotenv 配置文件中进行声明
如果一个环境变量既不是 NoneBot 的[**内置配置项**](#内置配置项),也不是任何插件所定义的[**插件配置**](#插件配置),那么 NoneBot 不会自发读取该环境变量,需要在 dotenv 配置文件中先行声明
:::

### dotenv 配置文件
Expand Down Expand Up @@ -242,11 +242,17 @@ weather = on_command(

这种方式可以简洁、高效地读取配置项,同时也可以设置默认值或者在运行时对配置项进行合法性检查,防止由于配置项导致的插件出错等情况出现。

:::tip 提示
:::tip 可配置的事件响应优先级
发布插件应该为自身的事件响应器提供可配置的优先级,以便插件使用者可以自定义多个插件间的响应顺序。
:::

由于插件配置项是从全局配置中读取的,通常我们需要在配置项名称前面添加前缀名,以防止配置项冲突。例如在上方的示例中,我们就添加了配置项前缀 `weather_`。但是这样会导致在使用配置项时过长的变量名,因此我们可以使用 `pydantic` 的 `alias` 或者通过配置 scope 来简化配置项名称。这里我们以 scope 配置为例:
:::tip 插件配置获取逻辑
无论是否在 dotenv 文件中声明了插件配置项,使用 `get_plugin_config` 获取插件配置模型中定义的配置项时都遵循[**配置项的加载**](#配置项的加载)一节中的优先级顺序进行读取。
:::

### 避免插件配置名称冲突

由于插件配置项是从全局配置和环境变量中读取的,通常我们需要在配置项名称前面添加前缀名,以防止配置项冲突。例如在上方的示例中,我们就添加了配置项前缀 `weather_`。但是这样会导致使用配置项时变量名过长,此时我们可以使用 `pydantic` 的 `alias` 或者通过配置 scope 来简化配置项名称。这里我们以 scope 配置为例:

```python title=weather/config.py
from pydantic import BaseModel
Expand Down