@@ -77,6 +77,36 @@ func TestEnum(t *testing.T) {
7777 })
7878}
7979
80+ func TestEnumDefault (t * testing.T ) {
81+ t .Parallel ()
82+
83+ t .Run ("uses default when flag not set" , func (t * testing.T ) {
84+ t .Parallel ()
85+ fs := flag .NewFlagSet ("test" , flag .ContinueOnError )
86+ fs .Var (EnumDefault ("sql" , []string {"sql" , "go" }), "type" , "" )
87+ err := fs .Parse ([]string {})
88+ require .NoError (t , err )
89+ got := fs .Lookup ("type" ).Value .(flag.Getter ).Get ().(string )
90+ assert .Equal (t , "sql" , got )
91+ })
92+ t .Run ("override default" , func (t * testing.T ) {
93+ t .Parallel ()
94+ fs := flag .NewFlagSet ("test" , flag .ContinueOnError )
95+ fs .Var (EnumDefault ("sql" , []string {"sql" , "go" }), "type" , "" )
96+ err := fs .Parse ([]string {"--type=go" })
97+ require .NoError (t , err )
98+ got := fs .Lookup ("type" ).Value .(flag.Getter ).Get ().(string )
99+ assert .Equal (t , "go" , got )
100+ })
101+ t .Run ("invalid default panics" , func (t * testing.T ) {
102+ t .Parallel ()
103+ assert .PanicsWithValue (t ,
104+ `flagtype: default value "xml" is not in allowed values: sql, go` ,
105+ func () { EnumDefault ("xml" , []string {"sql" , "go" }) },
106+ )
107+ })
108+ }
109+
80110func TestStringMap (t * testing.T ) {
81111 t .Parallel ()
82112
0 commit comments