diff --git a/pydantic_settings/sources/providers/env.py b/pydantic_settings/sources/providers/env.py index 5a350f1d..0c6e257e 100644 --- a/pydantic_settings/sources/providers/env.py +++ b/pydantic_settings/sources/providers/env.py @@ -7,6 +7,7 @@ Any, ) +from pydantic import Json, TypeAdapter, ValidationError from pydantic._internal._utils import deep_update, is_model_class from pydantic.dataclasses import is_pydantic_dataclass from pydantic.fields import FieldInfo @@ -17,6 +18,7 @@ from ..base import PydanticBaseEnvSettingsSource from ..types import EnvNoneType from ..utils import ( + _annotation_contains_types, _annotation_enum_name_to_val, _get_model_fields, _union_is_complex, @@ -125,7 +127,7 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val return value elif value is not None: # simplest case, field is not complex, we only need to add the value if it was found - return value + return self._coerce_env_val_strict(field, value) def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]: """ @@ -256,10 +258,31 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[ raise e if isinstance(env_var, dict): if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}: - env_var[last_key] = env_val - + env_var[last_key] = self._coerce_env_val_strict(target_field, env_val) return result + def _coerce_env_val_strict(self, field: FieldInfo | None, value: Any) -> Any: + """ + Coerce environment string values based on field annotation if model config is `strict=True`. + + Args: + field: The field. + value: The value to coerce. + + Returns: + The coerced value if successful, otherwise the original value. + """ + try: + if self.config.get('strict') and isinstance(value, str) and field is not None: + if value == self.env_parse_none_str: + return value + if not _annotation_contains_types(field.annotation, (Json,), is_instance=True): + return TypeAdapter(field.annotation).validate_python(value) + except ValidationError: + # Allow validation error to be raised at time of instatiation + pass + return value + def __repr__(self) -> str: return ( f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, ' diff --git a/pydantic_settings/sources/utils.py b/pydantic_settings/sources/utils.py index 56bfb3ea..3c004608 100644 --- a/pydantic_settings/sources/utils.py +++ b/pydantic_settings/sources/utils.py @@ -92,15 +92,24 @@ def _annotation_contains_types( types: tuple[Any, ...], is_include_origin: bool = True, is_strip_annotated: bool = False, + is_instance: bool = False, ) -> bool: """Check if a type annotation contains any of the specified types.""" if is_strip_annotated: annotation = _strip_annotated(annotation) - if is_include_origin is True and get_origin(annotation) in types: - return True + if is_include_origin is True: + origin = get_origin(annotation) + if origin in types: + return True + if is_instance and any(isinstance(origin, type_) for type_ in types): + return True for type_ in get_args(annotation): - if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated): + if _annotation_contains_types( + type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated, is_instance=is_instance + ): return True + if is_instance and any(isinstance(annotation, type_) for type_ in types): + return True return annotation in types diff --git a/tests/test_settings.py b/tests/test_settings.py index 7949eb38..fd5c0154 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -3200,3 +3200,41 @@ class Settings(BaseSettings): f'source to the settings sources via the settings_customise_sources hook.' ) assert warning.message.args[0] == expected_message + + +def test_env_strict_coercion(env): + class SubModel(BaseModel): + my_str: str + my_int: int + + class Settings(BaseSettings, env_nested_delimiter='__'): + my_str: str + my_int: int + sub_model: SubModel + + env.set('MY_STR', '0') + env.set('MY_INT', '0') + env.set('SUB_MODEL__MY_STR', '1') + env.set('SUB_MODEL__MY_INT', '1') + Settings().model_dump() == { + 'my_str': '0', + 'my_int': 0, + 'sub_model': { + 'my_str': '1', + 'my_int': 1, + }, + } + + class StrictSettings(BaseSettings, env_nested_delimiter='__', strict=True): + my_str: str + my_int: int + sub_model: SubModel + + StrictSettings().model_dump() == { + 'my_str': '0', + 'my_int': 0, + 'sub_model': { + 'my_str': '1', + 'my_int': 1, + }, + }