Skip to content

Commit 18c1f3a

Browse files
committed
feat: add EnumDefault constructor for enums with an initial default value
1 parent b28286b commit 18c1f3a

File tree

4 files changed

+48
-0
lines changed

4 files changed

+48
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- `flagtype.EnumDefault` constructor for enums with an initial default value
13+
1014
## [v0.6.0] - 2026-02-18
1115

1216
### Added

flagtype/doc.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// The following types are available:
66
// - [StringSlice] - repeatable flag that collects values into []string
77
// - [Enum] - restricts values to a predefined set, retrieved as string
8+
// - [EnumDefault] - like [Enum] but with an initial default value
89
// - [StringMap] - repeatable flag that parses key=value pairs into map[string]string
910
// - [URL] - parses and validates a URL (must have scheme and host), retrieved as *url.URL
1011
// - [Regexp] - compiles a regular expression, retrieved as *regexp.Regexp
@@ -14,6 +15,7 @@
1415
// Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
1516
// f.Var(flagtype.StringSlice(), "tag", "add a tag (repeatable)")
1617
// f.Var(flagtype.Enum("json", "yaml", "table"), "format", "output format")
18+
// f.Var(flagtype.EnumDefault("sql", []string{"sql", "go"}), "type", "migration type")
1719
// f.Var(flagtype.StringMap(), "label", "key=value pair (repeatable)")
1820
// })
1921
//

flagtype/enum.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ func Enum(allowed ...string) flag.Value {
2020
return &enumValue{allowed: allowed}
2121
}
2222

23+
// EnumDefault is like [Enum] but sets an initial default value. The default must be one of the
24+
// allowed values, otherwise EnumDefault panics.
25+
//
26+
// Use [cli.GetFlag] with type string to retrieve the value.
27+
func EnumDefault(defaultVal string, allowed []string) flag.Value {
28+
if !slices.Contains(allowed, defaultVal) {
29+
panic(fmt.Sprintf("flagtype: default value %q is not in allowed values: %s",
30+
defaultVal, strings.Join(allowed, ", ")))
31+
}
32+
return &enumValue{val: defaultVal, allowed: allowed}
33+
}
34+
2335
func (v *enumValue) String() string {
2436
return v.val
2537
}

flagtype/flagtype_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,36 @@ func TestEnum(t *testing.T) {
7777
})
7878
}
7979

80+
func TestEnumDefault(t *testing.T) {
81+
t.Parallel()
82+
83+
t.Run("uses default when flag not set", func(t *testing.T) {
84+
t.Parallel()
85+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
86+
fs.Var(EnumDefault("sql", []string{"sql", "go"}), "type", "")
87+
err := fs.Parse([]string{})
88+
require.NoError(t, err)
89+
got := fs.Lookup("type").Value.(flag.Getter).Get().(string)
90+
assert.Equal(t, "sql", got)
91+
})
92+
t.Run("override default", func(t *testing.T) {
93+
t.Parallel()
94+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
95+
fs.Var(EnumDefault("sql", []string{"sql", "go"}), "type", "")
96+
err := fs.Parse([]string{"--type=go"})
97+
require.NoError(t, err)
98+
got := fs.Lookup("type").Value.(flag.Getter).Get().(string)
99+
assert.Equal(t, "go", got)
100+
})
101+
t.Run("invalid default panics", func(t *testing.T) {
102+
t.Parallel()
103+
assert.PanicsWithValue(t,
104+
`flagtype: default value "xml" is not in allowed values: sql, go`,
105+
func() { EnumDefault("xml", []string{"sql", "go"}) },
106+
)
107+
})
108+
}
109+
80110
func TestStringMap(t *testing.T) {
81111
t.Parallel()
82112

0 commit comments

Comments
 (0)