Skip to content

Commit 102481a

Browse files
committed
database: Introduce ColumnMap
`ColumnMap` provides a cached mapping of structs exported fields to their database column names. By default, all exported struct fields are mapped to their database column names using snake case notation. The `-` (hyphen) directive for the db tag can be used to exclude certain fields. Since `ColumnMap` uses cache, the returned slice MUST NOT be modified directly.
1 parent adca848 commit 102481a

File tree

2 files changed

+84
-22
lines changed

2 files changed

+84
-22
lines changed

database/column_map.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package database
2+
3+
import (
4+
"database/sql/driver"
5+
"github.com/jmoiron/sqlx/reflectx"
6+
"reflect"
7+
"sync"
8+
)
9+
10+
// ColumnMap provides a cached mapping of structs exported fields to their database column names.
11+
type ColumnMap interface {
12+
// Columns returns database column names for a struct's exported fields in a cached manner.
13+
// Thus, the returned slice MUST NOT be modified directly.
14+
// By default, all exported struct fields are mapped to database column names using snake case notation.
15+
// The - (hyphen) directive for the db tag can be used to exclude certain fields.
16+
Columns(any) []string
17+
}
18+
19+
// NewColumnMap returns a new ColumnMap.
20+
func NewColumnMap(mapper *reflectx.Mapper) ColumnMap {
21+
return &columnMap{
22+
cache: make(map[reflect.Type][]string),
23+
mapper: mapper,
24+
}
25+
}
26+
27+
type columnMap struct {
28+
mutex sync.Mutex
29+
cache map[reflect.Type][]string
30+
mapper *reflectx.Mapper
31+
}
32+
33+
func (m *columnMap) Columns(subject any) []string {
34+
m.mutex.Lock()
35+
defer m.mutex.Unlock()
36+
37+
t, ok := subject.(reflect.Type)
38+
if !ok {
39+
t = reflect.TypeOf(subject)
40+
}
41+
42+
columns, ok := m.cache[t]
43+
if !ok {
44+
columns = m.getColumns(t)
45+
m.cache[t] = columns
46+
}
47+
48+
return columns
49+
}
50+
51+
func (m *columnMap) getColumns(t reflect.Type) []string {
52+
fields := m.mapper.TypeMap(t).Names
53+
columns := make([]string, 0, len(fields))
54+
55+
FieldLoop:
56+
for _, f := range fields {
57+
// If one of the parent fields implements the driver.Valuer interface, the field can be ignored.
58+
for parent := f.Parent; parent != nil && parent.Zero.IsValid(); parent = parent.Parent {
59+
// Check for pointer types.
60+
if _, ok := reflect.New(parent.Field.Type).Interface().(driver.Valuer); ok {
61+
continue FieldLoop
62+
}
63+
// Check for non-pointer types.
64+
if _, ok := reflect.Zero(parent.Field.Type).Interface().(driver.Valuer); ok {
65+
continue FieldLoop
66+
}
67+
}
68+
69+
columns = append(columns, f.Path)
70+
}
71+
72+
// Shrink/reduce slice length and capacity:
73+
// For a three-index slice (slice[a:b:c]), the length of the returned slice is b-a and the capacity is c-a.
74+
return columns[0:len(columns):len(columns)]
75+
}

database/db.go

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"golang.org/x/sync/semaphore"
2323
"net"
2424
"net/url"
25-
"reflect"
2625
"strconv"
2726
"strings"
2827
"sync"
@@ -37,6 +36,7 @@ type DB struct {
3736
Options *Options
3837

3938
addr string
39+
columnMap ColumnMap
4040
logger *logging.Logger
4141
tableSemaphores map[string]*semaphore.Weighted
4242
tableSemaphoresMu sync.Mutex
@@ -215,6 +215,7 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry
215215
return &DB{
216216
DB: db,
217217
Options: &c.Options,
218+
columnMap: NewColumnMap(db.Mapper),
218219
addr: addr,
219220
logger: logger,
220221
tableSemaphores: make(map[string]*semaphore.Weighted),
@@ -226,20 +227,6 @@ func (db *DB) GetAddr() string {
226227
return db.addr
227228
}
228229

229-
// BuildColumns returns all columns of the given struct.
230-
func (db *DB) BuildColumns(subject interface{}) []string {
231-
fields := db.Mapper.TypeMap(reflect.TypeOf(subject)).Names
232-
columns := make([]string, 0, len(fields))
233-
for _, f := range fields {
234-
if f.Field.Tag == "" {
235-
continue
236-
}
237-
columns = append(columns, f.Name)
238-
}
239-
240-
return columns
241-
}
242-
243230
// BuildDeleteStmt returns a DELETE statement for the given struct.
244231
func (db *DB) BuildDeleteStmt(from interface{}) string {
245232
return fmt.Sprintf(
@@ -250,7 +237,7 @@ func (db *DB) BuildDeleteStmt(from interface{}) string {
250237

251238
// BuildInsertStmt returns an INSERT INTO statement for the given struct.
252239
func (db *DB) BuildInsertStmt(into interface{}) (string, int) {
253-
columns := db.BuildColumns(into)
240+
columns := db.columnMap.Columns(into)
254241

255242
return fmt.Sprintf(
256243
`INSERT INTO "%s" ("%s") VALUES (%s)`,
@@ -264,7 +251,7 @@ func (db *DB) BuildInsertStmt(into interface{}) (string, int) {
264251
// which the database ignores rows that have already been inserted.
265252
func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) {
266253
table := TableName(into)
267-
columns := db.BuildColumns(into)
254+
columns := db.columnMap.Columns(into)
268255
var clause string
269256

270257
switch db.DriverName() {
@@ -289,7 +276,7 @@ func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) {
289276
func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string {
290277
q := fmt.Sprintf(
291278
`SELECT "%s" FROM "%s"`,
292-
strings.Join(db.BuildColumns(columns), `", "`),
279+
strings.Join(db.columnMap.Columns(columns), `", "`),
293280
TableName(table),
294281
)
295282

@@ -303,7 +290,7 @@ func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string {
303290

304291
// BuildUpdateStmt returns an UPDATE statement for the given struct.
305292
func (db *DB) BuildUpdateStmt(update interface{}) (string, int) {
306-
columns := db.BuildColumns(update)
293+
columns := db.columnMap.Columns(update)
307294
set := make([]string, 0, len(columns))
308295

309296
for _, col := range columns {
@@ -319,12 +306,12 @@ func (db *DB) BuildUpdateStmt(update interface{}) (string, int) {
319306

320307
// BuildUpsertStmt returns an upsert statement for the given struct.
321308
func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders int) {
322-
insertColumns := db.BuildColumns(subject)
309+
insertColumns := db.columnMap.Columns(subject)
323310
table := TableName(subject)
324311
var updateColumns []string
325312

326313
if upserter, ok := subject.(Upserter); ok {
327-
updateColumns = db.BuildColumns(upserter.Upsert())
314+
updateColumns = db.columnMap.Columns(upserter.Upsert())
328315
} else {
329316
updateColumns = insertColumns
330317
}
@@ -358,7 +345,7 @@ func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders in
358345
// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct
359346
// combined with the AND operator.
360347
func (db *DB) BuildWhere(subject interface{}) (string, int) {
361-
columns := db.BuildColumns(subject)
348+
columns := db.columnMap.Columns(subject)
362349
where := make([]string, 0, len(columns))
363350
for _, col := range columns {
364351
where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col))

0 commit comments

Comments
 (0)