Skip to content

Commit 9ad0964

Browse files
committed
Refactor parsing of embedded structs to types table conversion
1 parent 48a3857 commit 9ad0964

File tree

2 files changed

+53
-38
lines changed

2 files changed

+53
-38
lines changed

parser.go

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ type parser struct {
6464
position int
6565
current token
6666
strict bool
67-
types map[string]Type
67+
types typesTable
6868
}
6969

7070
// OptionFn for configuring parser.
@@ -81,7 +81,7 @@ func Parse(input string, ops ...OptionFn) (Node, error) {
8181
input: input,
8282
tokens: tokens,
8383
current: tokens[0],
84-
types: make(map[string]Type),
84+
types: make(typesTable),
8585
}
8686

8787
for _, op := range ops {
@@ -116,60 +116,69 @@ func Define(name string, t interface{}) OptionFn {
116116
}
117117

118118
// With sets variables for type checks during parsing.
119-
// If struct is passed, all fields will be treated as variables.
119+
// If struct is passed, all fields will be treated as variables,
120+
// as well as all fields of embedded structs.
121+
//
120122
// If map is passed, all items will be treated as variables
121123
// (key as name, value as type).
122124
func With(i interface{}) OptionFn {
123125
return func(p *parser) {
124126
p.strict = true
125-
v := reflect.ValueOf(i)
126-
t := reflect.TypeOf(i)
127-
t = dereference(t)
128-
if t == nil {
129-
return
127+
for k, v := range p.createTypesTable(i) {
128+
p.types[k] = v
130129
}
130+
}
131+
}
131132

132-
switch t.Kind() {
133-
case reflect.Struct:
134-
for i := 0; i < t.NumField(); i++ {
135-
f := t.Field(i)
136-
p.types[f.Name] = f.Type
133+
func (p *parser) createTypesTable(i interface{}) typesTable {
134+
types := make(typesTable)
135+
v := reflect.ValueOf(i)
136+
t := reflect.TypeOf(i)
137137

138-
for name, typ := range p.findEmbeddedFieldNames(f.Type) {
139-
p.types[name] = typ
140-
}
141-
}
142-
case reflect.Map:
143-
for _, key := range v.MapKeys() {
144-
value := v.MapIndex(key)
145-
if key.Kind() == reflect.String && value.IsValid() && value.CanInterface() {
146-
p.types[key.String()] = reflect.TypeOf(value.Interface())
147-
}
138+
t = dereference(t)
139+
if t == nil {
140+
return types
141+
}
142+
143+
switch t.Kind() {
144+
case reflect.Struct:
145+
types = p.fromStruct(t)
146+
147+
case reflect.Map:
148+
for _, key := range v.MapKeys() {
149+
value := v.MapIndex(key)
150+
if key.Kind() == reflect.String && value.IsValid() && value.CanInterface() {
151+
types[key.String()] = reflect.TypeOf(value.Interface())
148152
}
149153
}
150154
}
155+
156+
return types
151157
}
152158

153-
func (p *parser) findEmbeddedFieldNames(t reflect.Type) map[string]Type {
159+
func (p *parser) fromStruct(t reflect.Type) typesTable {
160+
types := make(typesTable)
154161
t = dereference(t)
162+
if t == nil {
163+
return types
164+
}
155165

156-
res := make(map[string]Type)
157-
if t.Kind() == reflect.Struct {
166+
switch t.Kind() {
167+
case reflect.Struct:
158168
for i := 0; i < t.NumField(); i++ {
159169
f := t.Field(i)
160170

161-
fType := dereference(f.Type)
162-
if fType.Kind() == reflect.Struct && f.Anonymous && fType.Name() == f.Name {
163-
for name, typ := range p.findEmbeddedFieldNames(fType) {
164-
res[name] = typ
171+
if f.Anonymous {
172+
for name, typ := range p.fromStruct(f.Type) {
173+
types[name] = typ
165174
}
175+
} else {
176+
types[f.Name] = f.Type
166177
}
167-
168-
res[f.Name] = fType
169178
}
170179
}
171180

172-
return res
181+
return types
173182
}
174183

175184
func (p *parser) errorf(format string, args ...interface{}) *syntaxError {

parser_test.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ func TestParse_error(t *testing.T) {
208208
}
209209
}
210210

211-
func TestParser_findEmbeddedFieldNames(t *testing.T) {
211+
func TestParser_createTypesTable(t *testing.T) {
212+
var intType = reflect.TypeOf(0)
213+
212214
type (
213215
D struct {
214216
F2 int
@@ -228,12 +230,16 @@ func TestParser_findEmbeddedFieldNames(t *testing.T) {
228230
}
229231
)
230232

231-
res := new(parser).findEmbeddedFieldNames(reflect.TypeOf(A{}))
232-
if res["F"] != reflect.TypeOf(1) {
233+
p := parser{}
234+
types := p.createTypesTable(A{})
235+
236+
if len(types) != 2 {
237+
t.Error("unexpected number of fields")
238+
}
239+
if types["F"] != intType {
233240
t.Error("expected embedded struct field 'F'")
234241
}
235-
236-
if res["F2"] != reflect.TypeOf(1) {
242+
if types["F2"] != intType {
237243
t.Error("expected embedded struct field 'F2'")
238244
}
239245
}

0 commit comments

Comments
 (0)