Skip to content

Commit 85eb0e7

Browse files
authored
refactor: decompose Parse into focused helpers and fix edge cases (#8)
1 parent 2382160 commit 85eb0e7

File tree

11 files changed

+423
-259
lines changed

11 files changed

+423
-259
lines changed

.github/workflows/ci.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@ on:
88
types: [opened, synchronize, reopened]
99

1010
jobs:
11+
lint:
12+
name: Lint
13+
runs-on: ubuntu-latest
14+
steps:
15+
- name: Checkout code
16+
uses: actions/checkout@v4
17+
- name: Set up Go
18+
uses: actions/setup-go@v5
19+
with:
20+
go-version: stable
21+
- name: Run golangci-lint
22+
uses: golangci/golangci-lint-action@v7
23+
1124
build:
1225
name: Build and test
1326
runs-on: ubuntu-latest

.golangci.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
version: "2"
2+
3+
linters:
4+
default: standard
5+
exclusions:
6+
paths:
7+
- examples
8+
9+
formatters:
10+
enable:
11+
- goimports

Makefile

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
.PHONY: build test lint format ci-test
2+
3+
build:
4+
go build -v .
5+
6+
test:
7+
go test $$(go list ./... | grep -v 'examples') -count=1 -v
8+
9+
lint:
10+
golangci-lint run ./...
11+
12+
format:
13+
goimports -w $$(find . -name '*.go' -not -path './examples/*')
14+
15+
ci-test:
16+
go test $$(go list ./... | grep -v 'examples') -count=1 -v -json -cover \
17+
| tparse -all -follow -sort=elapsed -trimpath=auto

graceful/graceful.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func Run(run func(context.Context) error, opts ...Option) {
100100
if cfg.logger != nil {
101101
cfg.logger.Error("function error", slog.Any("error", err))
102102
} else {
103-
fmt.Fprintln(cfg.stderr, err)
103+
_, _ = fmt.Fprintln(cfg.stderr, err)
104104
}
105105
exit(1)
106106
}
@@ -113,7 +113,7 @@ func Run(run func(context.Context) error, opts ...Option) {
113113
if cfg.logger != nil {
114114
cfg.logger.Warn(msg)
115115
} else {
116-
fmt.Fprintln(cfg.stderr, msg)
116+
_, _ = fmt.Fprintln(cfg.stderr, msg)
117117
}
118118
exit(130)
119119
}
@@ -127,7 +127,7 @@ func Run(run func(context.Context) error, opts ...Option) {
127127
if cfg.logger != nil {
128128
cfg.logger.Info(msg)
129129
} else {
130-
fmt.Fprintln(cfg.stderr, msg)
130+
_, _ = fmt.Fprintln(cfg.stderr, msg)
131131
}
132132

133133
// Set up shutdown timeout if configured
@@ -145,7 +145,7 @@ func Run(run func(context.Context) error, opts ...Option) {
145145
if cfg.logger != nil {
146146
cfg.logger.Error("function error", "error", err)
147147
} else {
148-
fmt.Fprintln(cfg.stderr, err)
148+
_, _ = fmt.Fprintln(cfg.stderr, err)
149149
}
150150
exit(1)
151151
}
@@ -157,7 +157,7 @@ func Run(run func(context.Context) error, opts ...Option) {
157157
if cfg.logger != nil {
158158
cfg.logger.Warn(msg)
159159
} else {
160-
fmt.Fprintln(cfg.stderr, msg)
160+
_, _ = fmt.Fprintln(cfg.stderr, msg)
161161
}
162162
exit(130)
163163

@@ -167,7 +167,7 @@ func Run(run func(context.Context) error, opts ...Option) {
167167
if cfg.logger != nil {
168168
cfg.logger.Error(msg)
169169
} else {
170-
fmt.Fprintln(cfg.stderr, msg)
170+
_, _ = fmt.Fprintln(cfg.stderr, msg)
171171
}
172172
exit(124)
173173
}

graceful/graceful_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func sendSignal(trigger <-chan struct{}, delay time.Duration) {
4040
if delay > 0 {
4141
time.Sleep(delay)
4242
}
43-
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
43+
_ = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
4444
}
4545

4646
func TestRun_Success(t *testing.T) {

parse.go

Lines changed: 112 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,62 @@ func Parse(root *Command, args []string) error {
3636
// Reset command path but preserve other state
3737
root.state.path = []*Command{root}
3838
}
39-
// First split args at the -- delimiter if present
40-
var argsToParse []string
41-
var remainingArgs []string
39+
40+
argsToParse, remainingArgs := splitAtDelimiter(args)
41+
42+
current, err := resolveCommandPath(root, argsToParse)
43+
if err != nil {
44+
return err
45+
}
46+
current.Flags.Usage = func() { /* suppress default usage */ }
47+
48+
// Check for help flags after resolving the correct command
49+
for _, arg := range argsToParse {
50+
if arg == "-h" || arg == "--h" || arg == "-help" || arg == "--help" {
51+
// Combine flags first so the help message includes all inherited flags
52+
combineFlags(root.state.path)
53+
return flag.ErrHelp
54+
}
55+
}
56+
57+
combinedFlags := combineFlags(root.state.path)
58+
59+
// Let ParseToEnd handle the flag parsing
60+
if err := xflag.ParseToEnd(combinedFlags, argsToParse); err != nil {
61+
return fmt.Errorf("command %q: %w", getCommandPath(root.state.path), err)
62+
}
63+
64+
if err := checkRequiredFlags(root.state.path, combinedFlags); err != nil {
65+
return err
66+
}
67+
68+
root.state.Args = collectArgs(root.state.path, combinedFlags.Args(), remainingArgs)
69+
70+
if current.Exec == nil {
71+
return fmt.Errorf("command %q: no exec function defined", getCommandPath(root.state.path))
72+
}
73+
return nil
74+
}
75+
76+
// splitAtDelimiter splits args at the first "--" delimiter. Returns the args before the delimiter
77+
// and any args after it.
78+
func splitAtDelimiter(args []string) (argsToParse, remaining []string) {
4279
for i, arg := range args {
4380
if arg == "--" {
44-
argsToParse = args[:i]
45-
remainingArgs = args[i+1:]
46-
break
81+
return args[:i], args[i+1:]
4782
}
4883
}
49-
if argsToParse == nil {
50-
argsToParse = args
51-
}
84+
return args, nil
85+
}
5286

87+
// resolveCommandPath walks argsToParse to resolve the subcommand chain, building root.state.path
88+
// and initializing flag sets along the way. Returns the terminal (deepest) command.
89+
func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) {
5390
current := root
5491
if current.Flags == nil {
5592
current.Flags = flag.NewFlagSet(root.Name, flag.ContinueOnError)
5693
}
57-
var commandChain []*Command
58-
commandChain = append(commandChain, root)
5994

60-
// Create combined flags with all parent flags
61-
combinedFlags := flag.NewFlagSet(root.Name, flag.ContinueOnError)
62-
combinedFlags.SetOutput(io.Discard)
63-
64-
// First pass: process commands and build the flag set
6595
i := 0
6696
for i < len(argsToParse) {
6797
arg := argsToParse[i]
@@ -74,15 +104,24 @@ func Parse(root *Command, args []string) error {
74104
continue
75105
}
76106

77-
// Check if this flag expects a value
107+
// Check if this flag expects a value across all commands in the chain (not just the
108+
// current command), since flags from ancestor commands are inherited and can appear
109+
// anywhere.
78110
name := strings.TrimLeft(arg, "-")
79-
if f := current.Flags.Lookup(name); f != nil {
80-
if _, isBool := f.Value.(interface{ IsBoolFlag() bool }); !isBool {
81-
// Skip both flag and its value
82-
i += 2
83-
continue
111+
skipValue := false
112+
for _, cmd := range root.state.path {
113+
if f := cmd.Flags.Lookup(name); f != nil {
114+
if _, isBool := f.Value.(interface{ IsBoolFlag() bool }); !isBool {
115+
skipValue = true
116+
}
117+
break
84118
}
85119
}
120+
if skipValue {
121+
// Skip both flag and its value
122+
i += 2
123+
continue
124+
}
86125
i++
87126
continue
88127
}
@@ -95,73 +134,55 @@ func Parse(root *Command, args []string) error {
95134
sub.Flags = flag.NewFlagSet(sub.Name, flag.ContinueOnError)
96135
}
97136
current = sub
98-
commandChain = append(commandChain, sub)
99137
i++
100138
continue
101139
}
102-
return current.formatUnknownCommandError(arg)
140+
return nil, current.formatUnknownCommandError(arg)
103141
}
104142
break
105143
}
106-
current.Flags.Usage = func() { /* suppress default usage */ }
107-
108-
// Add the help check here, after we've found the correct command
109-
hasHelp := false
110-
for _, arg := range argsToParse {
111-
if arg == "-h" || arg == "--h" || arg == "-help" || arg == "--help" {
112-
hasHelp = true
113-
break
114-
}
115-
}
144+
return current, nil
145+
}
116146

117-
// Add flags in reverse order for proper precedence
118-
for i := len(commandChain) - 1; i >= 0; i-- {
119-
cmd := commandChain[i]
147+
// combineFlags merges flags from the command path into a single FlagSet. Flags are added in reverse
148+
// order (deepest command first) so that child flags take precedence over parent flags.
149+
func combineFlags(path []*Command) *flag.FlagSet {
150+
combined := flag.NewFlagSet(path[0].Name, flag.ContinueOnError)
151+
combined.SetOutput(io.Discard)
152+
for i := len(path) - 1; i >= 0; i-- {
153+
cmd := path[i]
120154
if cmd.Flags != nil {
121155
cmd.Flags.VisitAll(func(f *flag.Flag) {
122-
if combinedFlags.Lookup(f.Name) == nil {
123-
combinedFlags.Var(f.Value, f.Name, f.Usage)
156+
if combined.Lookup(f.Name) == nil {
157+
combined.Var(f.Value, f.Name, f.Usage)
124158
}
125159
})
126160
}
127161
}
128-
// Make sure to return help only after combining all flags, this way we get the full list of
129-
// flags in the help message!
130-
if hasHelp {
131-
return flag.ErrHelp
132-
}
162+
return combined
163+
}
133164

134-
// Let ParseToEnd handle the flag parsing
135-
if err := xflag.ParseToEnd(combinedFlags, argsToParse); err != nil {
136-
return fmt.Errorf("command %q: %w", getCommandPath(root.state.path), err)
137-
}
165+
// checkRequiredFlags verifies that all flags marked as required in FlagsMetadata were explicitly
166+
// set during parsing.
167+
func checkRequiredFlags(path []*Command, combined *flag.FlagSet) error {
168+
// Build a set of flags that were explicitly set during parsing. Visit (unlike VisitAll) only
169+
// iterates over flags that were actually provided by the user, regardless of their value.
170+
setFlags := make(map[string]struct{})
171+
combined.Visit(func(f *flag.Flag) {
172+
setFlags[f.Name] = struct{}{}
173+
})
138174

139-
// Check required flags
140175
var missingFlags []string
141-
for _, cmd := range commandChain {
142-
if len(cmd.FlagsMetadata) > 0 {
143-
for _, flagMetadata := range cmd.FlagsMetadata {
144-
if !flagMetadata.Required {
145-
continue
146-
}
147-
flag := combinedFlags.Lookup(flagMetadata.Name)
148-
if flag == nil {
149-
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(root.state.path), formatFlagName(flagMetadata.Name))
150-
}
151-
if _, isBool := flag.Value.(interface{ IsBoolFlag() bool }); isBool {
152-
isSet := false
153-
for _, arg := range argsToParse {
154-
if strings.HasPrefix(arg, "-"+flagMetadata.Name) || strings.HasPrefix(arg, "--"+flagMetadata.Name) {
155-
isSet = true
156-
break
157-
}
158-
}
159-
if !isSet {
160-
missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name))
161-
}
162-
} else if flag.Value.String() == flag.DefValue {
163-
missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name))
164-
}
176+
for _, cmd := range path {
177+
for _, flagMetadata := range cmd.FlagsMetadata {
178+
if !flagMetadata.Required {
179+
continue
180+
}
181+
if combined.Lookup(flagMetadata.Name) == nil {
182+
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(path), formatFlagName(flagMetadata.Name))
183+
}
184+
if _, ok := setFlags[flagMetadata.Name]; !ok {
185+
missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name))
165186
}
166187
}
167188
}
@@ -170,40 +191,36 @@ func Parse(root *Command, args []string) error {
170191
if len(missingFlags) > 1 {
171192
msg += "s"
172193
}
173-
return fmt.Errorf("command %q: %s %q not set", getCommandPath(root.state.path), msg, strings.Join(missingFlags, ", "))
194+
return fmt.Errorf("command %q: %s %q not set", getCommandPath(path), msg, strings.Join(missingFlags, ", "))
174195
}
196+
return nil
197+
}
175198

176-
// Skip past command names in remaining args
177-
parsed := combinedFlags.Args()
199+
// collectArgs strips resolved command names from the parsed positional args and appends any args
200+
// that appeared after the "--" delimiter.
201+
func collectArgs(path []*Command, parsed, remaining []string) []string {
202+
// Skip past command names in remaining args. Only strip the exact command names that were
203+
// resolved during traversal (path[1:], since root never appears in user args), in order and
204+
// only once each.
178205
startIdx := 0
179-
for _, arg := range parsed {
180-
isCommand := false
181-
for _, cmd := range commandChain {
182-
if arg == cmd.Name {
183-
startIdx++
184-
isCommand = true
185-
break
186-
}
187-
}
188-
if !isCommand {
206+
chainIdx := 1 // Skip root
207+
for startIdx < len(parsed) && chainIdx < len(path) {
208+
if strings.EqualFold(parsed[startIdx], path[chainIdx].Name) {
209+
startIdx++
210+
chainIdx++
211+
} else {
189212
break
190213
}
191214
}
192215

193-
// Combine remaining parsed args and everything after delimiter
194216
var finalArgs []string
195217
if startIdx < len(parsed) {
196218
finalArgs = append(finalArgs, parsed[startIdx:]...)
197219
}
198-
if len(remainingArgs) > 0 {
199-
finalArgs = append(finalArgs, remainingArgs...)
220+
if len(remaining) > 0 {
221+
finalArgs = append(finalArgs, remaining...)
200222
}
201-
root.state.Args = finalArgs
202-
203-
if current.Exec == nil {
204-
return fmt.Errorf("command %q: no exec function defined", getCommandPath(root.state.path))
205-
}
206-
return nil
223+
return finalArgs
207224
}
208225

209226
var validNameRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`)

0 commit comments

Comments
 (0)