Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions pydantic_settings/sources/providers/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
)

from pydantic import Json, TypeAdapter, ValidationError
from pydantic._internal._utils import deep_update, is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.fields import FieldInfo
Expand All @@ -17,6 +18,7 @@
from ..base import PydanticBaseEnvSettingsSource
from ..types import EnvNoneType
from ..utils import (
_annotation_contains_types,
_annotation_enum_name_to_val,
_get_model_fields,
_union_is_complex,
Expand Down Expand Up @@ -125,7 +127,7 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
return value
elif value is not None:
# simplest case, field is not complex, we only need to add the value if it was found
return value
return self._coerce_env_val_strict(field, value)

def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
"""
Expand Down Expand Up @@ -256,10 +258,31 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
raise e
if isinstance(env_var, dict):
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
env_var[last_key] = env_val

env_var[last_key] = self._coerce_env_val_strict(target_field, env_val)
return result

def _coerce_env_val_strict(self, field: FieldInfo | None, value: Any) -> Any:
"""
Coerce environment string values based on field annotation if model config is `strict=True`.

Args:
field: The field.
value: The value to coerce.

Returns:
The coerced value if successful, otherwise the original value.
"""
try:
if self.config.get('strict') and isinstance(value, str) and field is not None:
if value == self.env_parse_none_str:
return value
if not _annotation_contains_types(field.annotation, (Json,), is_instance=True):
return TypeAdapter(field.annotation).validate_python(value)
except ValidationError:
# Allow validation error to be raised at time of instatiation
pass
return value

def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
Expand Down
15 changes: 12 additions & 3 deletions pydantic_settings/sources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,24 @@ def _annotation_contains_types(
types: tuple[Any, ...],
is_include_origin: bool = True,
is_strip_annotated: bool = False,
is_instance: bool = False,
) -> bool:
"""Check if a type annotation contains any of the specified types."""
if is_strip_annotated:
annotation = _strip_annotated(annotation)
if is_include_origin is True and get_origin(annotation) in types:
return True
if is_include_origin is True:
origin = get_origin(annotation)
if origin in types:
return True
if is_instance and any(isinstance(origin, type_) for type_ in types):
return True
for type_ in get_args(annotation):
if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated):
if _annotation_contains_types(
type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated, is_instance=is_instance
):
return True
if is_instance and any(isinstance(annotation, type_) for type_ in types):
return True
return annotation in types


Expand Down
38 changes: 38 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3175,3 +3175,41 @@ class Settings(BaseSettings):
f'source to the settings sources via the settings_customise_sources hook.'
)
assert warning.message.args[0] == expected_message


def test_env_strict_coercion(env):
class SubModel(BaseModel):
my_str: str
my_int: int

class Settings(BaseSettings, env_nested_delimiter='__'):
my_str: str
my_int: int
sub_model: SubModel

env.set('MY_STR', '0')
env.set('MY_INT', '0')
env.set('SUB_MODEL__MY_STR', '1')
env.set('SUB_MODEL__MY_INT', '1')
Settings().model_dump() == {
'my_str': '0',
'my_int': 0,
'sub_model': {
'my_str': '1',
'my_int': 1,
},
}

class StrictSettings(BaseSettings, env_nested_delimiter='__', strict=True):
my_str: str
my_int: int
sub_model: SubModel

StrictSettings().model_dump() == {
'my_str': '0',
'my_int': 0,
'sub_model': {
'my_str': '1',
'my_int': 1,
},
}