Skip to content

Commit 97198be

Browse files
authored
feat: add flagtype package with common flag.Value implementations (#13)
1 parent 53c0a5d commit 97198be

File tree

9 files changed

+538
-0
lines changed

9 files changed

+538
-0
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- New `flagtype` package with common `flag.Value` implementations: `StringSlice`, `Enum`,
13+
`StringMap`, `URL`, and `Regexp`
14+
1015
## [v0.5.0] - 2026-02-17
1116

1217
### Changed

docs/design/001-flagtype-api.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# 001 - flagtype API
2+
3+
**Date:** 2026-02-18
4+
5+
## Context
6+
7+
Users of pressly/cli must manually implement `flag.Value` (and `flag.Getter`) for common types like
8+
string slices, enums, and maps. This is repetitive boilerplate that most CLI tools need.
9+
10+
## Decision
11+
12+
Use stdlib-native constructors that return `flag.Value`, registered via `f.Var()`.
13+
14+
```go
15+
Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
16+
f.Bool("verbose", false, "enable verbose output")
17+
f.Var(flagtype.StringSlice(), "tag", "add a tag (repeatable)")
18+
f.Var(flagtype.Enum("json", "yaml", "table"), "format", "output format")
19+
f.Var(flagtype.StringMap(), "label", "key=value pair (repeatable)")
20+
})
21+
```
22+
23+
The flagtype package has no knowledge of `flag.FlagSet`. Each constructor returns a value that
24+
implements `flag.Value` and `flag.Getter`. Storage is internal -- no destination pointers needed
25+
since values are retrieved via `cli.GetFlag[T]`.
26+
27+
## Alternatives considered
28+
29+
### A: flagtype takes a FlagSet
30+
31+
```go
32+
Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
33+
f.Bool("verbose", false, "enable verbose output")
34+
flagtype.StringSlice(f, "tag", "add a tag (repeatable)")
35+
flagtype.Enum(f, "format", "output format", "json", "yaml", "table")
36+
})
37+
```
38+
39+
One-liner registration, no `f.Var()` ceremony. Rejected because it introduces a second calling
40+
convention in the same block -- stdlib flags use `f.Type(name, default, usage)` while flagtype would
41+
use `flagtype.Type(f, name, usage)`. The argument ordering inconsistency makes it harder to read at
42+
a glance.
43+
44+
### B: FlagSet wrapper
45+
46+
```go
47+
Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
48+
f.Bool("verbose", false, "enable verbose output")
49+
ft := flagtype.From(f)
50+
ft.StringSlice("tag", "add a tag (repeatable)")
51+
ft.Enum("format", "output format", "json", "yaml", "table")
52+
})
53+
```
54+
55+
Feels like a natural extension of FlagSet. Rejected because it requires managing two objects in the
56+
same closure -- `f` for standard types and `ft` for custom types. Also adds a layer of indirection
57+
that doesn't pull its weight.
58+
59+
### C: Declarative flag list
60+
61+
```go
62+
Flags: []cli.Flag{
63+
cli.String("output", "", "output file"),
64+
cli.Bool("verbose", false, "enable verbose output"),
65+
flagtype.StringSlice("tag", "add a tag (repeatable)"),
66+
flagtype.Enum("format", "output format", "json", "yaml", "table"),
67+
}
68+
```
69+
70+
Fully declarative, no callback, no FlagSet. Rejected because it's a significant departure from the
71+
stdlib `flag` package and would require rethinking the core `Command` type. Essentially a different
72+
framework.
73+
74+
### D: Destination pointer pattern
75+
76+
```go
77+
var tags []string
78+
var re *regexp.Regexp
79+
f.Var(flagtype.StringSlice(&tags), "tag", "add a tag (repeatable)")
80+
f.Var(flagtype.Regexp(&re), "pattern", "regex pattern")
81+
```
82+
83+
The initial implementation. Each constructor takes a pointer to the destination variable. Rejected
84+
because pointer types like `*regexp.Regexp` and `*url.URL` require double pointers
85+
(`**regexp.Regexp`), which is awkward. Since values are always retrieved via `cli.GetFlag[T]`, the
86+
destination pointer serves no purpose.
87+
88+
## Why this approach
89+
90+
- **Zero new concepts.** Anyone who knows `flag.Var` already knows how to use flagtype.
91+
- **No coupling.** flagtype has no dependency on the cli package or `flag.FlagSet`.
92+
- **Consistent with stdlib.** Custom flag types in Go have always been registered via `f.Var()`.
93+
This follows that convention exactly.
94+
- **No double pointers.** Internal storage means the API is clean for all types, including pointer
95+
types like `*url.URL` and `*regexp.Regexp`.

flagtype/doc.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Package flagtype provides common [flag.Value] implementations for use with [flag.FlagSet.Var].
2+
//
3+
// All types implement [flag.Getter] so they work with [cli.GetFlag].
4+
//
5+
// The following types are available:
6+
// - [StringSlice] - repeatable flag that collects values into []string
7+
// - [Enum] - restricts values to a predefined set, retrieved as string
8+
// - [StringMap] - repeatable flag that parses key=value pairs into map[string]string
9+
// - [URL] - parses and validates a URL (must have scheme and host), retrieved as *url.URL
10+
// - [Regexp] - compiles a regular expression, retrieved as *regexp.Regexp
11+
//
12+
// Example registration:
13+
//
14+
// Flags: cli.FlagsFunc(func(f *flag.FlagSet) {
15+
// f.Var(flagtype.StringSlice(), "tag", "add a tag (repeatable)")
16+
// f.Var(flagtype.Enum("json", "yaml", "table"), "format", "output format")
17+
// f.Var(flagtype.StringMap(), "label", "key=value pair (repeatable)")
18+
// })
19+
//
20+
// Example retrieval in Exec:
21+
//
22+
// tags := cli.GetFlag[[]string](s, "tag")
23+
// format := cli.GetFlag[string](s, "format")
24+
// labels := cli.GetFlag[map[string]string](s, "label")
25+
package flagtype

flagtype/enum.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package flagtype
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"slices"
7+
"strings"
8+
)
9+
10+
type enumValue struct {
11+
val string
12+
allowed []string
13+
}
14+
15+
// Enum returns a [flag.Value] that restricts the flag to one of the allowed values. If a value not
16+
// in the allowed list is provided, an error is returned listing valid options.
17+
//
18+
// Use [cli.GetFlag] with type string to retrieve the value.
19+
func Enum(allowed ...string) flag.Value {
20+
return &enumValue{allowed: allowed}
21+
}
22+
23+
func (v *enumValue) String() string {
24+
return v.val
25+
}
26+
27+
func (v *enumValue) Set(s string) error {
28+
if !slices.Contains(v.allowed, s) {
29+
return fmt.Errorf("invalid value %q, must be one of: %s", s, strings.Join(v.allowed, ", "))
30+
}
31+
v.val = s
32+
return nil
33+
}
34+
35+
func (v *enumValue) Get() any {
36+
return v.val
37+
}

flagtype/flagtype_test.go

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
package flagtype
2+
3+
import (
4+
"flag"
5+
"net/url"
6+
"regexp"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestStringSlice(t *testing.T) {
14+
t.Parallel()
15+
16+
t.Run("single value", func(t *testing.T) {
17+
t.Parallel()
18+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
19+
fs.Var(StringSlice(), "tag", "")
20+
err := fs.Parse([]string{"--tag=foo"})
21+
require.NoError(t, err)
22+
got := fs.Lookup("tag").Value.(flag.Getter).Get().([]string)
23+
assert.Equal(t, []string{"foo"}, got)
24+
})
25+
t.Run("multiple values", func(t *testing.T) {
26+
t.Parallel()
27+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
28+
fs.Var(StringSlice(), "tag", "")
29+
err := fs.Parse([]string{"--tag=foo", "--tag=bar", "--tag=baz"})
30+
require.NoError(t, err)
31+
got := fs.Lookup("tag").Value.(flag.Getter).Get().([]string)
32+
assert.Equal(t, []string{"foo", "bar", "baz"}, got)
33+
})
34+
t.Run("string output", func(t *testing.T) {
35+
t.Parallel()
36+
v := StringSlice()
37+
require.NoError(t, v.Set("a"))
38+
require.NoError(t, v.Set("b"))
39+
assert.Equal(t, "a,b", v.String())
40+
})
41+
t.Run("empty", func(t *testing.T) {
42+
t.Parallel()
43+
v := StringSlice()
44+
assert.Equal(t, "", v.String())
45+
got := v.(flag.Getter).Get().([]string)
46+
assert.Nil(t, got)
47+
})
48+
}
49+
50+
func TestEnum(t *testing.T) {
51+
t.Parallel()
52+
53+
t.Run("valid value", func(t *testing.T) {
54+
t.Parallel()
55+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
56+
fs.Var(Enum("json", "yaml", "table"), "format", "")
57+
err := fs.Parse([]string{"--format=yaml"})
58+
require.NoError(t, err)
59+
got := fs.Lookup("format").Value.(flag.Getter).Get().(string)
60+
assert.Equal(t, "yaml", got)
61+
})
62+
t.Run("invalid value", func(t *testing.T) {
63+
t.Parallel()
64+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
65+
fs.SetOutput(nopWriter{})
66+
fs.Var(Enum("json", "yaml"), "format", "")
67+
err := fs.Parse([]string{"--format=xml"})
68+
require.Error(t, err)
69+
assert.Contains(t, err.Error(), "must be one of")
70+
assert.Contains(t, err.Error(), "json, yaml")
71+
})
72+
t.Run("empty default", func(t *testing.T) {
73+
t.Parallel()
74+
v := Enum("a", "b")
75+
assert.Equal(t, "", v.String())
76+
assert.Equal(t, "", v.(flag.Getter).Get())
77+
})
78+
}
79+
80+
func TestStringMap(t *testing.T) {
81+
t.Parallel()
82+
83+
t.Run("single pair", func(t *testing.T) {
84+
t.Parallel()
85+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
86+
fs.Var(StringMap(), "label", "")
87+
err := fs.Parse([]string{"--label=env=prod"})
88+
require.NoError(t, err)
89+
got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string)
90+
assert.Equal(t, map[string]string{"env": "prod"}, got)
91+
})
92+
t.Run("multiple pairs", func(t *testing.T) {
93+
t.Parallel()
94+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
95+
fs.Var(StringMap(), "label", "")
96+
err := fs.Parse([]string{"--label=env=prod", "--label=tier=web"})
97+
require.NoError(t, err)
98+
got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string)
99+
assert.Equal(t, map[string]string{"env": "prod", "tier": "web"}, got)
100+
})
101+
t.Run("value contains equals", func(t *testing.T) {
102+
t.Parallel()
103+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
104+
fs.Var(StringMap(), "label", "")
105+
err := fs.Parse([]string{"--label=query=a=b"})
106+
require.NoError(t, err)
107+
got := fs.Lookup("label").Value.(flag.Getter).Get().(map[string]string)
108+
assert.Equal(t, map[string]string{"query": "a=b"}, got)
109+
})
110+
t.Run("missing equals", func(t *testing.T) {
111+
t.Parallel()
112+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
113+
fs.SetOutput(nopWriter{})
114+
fs.Var(StringMap(), "label", "")
115+
err := fs.Parse([]string{"--label=nope"})
116+
require.Error(t, err)
117+
assert.Contains(t, err.Error(), "missing '='")
118+
})
119+
t.Run("empty key", func(t *testing.T) {
120+
t.Parallel()
121+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
122+
fs.SetOutput(nopWriter{})
123+
fs.Var(StringMap(), "label", "")
124+
err := fs.Parse([]string{"--label==value"})
125+
require.Error(t, err)
126+
assert.Contains(t, err.Error(), "empty key")
127+
})
128+
t.Run("string output sorted", func(t *testing.T) {
129+
t.Parallel()
130+
v := StringMap()
131+
require.NoError(t, v.Set("b=2"))
132+
require.NoError(t, v.Set("a=1"))
133+
assert.Equal(t, "a=1,b=2", v.String())
134+
})
135+
t.Run("empty", func(t *testing.T) {
136+
t.Parallel()
137+
v := StringMap()
138+
assert.Equal(t, "", v.String())
139+
assert.Nil(t, v.(flag.Getter).Get())
140+
})
141+
}
142+
143+
func TestURL(t *testing.T) {
144+
t.Parallel()
145+
146+
t.Run("valid url", func(t *testing.T) {
147+
t.Parallel()
148+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
149+
fs.Var(URL(), "endpoint", "")
150+
err := fs.Parse([]string{"--endpoint=https://example.com/api"})
151+
require.NoError(t, err)
152+
got := fs.Lookup("endpoint").Value.(flag.Getter).Get().(*url.URL)
153+
require.NotNil(t, got)
154+
assert.Equal(t, "https", got.Scheme)
155+
assert.Equal(t, "example.com", got.Host)
156+
assert.Equal(t, "/api", got.Path)
157+
})
158+
t.Run("missing scheme", func(t *testing.T) {
159+
t.Parallel()
160+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
161+
fs.SetOutput(nopWriter{})
162+
fs.Var(URL(), "endpoint", "")
163+
err := fs.Parse([]string{"--endpoint=example.com"})
164+
require.Error(t, err)
165+
assert.Contains(t, err.Error(), "must have a scheme and host")
166+
})
167+
t.Run("empty", func(t *testing.T) {
168+
t.Parallel()
169+
v := URL()
170+
assert.Equal(t, "", v.String())
171+
assert.Nil(t, v.(flag.Getter).Get())
172+
})
173+
}
174+
175+
func TestRegexp(t *testing.T) {
176+
t.Parallel()
177+
178+
t.Run("valid pattern", func(t *testing.T) {
179+
t.Parallel()
180+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
181+
fs.Var(Regexp(), "pattern", "")
182+
err := fs.Parse([]string{"--pattern=^foo.*bar$"})
183+
require.NoError(t, err)
184+
got := fs.Lookup("pattern").Value.(flag.Getter).Get().(*regexp.Regexp)
185+
require.NotNil(t, got)
186+
assert.True(t, got.MatchString("fooXbar"))
187+
assert.False(t, got.MatchString("baz"))
188+
})
189+
t.Run("invalid pattern", func(t *testing.T) {
190+
t.Parallel()
191+
fs := flag.NewFlagSet("test", flag.ContinueOnError)
192+
fs.SetOutput(nopWriter{})
193+
fs.Var(Regexp(), "pattern", "")
194+
err := fs.Parse([]string{"--pattern=[invalid"})
195+
require.Error(t, err)
196+
})
197+
t.Run("empty", func(t *testing.T) {
198+
t.Parallel()
199+
v := Regexp()
200+
assert.Equal(t, "", v.String())
201+
assert.Nil(t, v.(flag.Getter).Get())
202+
})
203+
}
204+
205+
// nopWriter discards all writes, used to suppress flag.FlagSet error output in tests.
206+
type nopWriter struct{}
207+
208+
func (nopWriter) Write(p []byte) (int, error) { return len(p), nil }

0 commit comments

Comments
 (0)