Skip to content

Commit 6876699

Browse files
committed
Provide option to flatten structs
1 parent 786c076 commit 6876699

File tree

3 files changed

+86
-10
lines changed

3 files changed

+86
-10
lines changed

docs.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ the field should be ignored and no flag is declared. For example,
174174
Host string `flag:"server_address"
175175
GetsIgnored string `flag:""`
176176
177+
To help with organization, struct fields can be flattened such that the resolved flag name does not include the name of the struct itself. For example, this struct will accept the flags named simply `-host` and `-port`.
178+
179+
type struct Config {
180+
NetworkConfig struct {
181+
Host string
182+
Port int
183+
}
184+
}
185+
177186
Environment variable naming and processing can be overridden with the `env:"name"` tag, where
178187
the given name will be used exactly as the mapped environment variable name. If the WithEnv
179188
or WithEnvRenamer options were enabled, a field can be excluded from environment variable

flagset.go

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ var (
1818
stringToStringMapType = reflect.TypeOf(map[string]string{})
1919
)
2020

21+
const (
22+
TagDefault = "default"
23+
TagEnv = "env"
24+
TagFlag = "flag"
25+
TagFlatten = "flatten"
26+
TagOverrideValue = "override-value"
27+
TagType = "type"
28+
)
29+
2130
// FlagSetFiller is used to map the fields of a struct into flags of a flag.FlagSet
2231
type FlagSetFiller struct {
2332
options *fillerOptions
@@ -57,6 +66,8 @@ func (f *FlagSetFiller) Fill(flagSet *flag.FlagSet, from interface{}) error {
5766
}
5867
}
5968

69+
// isSupportedStruct checks if the given field reference is a registered extended type or implements
70+
// encoding.TextUnmarshaler
6071
func isSupportedStruct(in any) bool {
6172
t := reflect.TypeOf(in)
6273
_, ok := extendedTypes[getTypeName(t)]
@@ -106,15 +117,14 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string,
106117
field := structType.Field(i)
107118
fieldValue := structVal.Field(i)
108119

109-
if flagTag, ok := field.Tag.Lookup("flag"); ok {
120+
if flagTag, ok := field.Tag.Lookup(TagFlag); ok {
110121
if flagTag == "" {
111122
continue
112123
}
113124
}
114125

115126
switch field.Type.Kind() {
116127
case reflect.Struct:
117-
// fieldTypeName := getTypeName(field.Type)
118128
if field.IsExported() {
119129
if isSupportedStruct(fieldValue.Addr().Interface()) {
120130
err := handleDefault(field, fieldValue)
@@ -124,7 +134,11 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string,
124134
continue
125135
}
126136
}
127-
err := f.walkFields(flagSet, prefix+field.Name, fieldValue, field.Type)
137+
138+
err := f.walkFields(flagSet,
139+
qualifiedNameForStructField(field, prefix),
140+
fieldValue,
141+
field.Type)
128142
if err != nil {
129143
return fmt.Errorf("failed to process %s of %s: %w", field.Name, structType.String(), err)
130144
}
@@ -146,7 +160,10 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string,
146160
}
147161
}
148162

149-
err := f.walkFields(flagSet, field.Name, fieldValue.Elem(), field.Type.Elem())
163+
err := f.walkFields(flagSet,
164+
qualifiedNameForStructField(field, prefix),
165+
fieldValue.Elem(),
166+
field.Type.Elem())
150167
if err != nil {
151168
return fmt.Errorf("failed to process %s of %s: %w", field.Name, structType.String(), err)
152169
}
@@ -163,11 +180,27 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string,
163180
return nil
164181
}
165182

183+
func qualifiedNameForStructField(field reflect.StructField, prefix string) string {
184+
if !shouldFlatten(field) {
185+
return prefix + field.Name
186+
} else {
187+
return prefix
188+
}
189+
}
190+
191+
func shouldFlatten(field reflect.StructField) bool {
192+
value, ok := field.Tag.Lookup(TagFlatten)
193+
if !ok {
194+
return false
195+
}
196+
return value == "true"
197+
}
198+
166199
func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{},
167200
name string, t reflect.Type, tag reflect.StructTag) (err error) {
168201

169202
var envName string
170-
if override, exists := tag.Lookup("env"); exists {
203+
if override, exists := tag.Lookup(TagEnv); exists {
171204
envName = override
172205
} else if len(f.options.envRenamer) > 0 {
173206
envName = name
@@ -182,12 +215,12 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
182215
usage = fmt.Sprintf("%s (env %s)", usage, envName)
183216
}
184217

185-
tagDefault, hasDefaultTag := tag.Lookup("default")
218+
tagDefault, hasDefaultTag := tag.Lookup(TagDefault)
186219

187-
fieldType, _ := tag.Lookup("type")
220+
fieldType, _ := tag.Lookup(TagType)
188221

189222
var renamed string
190-
if override, exists := tag.Lookup("flag"); exists {
223+
if override, exists := tag.Lookup(TagFlag); exists {
191224
if override == "" {
192225
// empty flag override signal to skip this field
193226
return nil
@@ -231,7 +264,7 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
231264

232265
case t == stringSliceType, fieldType == "stringSlice":
233266
var override bool
234-
if overrideValue, exists := tag.Lookup("override-value"); exists {
267+
if overrideValue, exists := tag.Lookup(TagOverrideValue); exists {
235268
if value, err := strconv.ParseBool(overrideValue); err == nil {
236269
override = value
237270
}

flagset_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ func TestNestedStructPtr(t *testing.T) {
278278
type Config struct {
279279
Host string
280280
SomeGrouping *Nested
281+
Inner struct {
282+
Deeper *Nested
283+
}
281284
}
282285

283286
var config Config
@@ -288,11 +291,15 @@ func TestNestedStructPtr(t *testing.T) {
288291
err := filler.Fill(&flagset, &config)
289292
require.NoError(t, err)
290293

291-
err = flagset.Parse([]string{"--host", "h1", "--some-grouping-some-field", "val1"})
294+
err = flagset.Parse([]string{"--host", "h1",
295+
"--some-grouping-some-field", "val1",
296+
"--inner-deeper-some-field", "val2"})
297+
require.NoError(t, err)
292298
require.NoError(t, err)
293299

294300
assert.Equal(t, "h1", config.Host)
295301
assert.Equal(t, "val1", config.SomeGrouping.SomeField)
302+
assert.Equal(t, "val2", config.Inner.Deeper.SomeField)
296303
}
297304

298305
func TestNestedUnexportedStructPtr(t *testing.T) {
@@ -880,6 +887,33 @@ func TestFlagNameOverride(t *testing.T) {
880887

881888
}
882889

890+
func TestFlatten(t *testing.T) {
891+
type Config struct {
892+
Flattened struct {
893+
FlattenedField string
894+
} `flatten:"true"`
895+
PtrFlattened *struct {
896+
PtrFlattenedField string
897+
} `flatten:"true"`
898+
}
899+
900+
var config Config
901+
902+
filler := flagsfiller.New()
903+
904+
var flagset flag.FlagSet
905+
err := filler.Fill(&flagset, &config)
906+
require.NoError(t, err)
907+
908+
err = flagset.Parse([]string{"--flattened-field", "val1",
909+
"--ptr-flattened-field", "val2"})
910+
require.NoError(t, err)
911+
require.NoError(t, err)
912+
913+
assert.Equal(t, "val1", config.Flattened.FlattenedField)
914+
assert.Equal(t, "val2", config.PtrFlattened.PtrFlattenedField)
915+
}
916+
883917
type flagSet interface {
884918
SetOutput(io.Writer)
885919
PrintDefaults()

0 commit comments

Comments
 (0)