diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 63520afe..452eea94 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -43,7 +43,7 @@ from pydantic._internal._repr import Representation from pydantic._internal._signature import _field_name_for_signature from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base -from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass +from pydantic._internal._utils import KeyType, deep_update, is_model_class, lenient_issubclass from pydantic.dataclasses import is_pydantic_dataclass from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined @@ -658,6 +658,48 @@ def __repr__(self) -> str: return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})' +def _deep_update( + mapping: Dict[KeyType, Any] | List[Any], *updating_mappings: Dict[KeyType, Any] +) -> Dict[KeyType, Any] | List[Any]: + # if no updating mappings, return the original mapping + if not all(updating_mappings): + return mapping + + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + if isinstance(updated_mapping, list): + # list case + if isinstance(updating_mapping, list): + for i, v in enumerate(updating_mapping): + # if i < len(updated_mapping): + updated_mapping[i] = _deep_update(updated_mapping[i], v) + # else: + # updated_mapping.append(v) + elif isinstance(updating_mapping, dict): + for key, value in updating_mapping.items(): + # index is a stored as a key in the dict + index = int(key) + if len(updated_mapping) < index: + continue + # Add empty dict so we can update it + if len(updated_mapping) == index: + updated_mapping.append({}) + # Update it + if isinstance(value, dict): + updated_mapping[index] = _deep_update(updated_mapping[index], value) + else: + updated_mapping[index] = value + else: + raise NotImplementedError + else: + for k, v in updating_mapping.items(): + if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict): + updated_mapping[k] = _deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + class EnvSettingsSource(PydanticBaseEnvSettingsSource): """ Source class for loading settings values from environment variables. @@ -745,8 +787,8 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val if not allow_parse_failure: raise e - if isinstance(value, dict): - return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars)) + if isinstance(value, (dict, list)): + return _deep_update(value, self.explode_env_vars(field_name, field, self.env_vars)) else: return value elif value is not None: @@ -804,6 +846,15 @@ class Cfg(BaseSettings): return None annotation = field.annotation if isinstance(field, FieldInfo) else field + if get_origin(annotation) is list: + try: + # check if key is an integer. If so, it's an index. we fake a field info with the list type + # so future calls can continue to traverse the model and set proper types for leaf nodes + int(key) + if list_type := [*get_args(annotation), None].pop(0): + return FieldInfo(annotation=list_type) + except ValueError: + pass if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes): for type_ in get_args(annotation): type_has_key = EnvSettingsSource.next_field(type_, key, case_sensitive) @@ -822,7 +873,9 @@ class Cfg(BaseSettings): return None - def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]: + def explode_env_vars( + self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None] + ) -> dict[str, Any] | list[Any]: """ Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries. @@ -834,14 +887,15 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[ env_vars: Environment variables. Returns: - A dictionary contains extracted values from nested env values. + A list or a dictionary contains extracted values from nested env values. """ is_dict = lenient_issubclass(get_origin(field.annotation), dict) + is_list = lenient_issubclass(get_origin(field.annotation), list) prefixes = [ f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name) ] - result: dict[str, Any] = {} + result: dict[str, Any] | list[Any] = {} for env_name, env_val in env_vars.items(): if not any(env_name.startswith(prefix) for prefix in prefixes): continue @@ -859,11 +913,11 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[ target_field = self.next_field(target_field, last_key, self.case_sensitive) # check if env_val maps to a complex field and if so, parse the env_val - if (target_field or is_dict) and env_val: + if (target_field or is_dict or is_list) and env_val: if target_field: is_complex, allow_json_failure = self._field_is_complex(target_field) else: - # nested field type is dict + # nested field type is dict or list is_complex, allow_json_failure = True, True if is_complex: try: @@ -875,6 +929,20 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[ if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}: env_var[last_key] = env_val + def _transform_list_based_field(annotation: Annotated, values: dict[str, Any]) -> dict[str, Any] | list[Any]: + assert lenient_issubclass(get_origin(annotation), list) + if lenient_issubclass(get_args(annotation)[0], list): + result = [] + for i in (str(i) for i in range(len(values))): + if i not in values: + raise ValueError(f'Expected entry with index {i} for {field_name}') + result.append(_transform_list_based_field(get_args(annotation)[0], values[i])) + return result + else: + return values + + if is_list: + result = _transform_list_based_field(field.annotation, result) return result def __repr__(self) -> str: diff --git a/tests/test_settings.py b/tests/test_settings.py index d6da3119..ec1a70e6 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -257,6 +257,7 @@ class Cfg(BaseSettings): v0: str v0_union: Union[SubValue, int] top: TopValue + top_collection: List[TopValue] model_config = SettingsConfigDict(env_nested_delimiter='__') @@ -268,6 +269,18 @@ class Cfg(BaseSettings): env.set('v0_union', '0') env.set('top__sub__sub_sub__v6', '6') env.set('top__sub__v4', '4') + + env.set( + 'top_collection', + '[{"v1": "json-1", "v2": "json-2", "sub": { "v5": "xx"}}]', + ) + + env.set('top_collection__0__sub__v5', '5') + env.set('top_collection__0__v2', '2') + env.set('top_collection__0__v3', '3') + env.set('top_collection__0__sub__sub_sub__v6', '6') + env.set('top_collection__0__sub__v4', '4') + cfg = Cfg() assert cfg.model_dump() == { 'v0': '0', @@ -278,6 +291,49 @@ class Cfg(BaseSettings): 'v3': '3', 'sub': {'v4': '4', 'v5': 5, 'sub_sub': {'v6': '6'}}, }, + 'top_collection': [ + { + 'v1': 'json-1', + 'v2': '2', + 'v3': '3', + 'sub': {'v4': '4', 'v5': 5, 'sub_sub': {'v6': '6'}}, + } + ], + } + + +def test_nested_env_delimiter_lists_index(env): + class TopValue(BaseSettings): + v1: Optional[str] = None + v2: int + + class ListCfg(BaseSettings): + top_list: List[int] + top_collection: List[List[TopValue]] + + model_config = SettingsConfigDict(env_nested_delimiter='__') + + env.set('top_list__0', '3') + env.set('top_list__1', '2') + env.set('top_list__2', '1') + env.set('top_list__5', 'out of bounds index') + env.set('top_list', '[1]') + + env.set('top_collection', '[[{"v1": "json-1", "v2": "json-2"}]]') + env.set('top_collection__0__0__v1', 'json-1') + env.set('top_collection__0__1__v2', '6') + env.set('top_collection__0__0__v2', '5') + env.set('top_collection__0__3__v1', 'out of bounds index') + + cfg = ListCfg() + assert cfg.model_dump() == { + 'top_list': [3, 2, 1], + 'top_collection': [ + [ + {'v1': 'json-1', 'v2': 5}, + {'v1': None, 'v2': 6}, + ] + ], }