Skip to content

Commit 2d1ca47

Browse files
committed
added Connection.ValidateColumnName, rename db.QueryRowStruct to QueryRow, rename Connection.StructFieldNamer to StructFieldMapper
1 parent 5e1ff3a commit 2d1ca47

18 files changed

+136
-58
lines changed

connection.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ type Connection interface {
2424
// context for its operations.
2525
WithContext(ctx context.Context) Connection
2626

27-
// WithStructFieldNamer returns a copy of the connection
28-
// that will use the passed StructFieldNamer.
29-
WithStructFieldNamer(namer StructFieldMapper) Connection
27+
// WithStructFieldMapper returns a copy of the connection
28+
// that will use the passed StructFieldMapper.
29+
WithStructFieldMapper(StructFieldMapper) Connection
3030

31-
// StructFieldNamer used by methods of this Connection.
32-
StructFieldNamer() StructFieldMapper
31+
// StructFieldMapper used by methods of this Connection.
32+
StructFieldMapper() StructFieldMapper
3333

3434
// Ping returns an error if the database
3535
// does not answer on this connection
@@ -45,6 +45,11 @@ type Connection interface {
4545
// to create this connection.
4646
Config() *Config
4747

48+
// ValidateColumnName returns an error
49+
// if the passed name is not valid for a
50+
// column of the connection's database.
51+
ValidateColumnName(name string) error
52+
4853
// Now returns the result of the SQL now()
4954
// function for the current connection.
5055
// Useful for getting the timestamp of a
@@ -67,12 +72,12 @@ type Connection interface {
6772
InsertReturning(table string, values Values, returning string) RowScanner
6873

6974
// InsertStruct inserts a new row into table using the connection's
70-
// StructFieldNamer to map struct fields to column names.
75+
// StructFieldMapper to map struct fields to column names.
7176
// Optional ColumnFilter can be passed to ignore mapped columns.
7277
InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error
7378

7479
// InsertUniqueStruct inserts a new row into table using the connection's
75-
// StructFieldNamer to map struct fields to column names.
80+
// StructFieldMapper to map struct fields to column names.
7681
// Optional ColumnFilter can be passed to ignore mapped columns.
7782
// Does nothing if the onConflict statement applies
7883
// and returns if a row was inserted.

db/queryrow.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@ import (
99
"github.com/domonda/go-sqldb"
1010
)
1111

12-
// QueryRowStruct uses the passed pkValues to query a table row
12+
// QueryStruct uses the passed pkValues to query a table row
1313
// and scan it into a struct of type S that must have tagged fields
1414
// with primary key flags to identify the primary key column names
1515
// for the passed pkValues and a table name.
16-
func QueryRowStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error) {
16+
func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error) {
1717
if len(pkValues) == 0 {
18-
return nil, errors.New("no primaryKeyValues passed")
18+
return nil, errors.New("missing primary key values")
1919
}
2020
t := reflect.TypeOf(row).Elem()
2121
if t.Kind() != reflect.Struct {
2222
return nil, fmt.Errorf("expected struct template type instead of %s", t)
2323
}
2424
conn := Conn(ctx)
25-
table, pkColumns, err := pkColumnsOfStruct(t, conn.StructFieldNamer())
25+
table, pkColumns, err := pkColumnsOfStruct(conn, t)
2626
if err != nil {
2727
return nil, err
2828
}
@@ -40,7 +40,8 @@ func QueryRowStruct[S any](ctx context.Context, pkValues ...any) (row *S, err er
4040
return row, nil
4141
}
4242

43-
func pkColumnsOfStruct(t reflect.Type, mapper sqldb.StructFieldMapper) (table string, columns []string, err error) {
43+
func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) {
44+
mapper := conn.StructFieldMapper()
4445
for i := 0; i < t.NumField(); i++ {
4546
field := t.Field(i)
4647
fieldTable, column, flags, ok := mapper.MapStructField(field)
@@ -55,7 +56,7 @@ func pkColumnsOfStruct(t reflect.Type, mapper sqldb.StructFieldMapper) (table st
5556
}
5657

5758
if column == "" {
58-
fieldTable, columnsEmbed, err := pkColumnsOfStruct(field.Type, mapper)
59+
fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type)
5960
if err != nil {
6061
return "", nil, err
6162
}
@@ -67,6 +68,9 @@ func pkColumnsOfStruct(t reflect.Type, mapper sqldb.StructFieldMapper) (table st
6768
}
6869
columns = append(columns, columnsEmbed...)
6970
} else if flags.PrimaryKey() {
71+
if err = conn.ValidateColumnName(column); err != nil {
72+
return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name)
73+
}
7074
columns = append(columns, column)
7175
}
7276
}

errors.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ func (e connectionWithError) WithContext(ctx context.Context) Connection {
8080
return connectionWithError{ctx: ctx, err: e.err}
8181
}
8282

83-
func (e connectionWithError) WithStructFieldNamer(namer StructFieldMapper) Connection {
83+
func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection {
8484
return e
8585
}
8686

87-
func (e connectionWithError) StructFieldNamer() StructFieldMapper {
87+
func (e connectionWithError) StructFieldMapper() StructFieldMapper {
8888
return DefaultStructFieldMapping
8989
}
9090

@@ -100,6 +100,10 @@ func (e connectionWithError) Config() *Config {
100100
return &Config{Err: e.err}
101101
}
102102

103+
func (e connectionWithError) ValidateColumnName(name string) error {
104+
return e.err
105+
}
106+
103107
func (e connectionWithError) Now() (time.Time, error) {
104108
return time.Time{}, e.err
105109
}

examples/user_demo/user_demo.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func main() {
4545
panic(err)
4646
}
4747

48-
conn = conn.WithStructFieldNamer(&sqldb.TaggedStructFieldMapping{
48+
conn = conn.WithStructFieldMapper(&sqldb.TaggedStructFieldMapping{
4949
NameTag: "col",
5050
Ignore: "ignore",
5151
UntaggedNameFunc: sqldb.ToSnakeCase,

impl/connection.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,24 @@ import (
1313
// for an existing sql.DB connection.
1414
// argFmt is the format string for argument placeholders like "?" or "$%d"
1515
// that will be replaced error messages to format a complete query.
16-
func Connection(ctx context.Context, db *sql.DB, config *sqldb.Config, argFmt string) sqldb.Connection {
16+
func Connection(ctx context.Context, db *sql.DB, config *sqldb.Config, validateColumnName func(string) error, argFmt string) sqldb.Connection {
1717
return &connection{
18-
ctx: ctx,
19-
db: db,
20-
config: config,
21-
structFieldNamer: sqldb.DefaultStructFieldMapping,
22-
argFmt: argFmt,
18+
ctx: ctx,
19+
db: db,
20+
config: config,
21+
structFieldNamer: sqldb.DefaultStructFieldMapping,
22+
argFmt: argFmt,
23+
validateColumnName: validateColumnName,
2324
}
2425
}
2526

2627
type connection struct {
27-
ctx context.Context
28-
db *sql.DB
29-
config *sqldb.Config
30-
structFieldNamer sqldb.StructFieldMapper
31-
argFmt string
28+
ctx context.Context
29+
db *sql.DB
30+
config *sqldb.Config
31+
structFieldNamer sqldb.StructFieldMapper
32+
argFmt string
33+
validateColumnName func(string) error
3234
}
3335

3436
func (conn *connection) clone() *connection {
@@ -47,13 +49,13 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection {
4749
return c
4850
}
4951

50-
func (conn *connection) WithStructFieldNamer(namer sqldb.StructFieldMapper) sqldb.Connection {
52+
func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection {
5153
c := conn.clone()
5254
c.structFieldNamer = namer
5355
return c
5456
}
5557

56-
func (conn *connection) StructFieldNamer() sqldb.StructFieldMapper {
58+
func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper {
5759
return conn.structFieldNamer
5860
}
5961

@@ -75,6 +77,10 @@ func (conn *connection) Config() *sqldb.Config {
7577
return conn.config
7678
}
7779

80+
func (conn *connection) ValidateColumnName(name string) error {
81+
return conn.validateColumnName(name)
82+
}
83+
7884
func (conn *connection) Now() (time.Time, error) {
7985
return Now(conn)
8086
}

impl/insert.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string)
8484
}
8585

8686
// InsertStruct inserts a new row into table using the connection's
87-
// StructFieldNamer to map struct fields to column names.
87+
// StructFieldMapper to map struct fields to column names.
8888
// Optional ColumnFilter can be passed to ignore mapped columns.
8989
func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error {
9090
columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns)
@@ -102,7 +102,7 @@ func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld
102102
}
103103

104104
// InsertUniqueStruct inserts a new row into table using the connection's
105-
// StructFieldNamer to map struct fields to column names.
105+
// StructFieldMapper to map struct fields to column names.
106106
// Optional ColumnFilter can be passed to ignore mapped columns.
107107
// Does nothing if the onConflict statement applies
108108
// and returns if a row was inserted.

impl/rowscanner.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ func (s *RowScanner) Columns() ([]string, error) {
7777

7878
// CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close
7979
type CurrentRowScanner struct {
80-
Rows Rows
81-
StructFieldNamer sqldb.StructFieldMapper
80+
Rows Rows
81+
StructFieldMapper sqldb.StructFieldMapper
8282
}
8383

8484
func (s CurrentRowScanner) Scan(dest ...any) error {
8585
return s.Rows.Scan(dest...)
8686
}
8787

8888
func (s CurrentRowScanner) ScanStruct(dest any) error {
89-
return ScanStruct(s.Rows, dest, s.StructFieldNamer)
89+
return ScanStruct(s.Rows, dest, s.StructFieldMapper)
9090
}
9191

9292
func (s CurrentRowScanner) ScanValues() ([]any, error) {
@@ -103,16 +103,16 @@ func (s CurrentRowScanner) Columns() ([]string, error) {
103103

104104
// SingleRowScanner always uses the same Row
105105
type SingleRowScanner struct {
106-
Row Row
107-
StructFieldNamer sqldb.StructFieldMapper
106+
Row Row
107+
StructFieldMapper sqldb.StructFieldMapper
108108
}
109109

110110
func (s SingleRowScanner) Scan(dest ...any) error {
111111
return s.Row.Scan(dest...)
112112
}
113113

114114
func (s SingleRowScanner) ScanStruct(dest any) error {
115-
return ScanStruct(s.Row, dest, s.StructFieldNamer)
115+
return ScanStruct(s.Row, dest, s.StructFieldMapper)
116116
}
117117

118118
func (s SingleRowScanner) ScanValues() ([]any, error) {

impl/scanstruct_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ package impl
3131

3232
// type args struct {
3333
// t reflect.Type
34-
// namer sqldb.StructFieldNamer
34+
// namer sqldb.StructFieldMapper
3535
// }
3636
// tests := []struct {
3737
// name string

impl/transaction.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,24 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection {
4343
return newTransaction(parent, conn.tx, conn.opts)
4444
}
4545

46-
func (conn *transaction) WithStructFieldNamer(namer sqldb.StructFieldMapper) sqldb.Connection {
46+
func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection {
4747
c := conn.clone()
4848
c.structFieldNamer = namer
4949
return c
5050
}
5151

52-
func (conn *transaction) StructFieldNamer() sqldb.StructFieldMapper {
52+
func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper {
5353
return conn.structFieldNamer
5454
}
5555

5656
func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) }
5757
func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() }
5858
func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() }
5959

60+
func (conn *transaction) ValidateColumnName(name string) error {
61+
return conn.parent.validateColumnName(name)
62+
}
63+
6064
func (conn *transaction) Now() (time.Time, error) {
6165
return Now(conn)
6266
}

mockconn/connection.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection {
4949
}
5050
}
5151

52-
func (conn *connection) WithStructFieldNamer(namer sqldb.StructFieldMapper) sqldb.Connection {
52+
func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection {
5353
return &connection{
5454
ctx: conn.ctx,
5555
queryWriter: conn.queryWriter,
@@ -60,7 +60,7 @@ func (conn *connection) WithStructFieldNamer(namer sqldb.StructFieldMapper) sqld
6060
}
6161
}
6262

63-
func (conn *connection) StructFieldNamer() sqldb.StructFieldMapper {
63+
func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper {
6464
return conn.structFieldNamer
6565
}
6666

@@ -72,6 +72,10 @@ func (conn *connection) Config() *sqldb.Config {
7272
return &sqldb.Config{Driver: "mockconn", Host: "localhost", Database: "mock"}
7373
}
7474

75+
func (conn *connection) ValidateColumnName(name string) error {
76+
return validateColumnName(name)
77+
}
78+
7579
func (conn *connection) Ping(time.Duration) error {
7680
return nil
7781
}

0 commit comments

Comments
 (0)