Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 76 additions & 8 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pydantic._internal._repr import Representation
from pydantic._internal._signature import _field_name_for_signature
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
from pydantic._internal._utils import KeyType, deep_update, is_model_class, lenient_issubclass
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
Expand Down Expand Up @@ -658,6 +658,48 @@ def __repr__(self) -> str:
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'


def _deep_update(
mapping: Dict[KeyType, Any] | List[Any], *updating_mappings: Dict[KeyType, Any]
) -> Dict[KeyType, Any] | List[Any]:
# if no updating mappings, return the original mapping
if not all(updating_mappings):
return mapping

updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
if isinstance(updated_mapping, list):
# list case
if isinstance(updating_mapping, list):
for i, v in enumerate(updating_mapping):
# if i < len(updated_mapping):
updated_mapping[i] = _deep_update(updated_mapping[i], v)
# else:
# updated_mapping.append(v)
elif isinstance(updating_mapping, dict):
for key, value in updating_mapping.items():
# index is a stored as a key in the dict
index = int(key)
if len(updated_mapping) < index:
continue
# Add empty dict so we can update it
if len(updated_mapping) == index:
updated_mapping.append({})
# Update it
if isinstance(value, dict):
updated_mapping[index] = _deep_update(updated_mapping[index], value)
else:
updated_mapping[index] = value
else:
raise NotImplementedError
else:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = _deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping


class EnvSettingsSource(PydanticBaseEnvSettingsSource):
"""
Source class for loading settings values from environment variables.
Expand Down Expand Up @@ -745,8 +787,8 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
if not allow_parse_failure:
raise e

if isinstance(value, dict):
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
if isinstance(value, (dict, list)):
return _deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
else:
return value
elif value is not None:
Expand Down Expand Up @@ -804,6 +846,15 @@ class Cfg(BaseSettings):
return None

annotation = field.annotation if isinstance(field, FieldInfo) else field
if get_origin(annotation) is list:
try:
# check if key is an integer. If so, it's an index. we fake a field info with the list type
# so future calls can continue to traverse the model and set proper types for leaf nodes
int(key)
Copy link
Member

@hramezani hramezani Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need a test case with non int index

if list_type := [*get_args(annotation), None].pop(0):
return FieldInfo(annotation=list_type)
except ValueError:
pass
if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes):
for type_ in get_args(annotation):
type_has_key = EnvSettingsSource.next_field(type_, key, case_sensitive)
Expand All @@ -822,7 +873,9 @@ class Cfg(BaseSettings):

return None

def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
def explode_env_vars(
self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]
) -> dict[str, Any] | list[Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.

Expand All @@ -834,14 +887,15 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
env_vars: Environment variables.

Returns:
A dictionary contains extracted values from nested env values.
A list or a dictionary contains extracted values from nested env values.
"""
is_dict = lenient_issubclass(get_origin(field.annotation), dict)
is_list = lenient_issubclass(get_origin(field.annotation), list)

prefixes = [
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
]
result: dict[str, Any] = {}
result: dict[str, Any] | list[Any] = {}
for env_name, env_val in env_vars.items():
if not any(env_name.startswith(prefix) for prefix in prefixes):
continue
Expand All @@ -859,11 +913,11 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
target_field = self.next_field(target_field, last_key, self.case_sensitive)

# check if env_val maps to a complex field and if so, parse the env_val
if (target_field or is_dict) and env_val:
if (target_field or is_dict or is_list) and env_val:
if target_field:
is_complex, allow_json_failure = self._field_is_complex(target_field)
else:
# nested field type is dict
# nested field type is dict or list
is_complex, allow_json_failure = True, True
if is_complex:
try:
Expand All @@ -875,6 +929,20 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
env_var[last_key] = env_val

def _transform_list_based_field(annotation: Annotated, values: dict[str, Any]) -> dict[str, Any] | list[Any]:
assert lenient_issubclass(get_origin(annotation), list)
if lenient_issubclass(get_args(annotation)[0], list):
result = []
for i in (str(i) for i in range(len(values))):
if i not in values:
raise ValueError(f'Expected entry with index {i} for {field_name}')
result.append(_transform_list_based_field(get_args(annotation)[0], values[i]))
return result
else:
return values

if is_list:
result = _transform_list_based_field(field.annotation, result)
return result

def __repr__(self) -> str:
Expand Down
56 changes: 56 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class Cfg(BaseSettings):
v0: str
v0_union: Union[SubValue, int]
top: TopValue
top_collection: List[TopValue]

model_config = SettingsConfigDict(env_nested_delimiter='__')

Expand All @@ -268,6 +269,18 @@ class Cfg(BaseSettings):
env.set('v0_union', '0')
env.set('top__sub__sub_sub__v6', '6')
env.set('top__sub__v4', '4')

env.set(
'top_collection',
'[{"v1": "json-1", "v2": "json-2", "sub": { "v5": "xx"}}]',
)

env.set('top_collection__0__sub__v5', '5')
env.set('top_collection__0__v2', '2')
env.set('top_collection__0__v3', '3')
env.set('top_collection__0__sub__sub_sub__v6', '6')
env.set('top_collection__0__sub__v4', '4')

cfg = Cfg()
assert cfg.model_dump() == {
'v0': '0',
Expand All @@ -278,6 +291,49 @@ class Cfg(BaseSettings):
'v3': '3',
'sub': {'v4': '4', 'v5': 5, 'sub_sub': {'v6': '6'}},
},
'top_collection': [
{
'v1': 'json-1',
'v2': '2',
'v3': '3',
'sub': {'v4': '4', 'v5': 5, 'sub_sub': {'v6': '6'}},
}
],
}


def test_nested_env_delimiter_lists_index(env):
class TopValue(BaseSettings):
v1: Optional[str] = None
v2: int

class ListCfg(BaseSettings):
top_list: List[int]
top_collection: List[List[TopValue]]

model_config = SettingsConfigDict(env_nested_delimiter='__')

env.set('top_list__0', '3')
env.set('top_list__1', '2')
env.set('top_list__2', '1')
env.set('top_list__5', 'out of bounds index')
env.set('top_list', '[1]')

env.set('top_collection', '[[{"v1": "json-1", "v2": "json-2"}]]')
env.set('top_collection__0__0__v1', 'json-1')
env.set('top_collection__0__1__v2', '6')
env.set('top_collection__0__0__v2', '5')
env.set('top_collection__0__3__v1', 'out of bounds index')

cfg = ListCfg()
assert cfg.model_dump() == {
'top_list': [3, 2, 1],
'top_collection': [
[
{'v1': 'json-1', 'v2': 5},
{'v1': None, 'v2': 6},
]
],
}


Expand Down
Loading