Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pydantic_settings/sources/providers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_annotation_contains_types,
_annotation_enum_val_to_name,
_get_alias_names,
_get_class_types,
_get_model_fields,
_is_function,
_strip_annotated,
Expand Down Expand Up @@ -497,8 +498,9 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
raise SettingsError(f'CliSubCommand is not outermost annotation for {model.__name__}.{field_name}')
elif _annotation_contains_types(type_, (_CliPositionalArg,), is_include_origin=False):
raise SettingsError(f'CliPositionalArg is not outermost annotation for {model.__name__}.{field_name}')
if is_model_class(_strip_annotated(type_)) or is_pydantic_dataclass(_strip_annotated(type_)):
sub_models.append(_strip_annotated(type_))
for type_ in _get_class_types(type_):
if is_model_class(type_) or is_pydantic_dataclass(type_):
sub_models.append(cast(type[BaseModel], type_))
return sub_models

def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
Expand All @@ -523,7 +525,9 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
field_types = [
type_ for type_ in _get_class_types(field_info.annotation) if type_ is not type(None)
]
for field_type in field_types:
if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)):
raise SettingsError(
Expand Down
18 changes: 18 additions & 0 deletions pydantic_settings/sources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) ->

origin = get_origin(annotation)

# Check if annotation is of the form Union[type, ...].
if typing_objects.is_union(origin):
return _union_is_complex(annotation, metadata)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hramezani From what I observed, it appears to be related to incorrectly determining that nested annotated unions are not complex. This change fixed it, but causes other tests fail. I'm not sure if it is the correct fix. Can you take a look?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @kschwab for the delay.

I am also not sure what the correct fix is here. As the change breaks some existing tests, I would prefer to remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries @hramezani. I'm ok with closing this as a won't fix as well.


# Check if annotation is of the form Annotated[type, metadata].
if typing_objects.is_annotated(origin):
# Return result of recursive call on inner type.
Expand Down Expand Up @@ -102,6 +106,19 @@ def _annotation_contains_types(
return annotation in types


def _get_class_types(annotation: type[Any]) -> list[type[Any]]:
origin = get_origin(annotation)
if typing_objects.is_union(origin):
types = []
for arg in get_args(annotation):
types.extend(_get_class_types(arg))
return types
elif typing_objects.is_annotated(origin):
return _get_class_types(get_args(annotation)[0])
else:
return [annotation]


def _strip_annotated(annotation: Any) -> Any:
if typing_objects.is_annotated(get_origin(annotation)):
return annotation.__origin__
Expand Down Expand Up @@ -188,6 +205,7 @@ def _is_function(obj: Any) -> bool:
'_annotation_is_complex',
'_annotation_is_complex_inner',
'_get_alias_names',
'_get_class_types',
'_get_env_var_key',
'_get_model_fields',
'_is_function',
Expand Down
15 changes: 15 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,21 @@ class AnnotatedComplexSettings(BaseSettings):
]


def test_cli_nested_annotated_unions(env):
class Cat(BaseModel):
meow: str

class Dog(BaseModel):
woof: str

class Settings(BaseSettings):
model_config = SettingsConfigDict(env_nested_delimiter='__')
animals: Annotated[Union[Annotated[Union[Cat, Dog], 'my_nested_annotation'], None], 'my_annotation']
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hramezani I have the top level fix for the CLI in place. However, I discovered that the env parsing does not handle nested annotated unions properly. I've added a test case here to demonstrate the issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a special case. I am sure that there are more cases that fail(we are not aware of them).
I would prefer keeping the code base simple and preventing breaking change instead of supporting all complex scenarios.


env.set('ANIMALS__MEOW', 'hiss')
assert Settings().model_dump() == {'animals': {'meow': 'hiss'}}


def test_set_dict_model(env):
env.set('bananas', '[1, 2, 3, 3]')
env.set('CARROTS', '{"a": null, "b": 4}')
Expand Down
14 changes: 14 additions & 0 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,20 @@ class NotSettingsConfigDict(BaseModel):
get_subcommand(NotSettingsConfigDict(), cli_exit_on_error=False)


def test_cli_subcommand_with_annotated_union():
class Cat(BaseModel):
meow: str

class Dog(BaseModel):
woof: str

class Settings(BaseSettings):
animals: CliSubCommand[Annotated[Union[Cat, Dog], 'my_annotation']]

assert CliApp.run(Settings, cli_args=['Cat', '--meow=purr']).model_dump() == {'animals': {'meow': 'purr'}}
assert CliApp.run(Settings, cli_args=['Dog', '--woof=bark']).model_dump() == {'animals': {'woof': 'bark'}}


def test_cli_union_similar_sub_models():
class ChildA(BaseModel):
name: str = 'child a'
Expand Down
Loading