Skip to content
This repository was archived by the owner on Apr 19, 2024. It is now read-only.

Commit 8a89877

Browse files
Add WriteAnswer support for promoted fields (#366)
1 parent 3fff19a commit 8a89877

File tree

2 files changed

+193
-39
lines changed

2 files changed

+193
-39
lines changed

core/write.go

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ type OptionAnswer struct {
2424
Index int
2525
}
2626

27+
type reflectField struct {
28+
value reflect.Value
29+
fieldType reflect.StructField
30+
}
31+
2732
func OptionAnswerList(incoming []string) []OptionAnswer {
2833
list := []OptionAnswer{}
2934
for i, opt := range incoming {
@@ -63,13 +68,12 @@ func WriteAnswer(t interface{}, name string, v interface{}) (err error) {
6368
}
6469

6570
// get the name of the field that matches the string we were given
66-
fieldIndex, err := findFieldIndex(elem, name)
71+
field, _, err := findField(elem, name)
6772
// if something went wrong
6873
if err != nil {
6974
// bubble up
7075
return err
7176
}
72-
field := elem.Field(fieldIndex)
7377
// handle references to the Settable interface aswell
7478
if s, ok := field.Interface().(Settable); ok {
7579
// use the interface method
@@ -156,37 +160,51 @@ func IsFieldNotMatch(err error) (string, bool) {
156160

157161
// BUG(AlecAivazis): the current implementation might cause weird conflicts if there are
158162
// two fields with same name that only differ by casing.
159-
func findFieldIndex(s reflect.Value, name string) (int, error) {
160-
// the type of the value
161-
sType := s.Type()
163+
func findField(s reflect.Value, name string) (reflect.Value, reflect.StructField, error) {
162164

163-
// first look for matching tags so we can overwrite matching field names
164-
for i := 0; i < sType.NumField(); i++ {
165-
// the field we are current scanning
166-
field := sType.Field(i)
165+
fields := flattenFields(s)
167166

167+
// first look for matching tags so we can overwrite matching field names
168+
for _, f := range fields {
168169
// the value of the survey tag
169-
tag := field.Tag.Get(tagName)
170+
tag := f.fieldType.Tag.Get(tagName)
170171
// if the tag matches the name we are looking for
171172
if tag != "" && tag == name {
172173
// then we found our index
173-
return i, nil
174+
return f.value, f.fieldType, nil
174175
}
175176
}
176177

177178
// then look for matching names
178-
for i := 0; i < sType.NumField(); i++ {
179-
// the field we are current scanning
180-
field := sType.Field(i)
181-
179+
for _, f := range fields {
182180
// if the name of the field matches what we're looking for
183-
if strings.ToLower(field.Name) == strings.ToLower(name) {
184-
return i, nil
181+
if strings.ToLower(f.fieldType.Name) == strings.ToLower(name) {
182+
return f.value, f.fieldType, nil
185183
}
186184
}
187185

188186
// we didn't find the field
189-
return -1, errFieldNotMatch{name}
187+
return reflect.Value{}, reflect.StructField{}, errFieldNotMatch{name}
188+
}
189+
190+
func flattenFields(s reflect.Value) []reflectField {
191+
sType := s.Type()
192+
numField := sType.NumField()
193+
fields := make([]reflectField, 0, numField)
194+
for i := 0; i < numField; i++ {
195+
fieldType := sType.Field(i)
196+
field := s.Field(i)
197+
198+
if field.Kind() == reflect.Struct && fieldType.Anonymous {
199+
// field is a promoted structure
200+
for _, f := range flattenFields(field) {
201+
fields = append(fields, f)
202+
}
203+
continue
204+
}
205+
fields = append(fields, reflectField{field, fieldType})
206+
}
207+
return fields
190208
}
191209

192210
// isList returns true if the element is something we can Len()

core/write_test.go

Lines changed: 157 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,12 @@ func TestWriteAnswer_returnsErrWhenFieldNotFound(t *testing.T) {
305305
}
306306
}
307307

308-
func TestFindFieldIndex_canFindExportedField(t *testing.T) {
308+
func TestFindField_canFindExportedField(t *testing.T) {
309309
// create a reflective wrapper over the struct to look through
310-
val := reflect.ValueOf(struct{ Name string }{})
310+
val := reflect.ValueOf(struct{ Name string }{Name: "Jack"})
311311

312312
// find the field matching "name"
313-
fieldIndex, err := findFieldIndex(val, "name")
313+
field, fieldType, err := findField(val, "name")
314314
// if something went wrong
315315
if err != nil {
316316
// the test failed
@@ -319,20 +319,28 @@ func TestFindFieldIndex_canFindExportedField(t *testing.T) {
319319
}
320320

321321
// make sure we got the right value
322-
if val.Type().Field(fieldIndex).Name != "Name" {
322+
if field.Interface() != "Jack" {
323323
// the test failed
324-
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name)
324+
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
325+
}
326+
327+
// make sure we got the right field type
328+
if fieldType.Name != "Name" {
329+
// the test failed
330+
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
325331
}
326332
}
327333

328-
func TestFindFieldIndex_canFindTaggedField(t *testing.T) {
334+
func TestFindField_canFindTaggedField(t *testing.T) {
329335
// the struct to look through
330336
val := reflect.ValueOf(struct {
331337
Username string `survey:"name"`
332-
}{})
338+
}{
339+
Username: "Jack",
340+
})
333341

334342
// find the field matching "name"
335-
fieldIndex, err := findFieldIndex(val, "name")
343+
field, fieldType, err := findField(val, "name")
336344
// if something went wrong
337345
if err != nil {
338346
// the test failed
@@ -341,52 +349,180 @@ func TestFindFieldIndex_canFindTaggedField(t *testing.T) {
341349
}
342350

343351
// make sure we got the right value
344-
if val.Type().Field(fieldIndex).Name != "Username" {
352+
if field.Interface() != "Jack" {
345353
// the test failed
346-
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name)
354+
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
355+
}
356+
357+
// make sure we got the right fieldType
358+
if fieldType.Name != "Username" {
359+
// the test failed
360+
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
347361
}
348362
}
349363

350-
func TestFindFieldIndex_canHandleCapitalAnswerNames(t *testing.T) {
364+
func TestFindField_canHandleCapitalAnswerNames(t *testing.T) {
351365
// create a reflective wrapper over the struct to look through
352-
val := reflect.ValueOf(struct{ Name string }{})
366+
val := reflect.ValueOf(struct{ Name string }{Name: "Jack"})
353367

354368
// find the field matching "name"
355-
fieldIndex, err := findFieldIndex(val, "Name")
369+
field, fieldType, err := findField(val, "Name")
356370
// if something went wrong
357371
if err != nil {
358372
// the test failed
359373
t.Error(err.Error())
360374
return
361375
}
362-
363376
// make sure we got the right value
364-
if val.Type().Field(fieldIndex).Name != "Name" {
377+
if field.Interface() != "Jack" {
378+
// the test failed
379+
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
380+
}
381+
382+
// make sure we got the right fieldType
383+
if fieldType.Name != "Name" {
365384
// the test failed
366-
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name)
385+
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
367386
}
368387
}
369388

370-
func TestFindFieldIndex_tagOverwriteFieldName(t *testing.T) {
389+
func TestFindField_tagOverwriteFieldName(t *testing.T) {
371390
// the struct to look through
372391
val := reflect.ValueOf(struct {
373392
Name string
374393
Username string `survey:"name"`
375-
}{})
394+
}{
395+
Name: "Ralf",
396+
Username: "Jack",
397+
})
398+
399+
// find the field matching "name"
400+
field, fieldType, err := findField(val, "name")
401+
// if something went wrong
402+
if err != nil {
403+
// the test failed
404+
t.Error(err.Error())
405+
return
406+
}
407+
408+
// make sure we got the right value
409+
if field.Interface() != "Jack" {
410+
// the test failed
411+
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
412+
}
413+
414+
// make sure we got the right fieldType
415+
if fieldType.Name != "Username" {
416+
// the test failed
417+
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
418+
}
419+
}
420+
421+
func TestFindField_supportsPromotedFields(t *testing.T) {
422+
// create a reflective wrapper over the struct to look through
423+
type Common struct {
424+
Name string
425+
}
426+
427+
type Strct struct {
428+
Common // Name field added by composition
429+
Username string
430+
}
431+
432+
val := reflect.ValueOf(Strct{Common: Common{Name: "Jack"}})
433+
434+
// find the field matching "name"
435+
field, fieldType, err := findField(val, "Name")
436+
// if something went wrong
437+
if err != nil {
438+
// the test failed
439+
t.Error(err.Error())
440+
return
441+
}
442+
// make sure we got the right value
443+
if field.Interface() != "Jack" {
444+
// the test failed
445+
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
446+
}
447+
448+
// make sure we got the right fieldType
449+
if fieldType.Name != "Name" {
450+
// the test failed
451+
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
452+
}
453+
}
454+
455+
func TestFindField_promotedFieldsWithTag(t *testing.T) {
456+
// create a reflective wrapper over the struct to look through
457+
type Common struct {
458+
Username string `survey:"name"`
459+
}
460+
461+
type Strct struct {
462+
Common // Name field added by composition
463+
Name string
464+
}
465+
466+
val := reflect.ValueOf(Strct{
467+
Common: Common{Username: "Jack"},
468+
Name: "Ralf",
469+
})
376470

377471
// find the field matching "name"
378-
fieldIndex, err := findFieldIndex(val, "name")
472+
field, fieldType, err := findField(val, "name")
379473
// if something went wrong
380474
if err != nil {
381475
// the test failed
382476
t.Error(err.Error())
383477
return
384478
}
479+
// make sure we got the right value
480+
if field.Interface() != "Jack" {
481+
// the test failed
482+
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
483+
}
385484

485+
// make sure we got the right fieldType
486+
if fieldType.Name != "Username" {
487+
// the test failed
488+
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
489+
}
490+
}
491+
492+
func TestFindField_promotedFieldsDontHavePriorityOverTags(t *testing.T) {
493+
// create a reflective wrapper over the struct to look through
494+
type Common struct {
495+
Name string
496+
}
497+
498+
type Strct struct {
499+
Common // Name field added by composition
500+
Username string `survey:"name"`
501+
}
502+
503+
val := reflect.ValueOf(Strct{
504+
Common: Common{Name: "Ralf"},
505+
Username: "Jack",
506+
})
507+
508+
// find the field matching "name"
509+
field, fieldType, err := findField(val, "name")
510+
// if something went wrong
511+
if err != nil {
512+
// the test failed
513+
t.Error(err.Error())
514+
return
515+
}
386516
// make sure we got the right value
387-
if val.Type().Field(fieldIndex).Name != "Username" {
517+
if field.Interface() != "Jack" {
518+
// the test failed
519+
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
520+
}
521+
522+
// make sure we got the right fieldType
523+
if fieldType.Name != "Username" {
388524
// the test failed
389-
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name)
525+
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
390526
}
391527
}
392528

0 commit comments

Comments
 (0)