Skip to content

Commit c6c5804

Browse files
authored
Merge pull request urfave#2108 from bystones/fix_2098
use correct context in After function with subcommand
2 parents 65c6366 + a5cfa4f commit c6c5804

File tree

2 files changed

+157
-15
lines changed

2 files changed

+157
-15
lines changed

command_run.go

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ outer:
9191
// arguments are parsed according to the Flag and Command
9292
// definitions and the matching Action functions are run.
9393
func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
94+
_, deferErr = cmd.run(ctx, osArgs)
95+
return
96+
}
97+
98+
func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context, deferErr error) {
9499
tracef("running with arguments %[1]q (cmd=%[2]q)", osArgs, cmd.Name)
95100
cmd.setupDefaults(osArgs)
96101

@@ -102,7 +107,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
102107
if cmd.parent == nil {
103108
if cmd.ReadArgsFromStdin {
104109
if args, err := cmd.parseArgsFromStdin(); err != nil {
105-
return err
110+
return ctx, err
106111
} else {
107112
osArgs = append(osArgs, args...)
108113
}
@@ -132,7 +137,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
132137
var rargs Args = &stringSliceArgs{v: osArgs}
133138
for _, f := range cmd.allFlags() {
134139
if err := f.PreParse(); err != nil {
135-
return err
140+
return ctx, err
136141
}
137142
}
138143

@@ -149,7 +154,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
149154
tracef("using post-parse arguments %[1]q (cmd=%[2]q)", args, cmd.Name)
150155

151156
if checkCompletions(ctx, cmd) {
152-
return nil
157+
return ctx, nil
153158
}
154159

155160
if err != nil {
@@ -160,7 +165,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
160165
if cmd.OnUsageError != nil {
161166
err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil)
162167
err = cmd.handleExitCoder(ctx, err)
163-
return err
168+
return ctx, err
164169
}
165170
fmt.Fprintf(cmd.Root().ErrWriter, "Incorrect Usage: %s\n\n", err.Error())
166171
if cmd.Suggest {
@@ -182,23 +187,23 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
182187
}
183188
}
184189

185-
return err
190+
return ctx, err
186191
}
187192

188193
if cmd.checkHelp() {
189-
return helpCommandAction(ctx, cmd)
194+
return ctx, helpCommandAction(ctx, cmd)
190195
} else {
191196
tracef("no help is wanted (cmd=%[1]q)", cmd.Name)
192197
}
193198

194199
if cmd.parent == nil && !cmd.HideVersion && checkVersion(cmd) {
195200
ShowVersion(cmd)
196-
return nil
201+
return ctx, nil
197202
}
198203

199204
for _, flag := range cmd.allFlags() {
200205
if err := flag.PostParse(); err != nil {
201-
return err
206+
return ctx, err
202207
}
203208
}
204209

@@ -219,7 +224,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
219224
for _, grp := range cmd.MutuallyExclusiveFlags {
220225
if err := grp.check(cmd); err != nil {
221226
_ = ShowSubcommandHelp(cmd)
222-
return err
227+
return ctx, err
223228
}
224229
}
225230

@@ -262,7 +267,12 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
262267
// If a subcommand has been resolved, let it handle the remaining execution.
263268
if subCmd != nil {
264269
tracef("running sub-command %[1]q with arguments %[2]q (cmd=%[3]q)", subCmd.Name, cmd.Args(), cmd.Name)
265-
return subCmd.Run(ctx, cmd.Args().Slice())
270+
271+
// It is important that we overwrite the ctx variable in the current
272+
// function so any defer'd functions use the new context returned
273+
// from the sub command.
274+
ctx, err = subCmd.run(ctx, cmd.Args().Slice())
275+
return ctx, err
266276
}
267277

268278
// This code path is the innermost command execution. Here we actually
@@ -282,7 +292,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
282292
}
283293
if bctx, err := cmd.Before(ctx, cmd); err != nil {
284294
deferErr = cmd.handleExitCoder(ctx, err)
285-
return deferErr
295+
return ctx, deferErr
286296
} else if bctx != nil {
287297
ctx = bctx
288298
}
@@ -294,14 +304,14 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
294304
tracef("running flag actions (cmd=%[1]q)", cmd.Name)
295305
if err := cmd.runFlagActions(ctx); err != nil {
296306
deferErr = cmd.handleExitCoder(ctx, err)
297-
return deferErr
307+
return ctx, deferErr
298308
}
299309
}
300310

301311
if err := cmd.checkAllRequiredFlags(); err != nil {
302312
cmd.isInError = true
303313
_ = ShowSubcommandHelp(cmd)
304-
return err
314+
return ctx, err
305315
}
306316

307317
// Run the command action.
@@ -317,7 +327,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
317327
err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil)
318328
}
319329
err = cmd.handleExitCoder(ctx, err)
320-
return err
330+
return ctx, err
321331
}
322332
}
323333
cmd.parsedArgs = &stringSliceArgs{v: rargs}
@@ -329,5 +339,5 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
329339
}
330340

331341
tracef("returning deferErr (cmd=%[1]q) %[2]q", cmd.Name, deferErr)
332-
return deferErr
342+
return ctx, deferErr
333343
}

command_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,138 @@ func TestCommand_Run_BeforeReturnNewContext(t *testing.T) {
351351
require.Equal(t, "bval", receivedValFromAction)
352352
}
353353

354+
type ctxKey string
355+
356+
// ctxCollector is a small helper to collect context values.
357+
type ctxCollector struct {
358+
// keys are the keys to check the context for.
359+
keys []ctxKey
360+
361+
// m maps from function name to context name to value.
362+
m map[string]map[ctxKey]string
363+
}
364+
365+
func (cc *ctxCollector) collect(ctx context.Context, fnName string) {
366+
if cc.m == nil {
367+
cc.m = make(map[string]map[ctxKey]string)
368+
}
369+
370+
if _, ok := cc.m[fnName]; !ok {
371+
cc.m[fnName] = make(map[ctxKey]string)
372+
}
373+
374+
for _, k := range cc.keys {
375+
if val := ctx.Value(k); val != nil {
376+
cc.m[fnName][k] = val.(string)
377+
}
378+
}
379+
}
380+
381+
func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) {
382+
bkey := ctxKey("bkey")
383+
bkey2 := ctxKey("bkey2")
384+
385+
cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}}
386+
cmd := &Command{
387+
Name: "bar",
388+
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
389+
return context.WithValue(ctx, bkey, "bval"), nil
390+
},
391+
After: func(ctx context.Context, cmd *Command) error {
392+
cc.collect(ctx, "bar.After")
393+
return nil
394+
},
395+
Commands: []*Command{
396+
{
397+
Name: "baz",
398+
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
399+
return context.WithValue(ctx, bkey2, "bval2"), nil
400+
},
401+
Action: func(ctx context.Context, cmd *Command) error {
402+
cc.collect(ctx, "baz.Action")
403+
return nil
404+
},
405+
After: func(ctx context.Context, cmd *Command) error {
406+
cc.collect(ctx, "baz.After")
407+
return nil
408+
},
409+
},
410+
},
411+
}
412+
413+
require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "baz"}))
414+
expected := map[string]map[ctxKey]string{
415+
"bar.After": {
416+
bkey: "bval",
417+
bkey2: "bval2",
418+
},
419+
"baz.Action": {
420+
bkey: "bval",
421+
bkey2: "bval2",
422+
},
423+
"baz.After": {
424+
bkey: "bval",
425+
bkey2: "bval2",
426+
},
427+
}
428+
require.Equal(t, expected, cc.m)
429+
}
430+
431+
func TestCommand_Run_FlagActionContext(t *testing.T) {
432+
bkey := ctxKey("bkey")
433+
bkey2 := ctxKey("bkey2")
434+
435+
cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}}
436+
cmd := &Command{
437+
Name: "bar",
438+
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
439+
return context.WithValue(ctx, bkey, "bval"), nil
440+
},
441+
Flags: []Flag{
442+
&StringFlag{
443+
Name: "foo",
444+
Action: func(ctx context.Context, cmd *Command, _ string) error {
445+
cc.collect(ctx, "bar.foo.Action")
446+
return nil
447+
},
448+
},
449+
},
450+
Commands: []*Command{
451+
{
452+
Name: "baz",
453+
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
454+
return context.WithValue(ctx, bkey2, "bval2"), nil
455+
},
456+
Flags: []Flag{
457+
&StringFlag{
458+
Name: "goo",
459+
Action: func(ctx context.Context, cmd *Command, _ string) error {
460+
cc.collect(ctx, "baz.goo.Action")
461+
return nil
462+
},
463+
},
464+
},
465+
Action: func(ctx context.Context, cmd *Command) error {
466+
return nil
467+
},
468+
},
469+
},
470+
}
471+
472+
require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "--foo", "value", "baz", "--goo", "value"}))
473+
expected := map[string]map[ctxKey]string{
474+
"bar.foo.Action": {
475+
bkey: "bval",
476+
bkey2: "bval2",
477+
},
478+
"baz.goo.Action": {
479+
bkey: "bval",
480+
bkey2: "bval2",
481+
},
482+
}
483+
require.Equal(t, expected, cc.m)
484+
}
485+
354486
func TestCommand_OnUsageError_hasCommandContext(t *testing.T) {
355487
cmd := &Command{
356488
Name: "bar",

0 commit comments

Comments
 (0)