Skip to content

Commit 40230ab

Browse files
kschwabhramezani
andauthored
Fix for JSON on optional nested types. (#217)
Co-authored-by: Hasan Ramezani <[email protected]>
1 parent a6f6fa4 commit 40230ab

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

docs/index.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,6 @@ print(Settings().model_dump())
324324
`env_nested_delimiter` can be configured via the `model_config` as shown above, or via the
325325
`_env_nested_delimiter` keyword argument on instantiation.
326326

327-
JSON is only parsed in top-level fields, if you need to parse JSON in sub-models, you will need to implement
328-
validators on those models.
329-
330327
Nested environment variables take precedence over the top-level environment variable JSON
331328
(e.g. in the example above, `SUB_MODEL__V2` trumps `SUB_MODEL`).
332329

pydantic_settings/sources.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from dotenv import dotenv_values
1313
from pydantic import AliasChoices, AliasPath, BaseModel, Json, TypeAdapter
14-
from pydantic._internal._typing_extra import origin_is_union
15-
from pydantic._internal._utils import deep_update, lenient_issubclass
14+
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union
15+
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
1616
from pydantic.fields import FieldInfo
1717
from typing_extensions import get_args, get_origin
1818

@@ -188,6 +188,8 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
188188
)
189189
else: # string validation alias
190190
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
191+
elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
192+
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
191193
else:
192194
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
193195

@@ -478,24 +480,21 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
478480
# simplest case, field is not complex, we only need to add the value if it was found
479481
return value
480482

481-
def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool:
482-
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
483-
484483
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
485484
"""
486485
Find out if a field is complex, and if so whether JSON errors should be ignored
487486
"""
488487
if self.field_is_complex(field):
489488
allow_parse_failure = False
490-
elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata):
489+
elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
491490
allow_parse_failure = True
492491
else:
493492
return False, False
494493

495494
return True, allow_parse_failure
496495

497496
@staticmethod
498-
def next_field(field: FieldInfo | None, key: str) -> FieldInfo | None:
497+
def next_field(field: FieldInfo | Any | None, key: str) -> FieldInfo | None:
499498
"""
500499
Find the field in a sub model by key(env name)
501500
@@ -524,11 +523,17 @@ class Cfg(BaseSettings):
524523
Returns:
525524
Field if it finds the next field otherwise `None`.
526525
"""
527-
if not field or origin_is_union(get_origin(field.annotation)):
528-
# no support for Unions of complex BaseSettings fields
526+
if not field:
529527
return None
530-
elif field.annotation and hasattr(field.annotation, 'model_fields') and field.annotation.model_fields.get(key):
531-
return field.annotation.model_fields[key]
528+
529+
annotation = field.annotation if isinstance(field, FieldInfo) else field
530+
if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes):
531+
for type_ in get_args(annotation):
532+
type_has_key = EnvSettingsSource.next_field(type_, key)
533+
if type_has_key:
534+
return type_has_key
535+
elif is_model_class(annotation) and annotation.model_fields.get(key):
536+
return annotation.model_fields[key]
532537

533538
return None
534539

@@ -721,3 +726,7 @@ def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
721726
return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
722727
annotation
723728
)
729+
730+
731+
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
732+
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))

tests/test_settings.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ class Cfg(BaseSettings):
198198
}
199199

200200

201+
def test_nested_env_optional_json(env):
202+
class Child(BaseModel):
203+
num_list: Optional[List[int]] = None
204+
205+
class Cfg(BaseSettings, env_nested_delimiter='__'):
206+
child: Optional[Child] = None
207+
208+
env.set('CHILD__NUM_LIST', '[1,2,3]')
209+
cfg = Cfg()
210+
assert cfg.model_dump() == {
211+
'child': {
212+
'num_list': [1, 2, 3],
213+
},
214+
}
215+
216+
201217
def test_nested_env_delimiter_with_prefix(env):
202218
class Subsettings(BaseSettings):
203219
banana: str
@@ -1259,6 +1275,21 @@ class Settings(BaseSettings):
12591275
assert Settings().model_dump() == {'foo': {'a': 'b'}}
12601276

12611277

1278+
def test_secrets_nested_optional_json(tmp_path):
1279+
p = tmp_path / 'foo'
1280+
p.write_text('{"a": 10}')
1281+
1282+
class Foo(BaseModel):
1283+
a: int
1284+
1285+
class Settings(BaseSettings):
1286+
foo: Optional[Foo] = None
1287+
1288+
model_config = SettingsConfigDict(secrets_dir=tmp_path)
1289+
1290+
assert Settings().model_dump() == {'foo': {'a': 10}}
1291+
1292+
12621293
def test_secrets_path_invalid_json(tmp_path):
12631294
p = tmp_path / 'foo'
12641295
p.write_text('{"a": "b"')

0 commit comments

Comments
 (0)