Skip to content

Commit 738641c

Browse files
authored
Merge branch 'main' into issue-482
2 parents 2e32d23 + a903697 commit 738641c

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

pydantic_settings/sources.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,9 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
661661
a flag to determine whether value is complex.
662662
"""
663663

664-
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
664+
field_infos = self._extract_field_info(field, field_name)
665+
preferred_key, *_ = field_infos[0]
666+
for field_key, env_name, value_is_complex in field_infos:
665667
# paths reversed to match the last-wins behaviour of `env_file`
666668
for secrets_path in reversed(self.secrets_paths):
667669
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
@@ -670,14 +672,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
670672
continue
671673

672674
if path.is_file():
673-
return path.read_text().strip(), field_key, value_is_complex
675+
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
676+
preferred_key = field_key
677+
return path.read_text().strip(), preferred_key, value_is_complex
674678
else:
675679
warnings.warn(
676680
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
677681
stacklevel=4,
678682
)
679683

680-
return None, field_key, value_is_complex
684+
return None, preferred_key, value_is_complex
681685

682686
def __repr__(self) -> str:
683687
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
@@ -725,12 +729,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
725729
"""
726730

727731
env_val: str | None = None
728-
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
732+
field_infos = self._extract_field_info(field, field_name)
733+
preferred_key, *_ = field_infos[0]
734+
for field_key, env_name, value_is_complex in field_infos:
729735
env_val = self.env_vars.get(env_name)
730736
if env_val is not None:
737+
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
738+
preferred_key = field_key
731739
break
732740

733-
return env_val, field_key, value_is_complex
741+
return env_val, preferred_key, value_is_complex
734742

735743
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
736744
"""
@@ -1426,8 +1434,8 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
14261434
raise SettingsError(f'CliSubCommand is not outermost annotation for {model.__name__}.{field_name}')
14271435
elif _annotation_contains_types(type_, (_CliPositionalArg,), is_include_origin=False):
14281436
raise SettingsError(f'CliPositionalArg is not outermost annotation for {model.__name__}.{field_name}')
1429-
if is_model_class(type_) or is_pydantic_dataclass(type_):
1430-
sub_models.append(type_) # type: ignore
1437+
if is_model_class(_strip_annotated(type_)) or is_pydantic_dataclass(_strip_annotated(type_)):
1438+
sub_models.append(_strip_annotated(type_))
14311439
return sub_models
14321440

14331441
def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:

tests/test_source_cli.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import typing_extensions
1111
from pydantic import (
1212
AliasChoices,
13+
AliasGenerator,
1314
AliasPath,
1415
BaseModel,
1516
ConfigDict,
1617
DirectoryPath,
18+
Discriminator,
1719
Field,
20+
Tag,
1821
ValidationError,
1922
)
2023
from pydantic import (
@@ -107,7 +110,7 @@ def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace:
107110
return self.parser.parse_args(*args, **kwargs)
108111

109112

110-
def test_validation_alias_with_cli_prefix():
113+
def test_cli_validation_alias_with_cli_prefix():
111114
class Settings(BaseSettings, cli_exit_on_error=False):
112115
foobar: str = Field(validation_alias='foo')
113116

@@ -119,6 +122,36 @@ class Settings(BaseSettings, cli_exit_on_error=False):
119122
assert CliApp.run(Settings, cli_args=['--p.foo', 'bar']).foobar == 'bar'
120123

121124

125+
@pytest.mark.parametrize(
126+
'alias_generator',
127+
[
128+
AliasGenerator(validation_alias=lambda s: AliasChoices(s, s.replace('_', '-'))),
129+
AliasGenerator(validation_alias=lambda s: AliasChoices(s.replace('_', '-'), s)),
130+
],
131+
)
132+
def test_cli_alias_resolution_consistency_with_env(env, alias_generator):
133+
class SubModel(BaseModel):
134+
v1: str = 'model default'
135+
136+
class Settings(BaseSettings):
137+
model_config = SettingsConfigDict(
138+
env_nested_delimiter='__',
139+
nested_model_default_partial_update=True,
140+
alias_generator=alias_generator,
141+
)
142+
143+
sub_model: SubModel = SubModel(v1='top default')
144+
145+
assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'top default'}}
146+
147+
env.set('SUB_MODEL__V1', 'env default')
148+
assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'env default'}}
149+
150+
assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli default']).model_dump() == {
151+
'sub_model': {'v1': 'cli default'}
152+
}
153+
154+
122155
def test_cli_nested_arg():
123156
class SubSubValue(BaseModel):
124157
v6: str
@@ -2237,3 +2270,25 @@ class MySettings(BaseSettings):
22372270
CliApp.run(
22382271
MySettings, cli_args=['--bac', 'cli abbrev are invalid for internal parser'], cli_exit_on_error=False
22392272
)
2273+
2274+
2275+
def test_cli_submodels_strip_annotated():
2276+
class PolyA(BaseModel):
2277+
a: int = 1
2278+
type: Literal['a'] = 'a'
2279+
2280+
class PolyB(BaseModel):
2281+
b: str = '2'
2282+
type: Literal['b'] = 'b'
2283+
2284+
def _get_type(model: Union[BaseModel, Dict]) -> str:
2285+
if isinstance(model, dict):
2286+
return model.get('type', 'a')
2287+
return model.type # type: ignore
2288+
2289+
Poly = Annotated[Union[Annotated[PolyA, Tag('a')], Annotated[PolyB, Tag('b')]], Discriminator(_get_type)]
2290+
2291+
class WithUnion(BaseSettings):
2292+
poly: Poly
2293+
2294+
assert CliApp.run(WithUnion, ['--poly.type=a']).model_dump() == {'poly': {'a': 1, 'type': 'a'}}

0 commit comments

Comments
 (0)