Skip to content

Commit a10efb5

Browse files
authored
feat: Add option to unwrap embedded structs 1 level down (#111)
1 parent 6bdaf68 commit a10efb5

File tree

3 files changed

+226
-94
lines changed

3 files changed

+226
-94
lines changed

codegen/golang.go

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,20 @@ func WithExtraColumns(columns []ColumnDefinition) TableOptions {
7272
}
7373
}
7474

75+
// Unwrap specific struct fields (1 level deep only)
76+
func WithUnwrapFieldsStructs(fields []string) TableOptions {
77+
return func(t *TableDefinition) {
78+
t.structFieldsToUnwrap = fields
79+
}
80+
}
81+
82+
// Unwrap all fields that are embedded structs (1 level deep only)
83+
func WithUnwrapAllEmbeddedStructs() TableOptions {
84+
return func(t *TableDefinition) {
85+
t.unwrapAllEmbeddedStructFields = true
86+
}
87+
}
88+
7589
func defaultTransformer(name string) string {
7690
return strcase.ToSnake(name)
7791
}
@@ -85,13 +99,70 @@ func sliceContains(arr []string, s string) bool {
8599
return false
86100
}
87101

102+
func isFieldStruct(reflectType reflect.Type) bool {
103+
return reflectType.Kind() == reflect.Struct || (reflectType.Kind() == reflect.Ptr && reflectType.Elem().Kind() == reflect.Struct)
104+
}
105+
106+
func (t *TableDefinition) shouldUnwrapField(field reflect.StructField) bool {
107+
return isFieldStruct(field.Type) && (t.unwrapAllEmbeddedStructFields && field.Anonymous || sliceContains(t.structFieldsToUnwrap, field.Name))
108+
}
109+
110+
func (t *TableDefinition) getUnwrappedFields(field reflect.StructField) []reflect.StructField {
111+
reflectType := field.Type
112+
if reflectType.Kind() == reflect.Ptr {
113+
reflectType = reflectType.Elem()
114+
}
115+
116+
fields := make([]reflect.StructField, 0)
117+
for i := 0; i < reflectType.NumField(); i++ {
118+
sf := reflectType.Field(i)
119+
if t.ignoreField(sf) {
120+
continue
121+
}
122+
123+
fields = append(fields, sf)
124+
}
125+
return fields
126+
}
127+
128+
func (t *TableDefinition) ignoreField(field reflect.StructField) bool {
129+
return len(field.Name) == 0 || unicode.IsLower(rune(field.Name[0])) || sliceContains(t.skipFields, field.Name)
130+
}
131+
132+
func (t *TableDefinition) addColumnFromField(field reflect.StructField, parentFieldName string) {
133+
if t.ignoreField(field) {
134+
return
135+
}
136+
137+
columnType, err := valueToSchemaType(field.Type)
138+
if err != nil {
139+
fmt.Printf("skipping field %s, got err: %v\n", field.Name, err)
140+
return
141+
}
142+
143+
// generate a PathResolver to use by default
144+
pathResolver := fmt.Sprintf(`schema.PathResolver("%s")`, field.Name)
145+
name := t.nameTransformer(field.Name)
146+
if parentFieldName != "" {
147+
pathResolver = fmt.Sprintf(`schema.PathResolver("%s.%s")`, parentFieldName, field.Name)
148+
name = t.nameTransformer(parentFieldName) + "_" + name
149+
}
150+
151+
column := ColumnDefinition{
152+
Name: name,
153+
Type: columnType,
154+
Resolver: pathResolver,
155+
}
156+
t.Columns = append(t.Columns, column)
157+
}
158+
88159
func NewTableFromStruct(name string, obj interface{}, opts ...TableOptions) (*TableDefinition, error) {
89-
t := TableDefinition{
160+
t := &TableDefinition{
90161
Name: name,
91162
nameTransformer: defaultTransformer,
92163
}
93164
for _, opt := range opts {
94-
opt(&t)
165+
opt(t)
95166
}
96167

97168
e := reflect.ValueOf(obj)
@@ -106,33 +177,23 @@ func NewTableFromStruct(name string, obj interface{}, opts ...TableOptions) (*Ta
106177

107178
for i := 0; i < e.NumField(); i++ {
108179
field := e.Type().Field(i)
109-
if len(field.Name) == 0 {
110-
continue
111-
}
112-
if unicode.IsLower(rune(field.Name[0])) {
113-
continue
114-
}
115-
if sliceContains(t.skipFields, field.Name) {
116-
continue
117-
}
118-
119-
columnType, err := valueToSchemaType(field.Type)
120-
if err != nil {
121-
fmt.Printf("skipping field %s, got err: %v\n", field.Name, err)
122-
continue
123-
}
124180

125-
// generate a PathResolver to use by default
126-
pathResolver := fmt.Sprintf("schema.PathResolver(%q)", field.Name)
127-
column := ColumnDefinition{
128-
Name: t.nameTransformer(field.Name),
129-
Type: columnType,
130-
Resolver: pathResolver,
181+
if t.shouldUnwrapField(field) {
182+
unwrappedFields := t.getUnwrappedFields(field)
183+
parentFieldName := ""
184+
// For non embedded structs we need to add the parent field name to the path
185+
if !field.Anonymous {
186+
parentFieldName = field.Name
187+
}
188+
for _, f := range unwrappedFields {
189+
t.addColumnFromField(f, parentFieldName)
190+
}
191+
} else {
192+
t.addColumnFromField(field, "")
131193
}
132-
t.Columns = append(t.Columns, column)
133194
}
134195

135-
return &t, nil
196+
return t, nil
136197
}
137198

138199
func (t *TableDefinition) GenerateTemplate(wr io.Writer) error {

codegen/golang_test.go

Lines changed: 129 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ import (
1313
"github.com/stretchr/testify/require"
1414
)
1515

16+
type embeddedStruct struct {
17+
EmbeddedString string
18+
}
19+
1620
type testStruct struct {
1721
// IntCol this is an example documentation comment
1822
IntCol int `json:"int_col,omitempty"`
@@ -27,80 +31,145 @@ type testStruct struct {
2731
StringArrayCol []string `json:"string_array_col,omitempty"`
2832
TimeCol time.Time `json:"time_col,omitempty"`
2933
TimePointerCol *time.Time `json:"time_pointer_col,omitempty"`
34+
*embeddedStruct
35+
}
36+
37+
type testStructWithEmbeddedStruct struct {
38+
*testStruct
39+
*embeddedStruct
40+
}
41+
42+
type testStructWithNonEmbeddedStruct struct {
43+
TestStruct *testStruct
44+
NonEmbedded *embeddedStruct
45+
}
46+
47+
var expectedColumns = []ColumnDefinition{
48+
{
49+
Name: "int_col",
50+
Type: schema.TypeInt,
51+
Resolver: `schema.PathResolver("IntCol")`,
52+
},
53+
{
54+
Name: "string_col",
55+
Type: schema.TypeString,
56+
Resolver: `schema.PathResolver("StringCol")`,
57+
},
58+
{
59+
Name: "float_col",
60+
Type: schema.TypeFloat,
61+
Resolver: `schema.PathResolver("FloatCol")`,
62+
},
63+
{
64+
Name: "bool_col",
65+
Type: schema.TypeBool,
66+
Resolver: `schema.PathResolver("BoolCol")`,
67+
},
68+
{
69+
Name: "json_col",
70+
Type: schema.TypeJSON,
71+
Resolver: `schema.PathResolver("JSONCol")`,
72+
},
73+
{
74+
Name: "int_array_col",
75+
Type: schema.TypeIntArray,
76+
Resolver: `schema.PathResolver("IntArrayCol")`,
77+
},
78+
{
79+
Name: "string_array_col",
80+
Type: schema.TypeStringArray,
81+
Resolver: `schema.PathResolver("StringArrayCol")`,
82+
},
83+
{
84+
Name: "time_col",
85+
Type: schema.TypeTimestamp,
86+
Resolver: `schema.PathResolver("TimeCol")`,
87+
},
88+
{
89+
Name: "time_pointer_col",
90+
Type: schema.TypeTimestamp,
91+
Resolver: `schema.PathResolver("TimePointerCol")`,
92+
},
3093
}
3194

3295
var expectedTestTable = TableDefinition{
96+
Name: "test_struct",
97+
Columns: expectedColumns,
98+
nameTransformer: defaultTransformer,
99+
}
100+
101+
var expectedTestTableEmbeddedStruct = TableDefinition{
102+
Name: "test_struct",
103+
Columns: append(expectedColumns, ColumnDefinition{Name: "embedded_string", Type: schema.TypeString, Resolver: `schema.PathResolver("EmbeddedString")`}),
104+
nameTransformer: defaultTransformer,
105+
}
106+
107+
var expectedTestTableNonEmbeddedStruct = TableDefinition{
33108
Name: "test_struct",
34-
Columns: []ColumnDefinition{
35-
{
36-
Name: "int_col",
37-
Type: schema.TypeInt,
38-
Resolver: `schema.PathResolver("IntCol")`,
39-
},
40-
{
41-
Name: "string_col",
42-
Type: schema.TypeString,
43-
Resolver: `schema.PathResolver("StringCol")`,
44-
},
45-
{
46-
Name: "float_col",
47-
Type: schema.TypeFloat,
48-
Resolver: `schema.PathResolver("FloatCol")`,
49-
},
50-
{
51-
Name: "bool_col",
52-
Type: schema.TypeBool,
53-
Resolver: `schema.PathResolver("BoolCol")`,
54-
},
55-
{
56-
Name: "json_col",
57-
Type: schema.TypeJSON,
58-
Resolver: `schema.PathResolver("JSONCol")`,
59-
},
60-
{
61-
Name: "int_array_col",
62-
Type: schema.TypeIntArray,
63-
Resolver: `schema.PathResolver("IntArrayCol")`,
64-
},
65-
{
66-
Name: "string_array_col",
67-
Type: schema.TypeStringArray,
68-
Resolver: `schema.PathResolver("StringArrayCol")`,
69-
},
70-
{
71-
Name: "time_col",
72-
Type: schema.TypeTimestamp,
73-
Resolver: `schema.PathResolver("TimeCol")`,
74-
},
75-
{
76-
Name: "time_pointer_col",
77-
Type: schema.TypeTimestamp,
78-
Resolver: `schema.PathResolver("TimePointerCol")`,
79-
},
109+
Columns: ColumnDefinitions{
110+
// Should not be unwrapped
111+
ColumnDefinition{Name: "test_struct", Type: schema.TypeJSON, Resolver: `schema.PathResolver("TestStruct")`},
112+
// Should be unwrapped
113+
ColumnDefinition{Name: "non_embedded_embedded_string", Type: schema.TypeString, Resolver: `schema.PathResolver("NonEmbedded.EmbeddedString")`},
80114
},
81115
nameTransformer: defaultTransformer,
82116
}
83117

84118
func TestTableFromGoStruct(t *testing.T) {
85-
table, err := NewTableFromStruct("test_struct", testStruct{})
86-
if err != nil {
87-
t.Fatal(err)
119+
type args struct {
120+
testStruct interface{}
121+
options []TableOptions
88122
}
89-
if diff := cmp.Diff(table, &expectedTestTable,
90-
cmpopts.IgnoreUnexported(TableDefinition{})); diff != "" {
91-
t.Fatalf("table does not match expected. diff (-got, +want): %v", diff)
123+
124+
tests := []struct {
125+
name string
126+
args args
127+
want TableDefinition
128+
}{
129+
{
130+
name: "should generate table from struct with default options",
131+
args: args{
132+
testStruct: testStruct{},
133+
},
134+
want: expectedTestTable,
135+
},
136+
{
137+
name: "should unwrap all embedded structs when option is set",
138+
args: args{
139+
testStruct: testStructWithEmbeddedStruct{},
140+
options: []TableOptions{WithUnwrapAllEmbeddedStructs()},
141+
},
142+
want: expectedTestTableEmbeddedStruct,
143+
},
144+
{
145+
name: "should_unwrap_specific_structs_when_option_is_set",
146+
args: args{
147+
testStruct: testStructWithNonEmbeddedStruct{},
148+
options: []TableOptions{WithUnwrapFieldsStructs([]string{"NonEmbedded"})},
149+
},
150+
want: expectedTestTableNonEmbeddedStruct,
151+
},
92152
}
93-
buf := bytes.NewBufferString("")
94-
if err := table.GenerateTemplate(buf); err != nil {
95-
t.Fatal(err)
153+
154+
for _, tt := range tests {
155+
t.Run(tt.name, func(t *testing.T) {
156+
table, err := NewTableFromStruct("test_struct", tt.args.testStruct, tt.args.options...)
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
if diff := cmp.Diff(table, &tt.want,
161+
cmpopts.IgnoreUnexported(TableDefinition{})); diff != "" {
162+
t.Fatalf("table does not match expected. diff (-got, +want): %v", diff)
163+
}
164+
buf := bytes.NewBufferString("")
165+
if err := table.GenerateTemplate(buf); err != nil {
166+
t.Fatal(err)
167+
}
168+
fmt.Println(buf.String())
169+
})
96170
}
97-
fmt.Println(buf.String())
98171
}
99172

100-
// func TestReadComments(t *testing.T) {
101-
// readComments("github.com/google/go-cmp/cmp")
102-
// }
103-
104173
func TestGenerateTemplate(t *testing.T) {
105174
type args struct {
106175
table *TableDefinition

codegen/table.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@ type ResourceDefinition struct {
1111

1212
type TableDefinition struct {
1313
Name string
14-
Description string
1514
Columns ColumnDefinitions
15+
Description string
1616
Relations []string
1717

18-
Resolver string
19-
IgnoreError string
20-
Multiplex string
21-
PostResourceResolver string
22-
PreResourceResolver string
23-
nameTransformer func(string) string
24-
skipFields []string
25-
extraColumns ColumnDefinitions
18+
Resolver string
19+
IgnoreError string
20+
Multiplex string
21+
PostResourceResolver string
22+
PreResourceResolver string
23+
nameTransformer func(string) string
24+
skipFields []string
25+
extraColumns ColumnDefinitions
26+
structFieldsToUnwrap []string
27+
unwrapAllEmbeddedStructFields bool
2628
}
2729

2830
type ColumnDefinitions []ColumnDefinition

0 commit comments

Comments
 (0)