Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ import (
// After calling ResetForTesting, parse errors in flag handling will not
// exit the program.
func ResetForTesting(usage func()) {
CommandLine = &FlagSet{
name: os.Args[0],
errorHandling: ContinueOnError,
output: ioutil.Discard,
}
CommandLine = NewFlagSet(os.Args[0], ContinueOnError)
CommandLine.output = ioutil.Discard
Usage = usage
}

Expand Down
236 changes: 111 additions & 125 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,15 @@ type FlagSet struct {
output io.Writer // nil means stderr; use Output() accessor
interspersed bool // allow interspersed option/non-option args
normalizeNameFunc func(f *FlagSet, name string) NormalizedName
unknownFlags *[]string

addedGoFlagSets []*goflag.FlagSet
unknownFlags []*UnknownFlag
}

// A UnknownFlag represents the state of a flag that is not expected.
type UnknownFlag struct {
Name string // name, as it appears on command line
Value string // argument, if provided
}

// A Flag represents the state of a flag.
Expand Down Expand Up @@ -193,20 +199,18 @@ func (f *FlagSet) VisitAll(fn func(*Flag)) {
if len(f.formal) == 0 {
return
}

var flags []*Flag
if f.SortFlags {
if len(f.formal) != len(f.sortedFormal) {
f.sortedFormal = sortFlags(f.formal)
}
flags = f.sortedFormal
} else {
flags = f.orderedFormal
for _, flag := range f.GetAllFlags() {
fn(flag)
}
}

for _, flag := range flags {
fn(flag)
// GetAllFlags return the flags in lexicographical order or
// in primordial order if f.SortFlags is false.
func (f *FlagSet) GetAllFlags() []*Flag {
if f.SortFlags && len(f.formal) != len(f.sortedFormal) {
f.sortedFormal = sortFlags(f.formal)
}
return f.sortedFormal
}

// HasFlags returns a bool to indicate if the FlagSet has any flags defined.
Expand All @@ -232,27 +236,31 @@ func VisitAll(fn func(*Flag)) {
CommandLine.VisitAll(fn)
}

// GetAllFlags return the flags in lexicographical order or
// in primordial order if f.SortFlags is false.
func GetAllFlags() []*Flag {
return CommandLine.GetAllFlags()
}

// Visit visits the flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits only those flags that have been set.
func (f *FlagSet) Visit(fn func(*Flag)) {
if len(f.actual) == 0 {
return
}

var flags []*Flag
if f.SortFlags {
if len(f.actual) != len(f.sortedActual) {
f.sortedActual = sortFlags(f.actual)
}
flags = f.sortedActual
} else {
flags = f.orderedActual
for _, flag := range f.GetFlags() {
fn(flag)
}
}

for _, flag := range flags {
fn(flag)
// GetFlags return the flags in lexicographical order or
// in primordial order if f.SortFlags is false.
func (f *FlagSet) GetFlags() []*Flag {
if f.SortFlags && len(f.actual) != len(f.sortedActual) {
f.sortedActual = sortFlags(f.actual)
}
return f.sortedActual
}

// Visit visits the command-line flags in lexicographical order or
Expand All @@ -262,6 +270,45 @@ func Visit(fn func(*Flag)) {
CommandLine.Visit(fn)
}

// GetFlags return the flags in lexicographical order or
// in primordial order if f.SortFlags is false.
func GetFlags() []*Flag {
return CommandLine.GetFlags()
}

// VisitUnknowns visits all the flags that have not been registered.
func (f *FlagSet) VisitUnknowns(fn func(*UnknownFlag)) {
if len(f.unknownFlags) == 0 {
return
}
for _, flag := range f.unknownFlags {
fn(flag)
}
}

// GetUnknownFlags returns unknown flags found during Parse.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func (f *FlagSet) GetUnknownFlags() []*UnknownFlag {
return f.unknownFlags
}

func (f *FlagSet) addUnknownFlag(name, value string) {
f.unknownFlags = append(f.unknownFlags, &UnknownFlag{name, value})
}

// VisitUnknowns visits all the flags that have not been registered.
func VisitUnknowns(fn func(*UnknownFlag)) {
CommandLine.VisitUnknowns(fn)
}

// GetUnknownFlags returns unknown flags found during Parse.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func GetUnknownFlags() []*UnknownFlag {
return CommandLine.GetUnknownFlags()
}

// Lookup returns the Flag structure of the named flag, returning nil if none exists.
func (f *FlagSet) Lookup(name string) *Flag {
return f.lookup(f.normalizeFlagName(name))
Expand Down Expand Up @@ -883,36 +930,6 @@ func (f *FlagSet) usage() {
}
}

func (f *FlagSet) addUnknownFlag(s string) {
if f.unknownFlags == nil {
f.unknownFlags = new([]string)
}
*f.unknownFlags = append(*f.unknownFlags, s)
}

//--unknown (args will be empty)
//--unknown --next-flag ... (args will be --next-flag ...)
//--unknown arg ... (args will be arg ...)
func (f *FlagSet) stripUnknownFlagValue(args []string) []string {
if len(args) == 0 {
//--unknown
return args
}

first := args[0]
if len(first) > 0 && first[0] == '-' {
//--unknown --next-flag ...
return args
}

//--unknown arg ... (args will be arg ...)
if len(args) > 1 {
f.addUnknownFlag(args[0])
return args[1:]
}
return nil
}

func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
name := s[2:]
Expand All @@ -926,21 +943,12 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
flag, exists := f.formal[f.normalizeFlagName(name)]

if !exists || flag.ShorthandOnly {
switch {
case !f.DisableBuiltinHelp && name == "help":
if !f.DisableBuiltinHelp && name == "help" {
f.usage()
err = ErrHelp
return
case f.ParseErrorsWhitelist.UnknownFlags:
f.addUnknownFlag(s)
// --unknown=unknownval arg ...
// we do not want to lose arg in this case
if len(split) >= 2 {
return a, nil
}

return f.stripUnknownFlagValue(a), nil
default:
}
if !f.ParseErrorsWhitelist.UnknownFlags {
err = f.failf("unknown flag: --%s", name)
return
}
Expand All @@ -950,16 +958,23 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
if len(split) == 2 {
// '--flag=arg'
value = split[1]
} else if flag.NoOptDefVal != "" {
} else if exists && flag.NoOptDefVal != "" {
// '--flag' (arg was optional)
value = flag.NoOptDefVal
} else if len(a) > 0 {
// '--flag arg'
value = a[0]
a = a[1:]
} else {
// '--flag' (arg was required)
err = f.failf("flag needs an argument: %s", s)
if !exists && strings.HasPrefix(a[0], "-") {
value = ""
} else {
value = a[0]
a = a[1:]
}
} else if f.ParseErrorsWhitelist.UnknownFlags {
value = ""
}

if !exists && f.ParseErrorsWhitelist.UnknownFlags {
f.addUnknownFlag(name, value)
return
}

Expand All @@ -978,26 +993,12 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse

flag, exists := f.shorthands[c]
if !exists {
switch {
case !f.DisableBuiltinHelp && c == 'h':
if !f.DisableBuiltinHelp && c == 'h' {
f.usage()
err = ErrHelp
return
case f.ParseErrorsWhitelist.UnknownFlags:
// '-f=arg arg ...'
// we do not want to lose arg in this case
if len(shorthands) > 2 && shorthands[1] == '=' {
f.addUnknownFlag("-" + shorthands)
outShorts = ""
return
}

f.addUnknownFlag("-" + string(c))
if len(outShorts) == 0 {
outArgs = f.stripUnknownFlagValue(outArgs)
}
return
default:
}
if !f.ParseErrorsWhitelist.UnknownFlags {
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
return
}
Expand All @@ -1008,18 +1009,33 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
// '-f=arg'
value = shorthands[2:]
outShorts = ""
} else if flag.NoOptDefVal != "" {
} else if exists && flag.NoOptDefVal != "" {
// '-f' (arg was optional)
value = flag.NoOptDefVal
} else if len(shorthands) > 1 {
// '-farg'
value = shorthands[1:]
outShorts = ""
if next := f.ShorthandLookup(string(shorthands[1])); next == nil {
// preserve arg if it's a known flag
value = shorthands[1:]
outShorts = ""
}
} else if len(args) > 0 {
// '-f arg'
value = args[0]
outArgs = args[1:]
} else {
if !exists && strings.HasPrefix(args[0], "-") {
value = ""
} else {
value = args[0]
outArgs = args[1:]
}
} else if f.ParseErrorsWhitelist.UnknownFlags {
value = ""
}

if !exists && f.ParseErrorsWhitelist.UnknownFlags {
f.addUnknownFlag(string(c), value)
return
}

if flag.NoOptDefVal == "" && value == "" {
// '-f' (arg was required)
err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
return
Expand Down Expand Up @@ -1150,21 +1166,6 @@ func (f *FlagSet) Parsed() bool {
return f.parsed
}

// SetUnknownFlags sets the store for unknown flags found during Parse.
// The argument s points to a slice variable in which to store the values.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func (f *FlagSet) SetUnknownFlags(s *[]string) {
f.unknownFlags = s
}

// GetUnknownFlags returns unknown flags found during Parse.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func (f *FlagSet) GetUnknownFlags() *[]string {
return f.unknownFlags
}

// Parse parses the command-line flags from os.Args[1:]. Must be called
// after all flags are defined and before flags are accessed by the program.
func Parse() {
Expand All @@ -1190,21 +1191,6 @@ func Parsed() bool {
return CommandLine.Parsed()
}

// SetUnknownFlags sets the store for unknown flags found during Parse.
// The argument s points to a slice variable in which to store the values.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func SetUnknownFlags(s *[]string) {
CommandLine.SetUnknownFlags(s)
}

// GetUnknownFlags returns unknown flags found during Parse.
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
// parsing does not abort on the first unknown flag.
func GetUnknownFlags() *[]string {
return CommandLine.GetUnknownFlags()
}

// CommandLine is the default set of command-line flags, parsed from os.Args.
var CommandLine = NewFlagSet(os.Args[0], ExitOnError)

Expand Down
Loading