diff --git a/pydantic_settings/sources/providers/cli.py b/pydantic_settings/sources/providers/cli.py index 775a7cea..0275320d 100644 --- a/pydantic_settings/sources/providers/cli.py +++ b/pydantic_settings/sources/providers/cli.py @@ -52,6 +52,7 @@ _annotation_contains_types, _annotation_enum_val_to_name, _get_alias_names, + _get_class_types, _get_model_fields, _is_function, _strip_annotated, @@ -497,8 +498,9 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F raise SettingsError(f'CliSubCommand is not outermost annotation for {model.__name__}.{field_name}') elif _annotation_contains_types(type_, (_CliPositionalArg,), is_include_origin=False): raise SettingsError(f'CliPositionalArg is not outermost annotation for {model.__name__}.{field_name}') - if is_model_class(_strip_annotated(type_)) or is_pydantic_dataclass(_strip_annotated(type_)): - sub_models.append(_strip_annotated(type_)) + for type_ in _get_class_types(type_): + if is_model_class(type_) or is_pydantic_dataclass(type_): + sub_models.append(cast(type[BaseModel], type_)) return sub_models def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None: @@ -523,7 +525,9 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] alias_names, *_ = _get_alias_names(field_name, field_info) if len(alias_names) > 1: raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases') - field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)] + field_types = [ + type_ for type_ in _get_class_types(field_info.annotation) if type_ is not type(None) + ] for field_type in field_types: if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)): raise SettingsError( diff --git a/pydantic_settings/sources/utils.py b/pydantic_settings/sources/utils.py index 41c856fc..7b4e6297 100644 --- a/pydantic_settings/sources/utils.py +++ b/pydantic_settings/sources/utils.py @@ -54,6 +54,10 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> origin = get_origin(annotation) + # Check if annotation is of the form Union[type, ...]. + if typing_objects.is_union(origin): + return _union_is_complex(annotation, metadata) + # Check if annotation is of the form Annotated[type, metadata]. if typing_objects.is_annotated(origin): # Return result of recursive call on inner type. @@ -102,6 +106,19 @@ def _annotation_contains_types( return annotation in types +def _get_class_types(annotation: type[Any]) -> list[type[Any]]: + origin = get_origin(annotation) + if typing_objects.is_union(origin): + types = [] + for arg in get_args(annotation): + types.extend(_get_class_types(arg)) + return types + elif typing_objects.is_annotated(origin): + return _get_class_types(get_args(annotation)[0]) + else: + return [annotation] + + def _strip_annotated(annotation: Any) -> Any: if typing_objects.is_annotated(get_origin(annotation)): return annotation.__origin__ @@ -188,6 +205,7 @@ def _is_function(obj: Any) -> bool: '_annotation_is_complex', '_annotation_is_complex_inner', '_get_alias_names', + '_get_class_types', '_get_env_var_key', '_get_model_fields', '_is_function', diff --git a/tests/test_settings.py b/tests/test_settings.py index 5ca14d85..f37c3822 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -474,6 +474,21 @@ class AnnotatedComplexSettings(BaseSettings): ] +def test_cli_nested_annotated_unions(env): + class Cat(BaseModel): + meow: str + + class Dog(BaseModel): + woof: str + + class Settings(BaseSettings): + model_config = SettingsConfigDict(env_nested_delimiter='__') + animals: Annotated[Union[Annotated[Union[Cat, Dog], 'my_nested_annotation'], None], 'my_annotation'] + + env.set('ANIMALS__MEOW', 'hiss') + assert Settings().model_dump() == {'animals': {'meow': 'hiss'}} + + def test_set_dict_model(env): env.set('bananas', '[1, 2, 3, 3]') env.set('CARROTS', '{"a": null, "b": 4}') diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index 681e9743..9c772ef1 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -1275,6 +1275,20 @@ class NotSettingsConfigDict(BaseModel): get_subcommand(NotSettingsConfigDict(), cli_exit_on_error=False) +def test_cli_subcommand_with_annotated_union(): + class Cat(BaseModel): + meow: str + + class Dog(BaseModel): + woof: str + + class Settings(BaseSettings): + animals: CliSubCommand[Annotated[Union[Cat, Dog], 'my_annotation']] + + assert CliApp.run(Settings, cli_args=['Cat', '--meow=purr']).model_dump() == {'animals': {'meow': 'purr'}} + assert CliApp.run(Settings, cli_args=['Dog', '--woof=bark']).model_dump() == {'animals': {'woof': 'bark'}} + + def test_cli_union_similar_sub_models(): class ChildA(BaseModel): name: str = 'child a'