Skip to content

Commit c5cb6c4

Browse files
committed
Add parse tests and improve required flags logic
1 parent 1239c35 commit c5cb6c4

File tree

3 files changed

+87
-36
lines changed

3 files changed

+87
-36
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
/internal/
1+
internal/
2+
tmp/

parse.go

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,36 +102,27 @@ func Parse(root *Command, args []string) error {
102102
return fmt.Errorf("command %q: %w", current.Name, err)
103103
}
104104

105-
// Check required flags by inspecting the args string for their presence
106-
if len(current.FlagsMetadata) > 0 {
107-
var missingFlags []string
108-
for _, flagMetadata := range current.FlagsMetadata {
109-
if !flagMetadata.Required {
110-
continue
111-
}
112-
// TODO(mf): we need to validate that the metadata flag is known to the flag set
113-
flag := combinedFlags.Lookup(flagMetadata.Name)
114-
if flag == nil {
115-
return fmt.Errorf("command %q: internal error: required flag %q not found in flag set", current.Name, flagMetadata.Name)
116-
}
117-
// Look for the flag in the original args before any delimiter
118-
found := false
119-
for _, arg := range argsToParse {
120-
// Match either -flag or --flag
121-
if arg == "-"+flagMetadata.Name || arg == "--"+flagMetadata.Name ||
122-
strings.HasPrefix(arg, "-"+flagMetadata.Name+"=") ||
123-
strings.HasPrefix(arg, "--"+flagMetadata.Name+"=") {
124-
found = true
125-
break
105+
// Check required flags by checking if they were actually set to non-default values
106+
var missingFlags []string
107+
for _, cmd := range commandChain {
108+
if len(cmd.FlagsMetadata) > 0 {
109+
for _, flagMetadata := range cmd.FlagsMetadata {
110+
if !flagMetadata.Required {
111+
continue
112+
}
113+
flag := combinedFlags.Lookup(flagMetadata.Name)
114+
if flag == nil {
115+
return fmt.Errorf("command %q: internal error: required flag %q not found in flag set", current.Name, flagMetadata.Name)
116+
}
117+
// Check if the flag was set by checking its actual value against default
118+
if flag.Value.String() == flag.DefValue {
119+
missingFlags = append(missingFlags, flagMetadata.Name)
126120
}
127-
}
128-
if !found {
129-
missingFlags = append(missingFlags, flagMetadata.Name)
130121
}
131122
}
132-
if len(missingFlags) > 0 {
133-
return fmt.Errorf("command %q: required flag(s) %q not set", current.Name, strings.Join(missingFlags, ", "))
134-
}
123+
}
124+
if len(missingFlags) > 0 {
125+
return fmt.Errorf("command %q: required flag(s) %q not set", current.Name, strings.Join(missingFlags, ", "))
135126
}
136127

137128
// Skip past command names in remaining args from flag parsing

parse_test.go

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ import (
1313

1414
// testState is a helper struct to hold the commands for testing
1515
//
16-
// root --verbose --version
17-
// ├── add --dry-run
18-
// └── nested --force
19-
// └── sub --echo
16+
// root --verbose --version
17+
// ├── add --dry-run
18+
// └── nested --force
19+
// └── sub --echo
20+
// └── hello --mandatory-flag
2021
type testState struct {
21-
add *Command
22-
nested, sub *Command
23-
root *Command
22+
add *Command
23+
nested, sub, hello *Command
24+
root *Command
2425
}
2526

2627
func newTestState() testState {
@@ -37,14 +38,27 @@ func newTestState() testState {
3738
Flags: FlagsFunc(func(fset *flag.FlagSet) {
3839
fset.String("echo", "", "echo the message")
3940
}),
41+
FlagsMetadata: []FlagMetadata{
42+
{Name: "echo", Required: false}, // not required
43+
},
44+
Exec: exec,
45+
}
46+
hello := &Command{
47+
Name: "hello",
48+
Flags: FlagsFunc(func(fset *flag.FlagSet) {
49+
fset.Bool("mandatory-flag", false, "mandatory flag")
50+
}),
51+
FlagsMetadata: []FlagMetadata{
52+
{Name: "mandatory-flag", Required: true},
53+
},
4054
Exec: exec,
4155
}
4256
nested := &Command{
4357
Name: "nested",
4458
Flags: FlagsFunc(func(fset *flag.FlagSet) {
4559
fset.Bool("force", false, "force the operation")
4660
}),
47-
SubCommands: []*Command{sub},
61+
SubCommands: []*Command{sub, hello},
4862
Exec: exec,
4963
}
5064
root := &Command{
@@ -290,4 +304,49 @@ func TestParse(t *testing.T) {
290304
require.Error(t, err)
291305
require.ErrorContains(t, err, `subcommand in path "todo nested" has no name`)
292306
})
307+
t.Run("required flag not set", func(t *testing.T) {
308+
t.Parallel()
309+
s := newTestState()
310+
311+
err := Parse(s.root, []string{"nested", "hello"})
312+
require.Error(t, err)
313+
// TODO(mf): this error message should have the full path to the command, e.g., "todo nested hello"
314+
require.ErrorContains(t, err, `command "hello": required flag(s) "mandatory-flag" not set`)
315+
316+
// Correct type
317+
err = Parse(s.root, []string{"nested", "hello", "--mandatory-flag", "true"})
318+
require.NoError(t, err)
319+
require.True(t, GetFlag[bool](s.root.selected.state, "mandatory-flag"))
320+
// Incorrect type
321+
err = Parse(s.root, []string{"nested", "hello", "--mandatory-flag=not-a-bool"})
322+
require.Error(t, err)
323+
require.ErrorContains(t, err, `command "hello": invalid boolean value "not-a-bool" for -mandatory-flag: parse error`)
324+
})
325+
t.Run("unknown required flag set by cli author", func(t *testing.T) {
326+
t.Parallel()
327+
cmd := &Command{
328+
Name: "root",
329+
FlagsMetadata: []FlagMetadata{
330+
{Name: "some-other-flag", Required: true},
331+
},
332+
}
333+
err := Parse(cmd, nil)
334+
require.Error(t, err)
335+
// TODO(mf): consider improving this error message so it's obvious that a "required" flag
336+
// was set by the cli author but not registered in the flag set
337+
require.ErrorContains(t, err, `command "root": internal error: required flag "some-other-flag" not found in flag set`)
338+
})
339+
t.Run("space in command name", func(t *testing.T) {
340+
t.Parallel()
341+
cmd := &Command{
342+
Name: "root",
343+
SubCommands: []*Command{
344+
{Name: "sub command"},
345+
},
346+
}
347+
err := Parse(cmd, nil)
348+
require.Error(t, err)
349+
// TODO(mf): consider improving this error message so it's a bit more user-friendly
350+
require.ErrorContains(t, err, `command name "sub command" contains spaces`)
351+
})
293352
}

0 commit comments

Comments
 (0)