22
33import dataclasses
44import json
5- from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
5+ from typing import Any , Dict , Generic , List , Literal , Optional , Tuple , TypeVar , Union
66from unittest .mock import patch
77
88import pytest
@@ -713,6 +713,18 @@ def test_pydantic_annotated_nested_annotated_dataclass_with_default_factory(pars
713713 cfg = parser .parse_args (["--n" , "{}" ])
714714 assert cfg .n == Namespace (a1 = Namespace (a2 = 1 ))
715715
716+ class PingTask (pydantic .BaseModel ):
717+ type : Literal ["ping" ] = "ping"
718+ attr : str = ""
719+
720+ class PongTask (pydantic .BaseModel ):
721+ type : Literal ["pong" ] = "pong"
722+
723+ PingPongTask = annotated [
724+ Union [PingTask , PongTask ],
725+ pydantic .Field (discriminator = "type" ),
726+ ]
727+
716728
717729length = "length"
718730if pydantic_support :
@@ -806,6 +818,8 @@ def test_subclass(self, parser):
806818 parser .add_argument ("--model" , type = PydanticSubModel , default = PydanticSubModel (p1 = "a" ))
807819 cfg = parser .parse_args (["--model.p3=0.2" ])
808820 assert Namespace (p1 = "a" , p2 = 3 , p3 = 0.2 ) == cfg .model
821+ init = parser .instantiate_classes (cfg )
822+ assert isinstance (init .model , PydanticSubModel )
809823
810824 def test_field_default_factory (self , parser ):
811825 parser .add_argument ("--model" , type = PydanticFieldFactory )
@@ -831,6 +845,18 @@ def test_annotated_field(self, parser):
831845 parser .parse_args (["--model.p1=0" ])
832846 ctx .match ("model.p1" )
833847
848+ @pytest .mark .skipif (not (annotated and pydantic_support > 1 ), reason = "Annotated is required" )
849+ def test_field_union_discriminator_dot_syntax (self , parser ):
850+ parser .add_argument ("--model" , type = PingPongTask )
851+ cfg = parser .parse_args (["--model.type=pong" ])
852+ assert cfg .model == Namespace (type = "pong" )
853+ init = parser .instantiate_classes (cfg )
854+ assert isinstance (init .model , PongTask )
855+ cfg = parser .parse_args (["--model.type=ping" , "--model.attr=abc" ])
856+ assert cfg .model == Namespace (type = "ping" , attr = "abc" )
857+ init = parser .instantiate_classes (cfg )
858+ assert isinstance (init .model , PingTask )
859+
834860 @pytest .mark .parametrize (
835861 ["valid_value" , "invalid_value" , "cast" , "type_str" ],
836862 [
0 commit comments