Skip to content

Commit 5c3a217

Browse files
committed
pqconn uses impl.NewGenericConnection
1 parent b38a83a commit 5c3a217

16 files changed

+160
-489
lines changed

connection.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ type Connection interface {
3131
// StructFieldMapper used by methods of this Connection.
3232
StructFieldMapper() StructFieldMapper
3333

34+
// ValidateColumnName returns an error
35+
// if the passed name is not valid for a
36+
// column of the connection's database.
37+
ValidateColumnName(name string) error
38+
3439
// Ping returns an error if the database
3540
// does not answer on this connection
3641
// with an optional timeout.
@@ -45,11 +50,6 @@ type Connection interface {
4550
// to create this connection.
4651
Config() *Config
4752

48-
// ValidateColumnName returns an error
49-
// if the passed name is not valid for a
50-
// column of the connection's database.
51-
ValidateColumnName(name string) error
52-
5353
// Now returns the result of the SQL now()
5454
// function for the current connection.
5555
// Useful for getting the timestamp of a

errors.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (e connectionWithError) WithContext(ctx context.Context) Connection {
8181
return connectionWithError{ctx: ctx, err: e.err}
8282
}
8383

84-
func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection {
84+
func (e connectionWithError) WithStructFieldMapper(StructFieldMapper) Connection {
8585
return e
8686
}
8787

impl/genericconnection.go

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
// for an existing sql.DB connection.
1414
// argFmt is the format string for argument placeholders like "?" or "$%d"
1515
// that will be replaced error messages to format a complete query.
16-
func NewGenericConnection(ctx context.Context, db *sql.DB, config *sqldb.Config, listener sqldb.Listener, validateColumnName func(string) error, argFmt string) sqldb.Connection {
16+
func NewGenericConnection(ctx context.Context, db *sql.DB, config *sqldb.Config, listener sqldb.Listener, structFieldMapper sqldb.StructFieldMapper, validateColumnName func(string) error, argFmt string) sqldb.Connection {
1717
if listener == nil {
1818
listener = sqldb.UnsupportedListener()
1919
}
@@ -22,9 +22,9 @@ func NewGenericConnection(ctx context.Context, db *sql.DB, config *sqldb.Config,
2222
db: db,
2323
config: config,
2424
listener: listener,
25-
structFieldNamer: sqldb.DefaultStructFieldMapping,
26-
argFmt: argFmt,
25+
structFieldMapper: structFieldMapper,
2726
validateColumnName: validateColumnName,
27+
argFmt: argFmt,
2828
}
2929
}
3030

@@ -33,9 +33,9 @@ type genericConn struct {
3333
db *sql.DB
3434
config *sqldb.Config
3535
listener sqldb.Listener
36-
structFieldNamer sqldb.StructFieldMapper
37-
argFmt string
36+
structFieldMapper sqldb.StructFieldMapper
3837
validateColumnName func(string) error
38+
argFmt string
3939

4040
tx *sql.Tx
4141
txOptions *sql.TxOptions
@@ -58,14 +58,18 @@ func (conn *genericConn) WithContext(ctx context.Context) sqldb.Connection {
5858
return c
5959
}
6060

61-
func (conn *genericConn) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection {
61+
func (conn *genericConn) WithStructFieldMapper(mapper sqldb.StructFieldMapper) sqldb.Connection {
6262
c := conn.clone()
63-
c.structFieldNamer = namer
63+
c.structFieldMapper = mapper
6464
return c
6565
}
6666

6767
func (conn *genericConn) StructFieldMapper() sqldb.StructFieldMapper {
68-
return conn.structFieldNamer
68+
return conn.structFieldMapper
69+
}
70+
71+
func (conn *genericConn) ValidateColumnName(name string) error {
72+
return conn.validateColumnName(name)
6973
}
7074

7175
func (conn *genericConn) Ping(timeout time.Duration) error {
@@ -86,40 +90,50 @@ func (conn *genericConn) Config() *sqldb.Config {
8690
return conn.config
8791
}
8892

89-
func (conn *genericConn) ValidateColumnName(name string) error {
90-
return conn.validateColumnName(name)
91-
}
92-
9393
func (conn *genericConn) Now() (time.Time, error) {
9494
return Now(conn)
9595
}
9696

97+
func (conn *genericConn) execer() Execer {
98+
if conn.tx != nil {
99+
return conn.tx
100+
}
101+
return conn.db
102+
}
103+
104+
func (conn *genericConn) queryer() Queryer {
105+
if conn.tx != nil {
106+
return conn.tx
107+
}
108+
return conn.db
109+
}
110+
97111
func (conn *genericConn) Exec(query string, args ...any) error {
98-
return Exec(conn.ctx, conn.db, conn.argFmt, query, args)
112+
return Exec(conn.ctx, conn.execer(), conn.argFmt, query, args)
99113
}
100114

101115
func (conn *genericConn) Insert(table string, columValues sqldb.Values) error {
102-
return Insert(conn, table, conn.argFmt, columValues)
116+
return Insert(conn.ctx, conn.execer(), table, conn.argFmt, columValues)
103117
}
104118

105119
func (conn *genericConn) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) {
106-
return InsertUnique(conn, table, conn.argFmt, values, onConflict)
120+
return InsertUnique(conn.ctx, conn.queryer(), conn.structFieldMapper, conn.argFmt, table, values, onConflict)
107121
}
108122

109123
func (conn *genericConn) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner {
110124
return InsertReturning(conn, table, conn.argFmt, values, returning)
111125
}
112126

113127
func (conn *genericConn) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error {
114-
return InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns)
128+
return InsertStruct(conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns)
115129
}
116130

117131
func (conn *genericConn) InsertStructs(table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error {
118132
return InsertStructs(conn, table, rowStructs, ignoreColumns...)
119133
}
120134

121135
func (conn *genericConn) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) {
122-
return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns)
136+
return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldMapper, conn.argFmt, ignoreColumns)
123137
}
124138

125139
func (conn *genericConn) Update(table string, values sqldb.Values, where string, args ...any) error {
@@ -135,19 +149,31 @@ func (conn *genericConn) UpdateReturningRows(table string, values sqldb.Values,
135149
}
136150

137151
func (conn *genericConn) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error {
138-
return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns)
152+
return UpdateStruct(conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns)
139153
}
140154

141155
func (conn *genericConn) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error {
142-
return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns)
156+
return UpsertStruct(conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns)
143157
}
144158

145159
func (conn *genericConn) QueryRow(query string, args ...any) sqldb.RowScanner {
146-
return QueryRow(conn.ctx, conn.db, conn.structFieldNamer, conn.argFmt, query, args)
160+
var queryer Queryer
161+
if conn.tx != nil {
162+
queryer = conn.tx
163+
} else {
164+
queryer = conn.db
165+
}
166+
return QueryRow(conn.ctx, queryer, conn.structFieldMapper, conn.argFmt, query, args)
147167
}
148168

149169
func (conn *genericConn) QueryRows(query string, args ...any) sqldb.RowsScanner {
150-
return QueryRows(conn.ctx, conn.db, conn.structFieldNamer, conn.argFmt, query, args)
170+
var queryer Queryer
171+
if conn.tx != nil {
172+
queryer = conn.tx
173+
} else {
174+
queryer = conn.db
175+
}
176+
return QueryRows(conn.ctx, queryer, conn.structFieldMapper, conn.argFmt, query, args)
151177
}
152178

153179
func (conn *genericConn) IsTransaction() bool {
@@ -169,10 +195,10 @@ func (conn *genericConn) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection
169195
}
170196
return &genericConn{
171197
ctx: conn.ctx,
172-
db: nil,
198+
db: conn.db, // needed for PingContext, Stats
173199
config: conn.config,
174200
listener: conn.listener,
175-
structFieldNamer: conn.structFieldNamer,
201+
structFieldMapper: conn.structFieldMapper,
176202
argFmt: conn.argFmt,
177203
validateColumnName: conn.validateColumnName,
178204
tx: tx,

impl/insert.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package impl
22

33
import (
4+
"context"
45
"fmt"
56
"reflect"
67
"strings"
@@ -9,7 +10,7 @@ import (
910
)
1011

1112
// Insert a new row into table using the values.
12-
func Insert(conn sqldb.Connection, table, argFmt string, values sqldb.Values) error {
13+
func Insert(ctx context.Context, conn Execer, table, argFmt string, values sqldb.Values) error {
1314
if len(values) == 0 {
1415
return fmt.Errorf("Insert into table %s: no values", table)
1516
}
@@ -19,15 +20,15 @@ func Insert(conn sqldb.Connection, table, argFmt string, values sqldb.Values) er
1920
writeInsertQuery(&b, table, argFmt, names)
2021
query := b.String()
2122

22-
err := conn.Exec(query, vals...)
23+
_, err := conn.ExecContext(ctx, query, vals...)
2324

2425
return WrapNonNilErrorWithQuery(err, query, argFmt, vals)
2526
}
2627

2728
// InsertUnique inserts a new row into table using the passed values
2829
// or does nothing if the onConflict statement applies.
2930
// Returns if a row was inserted.
30-
func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Values, onConflict string) (inserted bool, err error) {
31+
func InsertUnique(ctx context.Context, conn Queryer, mapper sqldb.StructFieldMapper, argFmt, table string, values sqldb.Values, onConflict string) (inserted bool, err error) {
3132
if len(values) == 0 {
3233
return false, fmt.Errorf("InsertUnique into table %s: no values", table)
3334
}
@@ -41,7 +42,7 @@ func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Valu
4142
writeInsertQuery(&query, table, argFmt, names)
4243
fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict)
4344

44-
err = conn.QueryRow(query.String(), vals...).Scan(&inserted)
45+
err = QueryRow(ctx, conn, mapper, argFmt, query.String(), vals).Scan(&inserted)
4546

4647
err = sqldb.ReplaceErrNoRows(err, nil)
4748
err = WrapNonNilErrorWithQuery(err, query.String(), argFmt, vals)
@@ -86,8 +87,8 @@ func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string)
8687
// InsertStruct inserts a new row into table using the connection's
8788
// StructFieldMapper to map struct fields to column names.
8889
// Optional ColumnFilter can be passed to ignore mapped columns.
89-
func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error {
90-
columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns)
90+
func InsertStruct(conn sqldb.Connection, table string, rowStruct any, mapper sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error {
91+
columns, vals, err := insertStructValues(table, rowStruct, mapper, ignoreColumns)
9192
if err != nil {
9293
return err
9394
}
@@ -106,8 +107,8 @@ func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld
106107
// Optional ColumnFilter can be passed to ignore mapped columns.
107108
// Does nothing if the onConflict statement applies
108109
// and returns if a row was inserted.
109-
func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onConflict string, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) (inserted bool, err error) {
110-
columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns)
110+
func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onConflict string, mapper sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) (inserted bool, err error) {
111+
columns, vals, err := insertStructValues(table, rowStruct, mapper, ignoreColumns)
111112
if err != nil {
112113
return false, err
113114
}
@@ -127,7 +128,7 @@ func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onCo
127128
return inserted, WrapNonNilErrorWithQuery(err, query, argFmt, vals)
128129
}
129130

130-
func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) {
131+
func insertStructValues(table string, rowStruct any, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) {
131132
v := reflect.ValueOf(rowStruct)
132133
for v.Kind() == reflect.Ptr && !v.IsNil() {
133134
v = v.Elem()
@@ -139,7 +140,7 @@ func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapp
139140
return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct)
140141
}
141142

142-
columns, _, vals = ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly))
143+
columns, _, vals = ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly))
143144
return columns, vals, nil
144145
}
145146

impl/query.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@ type Queryer interface {
1111
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
1212
}
1313

14-
func QueryRow(ctx context.Context, conn Queryer, namer sqldb.StructFieldMapper, argFmt, query string, args []any) sqldb.RowScanner {
14+
func QueryRow(ctx context.Context, conn Queryer, mapper sqldb.StructFieldMapper, argFmt, query string, args []any) sqldb.RowScanner {
1515
rows, err := conn.QueryContext(ctx, query, args...)
1616
if err != nil {
1717
err = WrapNonNilErrorWithQuery(err, query, argFmt, args)
1818
return sqldb.RowScannerWithError(err)
1919
}
20-
return NewRowScanner(rows, namer, query, argFmt, args)
20+
return NewRowScanner(rows, mapper, query, argFmt, args)
2121
}
2222

23-
func QueryRows(ctx context.Context, conn Queryer, namer sqldb.StructFieldMapper, argFmt, query string, args []any) sqldb.RowsScanner {
23+
func QueryRows(ctx context.Context, conn Queryer, mapper sqldb.StructFieldMapper, argFmt, query string, args []any) sqldb.RowsScanner {
2424
rows, err := conn.QueryContext(ctx, query, args...)
2525
if err != nil {
2626
err = WrapNonNilErrorWithQuery(err, query, argFmt, args)
2727
return sqldb.RowsScannerWithError(err)
2828
}
29-
return NewRowsScanner(ctx, rows, namer, query, argFmt, args)
29+
return NewRowsScanner(ctx, rows, mapper, query, argFmt, args)
3030
}

impl/reflectstruct.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@ import (
1111
"github.com/domonda/go-sqldb"
1212
)
1313

14-
func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) {
14+
func ReflectStructValues(structVal reflect.Value, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) {
1515
for i := 0; i < structVal.NumField(); i++ {
1616
fieldType := structVal.Type().Field(i)
17-
_, column, flags, use := namer.MapStructField(fieldType)
17+
_, column, flags, use := mapper.MapStructField(fieldType)
1818
if !use {
1919
continue
2020
}
2121
fieldValue := structVal.Field(i)
2222

2323
if column == "" {
2424
// Embedded struct field
25-
columnsEmbed, pkColsEmbed, valuesEmbed := ReflectStructValues(fieldValue, namer, ignoreColumns)
25+
columnsEmbed, pkColsEmbed, valuesEmbed := ReflectStructValues(fieldValue, mapper, ignoreColumns)
2626
for _, pkCol := range pkColsEmbed {
2727
pkCols = append(pkCols, pkCol+len(columns))
2828
}
@@ -44,12 +44,12 @@ func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper,
4444
return columns, pkCols, values
4545
}
4646

47-
func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string) (pointers []any, err error) {
47+
func ReflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string) (pointers []any, err error) {
4848
if len(columns) == 0 {
4949
return nil, errors.New("no columns")
5050
}
5151
pointers = make([]any, len(columns))
52-
err = reflectStructColumnPointers(structVal, namer, columns, pointers)
52+
err = reflectStructColumnPointers(structVal, mapper, columns, pointers)
5353
if err != nil {
5454
return nil, err
5555
}
@@ -72,19 +72,19 @@ func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel
7272
return pointers, nil
7373
}
7474

75-
func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error {
75+
func reflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string, pointers []any) error {
7676
structType := structVal.Type()
7777
for i := 0; i < structType.NumField(); i++ {
7878
field := structType.Field(i)
79-
_, column, _, use := namer.MapStructField(field)
79+
_, column, _, use := mapper.MapStructField(field)
8080
if !use {
8181
continue
8282
}
8383
fieldValue := structVal.Field(i)
8484

8585
if column == "" {
8686
// Embedded struct field
87-
err := reflectStructColumnPointers(fieldValue, namer, columns, pointers)
87+
err := reflectStructColumnPointers(fieldValue, mapper, columns, pointers)
8888
if err != nil {
8989
return err
9090
}

impl/scanstruct.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
sqldb "github.com/domonda/go-sqldb"
88
)
99

10-
func ScanStruct(srcRow Row, destStruct any, namer sqldb.StructFieldMapper) error {
10+
func ScanStruct(srcRow Row, destStruct any, mapper sqldb.StructFieldMapper) error {
1111
v := reflect.ValueOf(destStruct)
1212
for v.Kind() == reflect.Ptr && !v.IsNil() {
1313
v = v.Elem()
@@ -35,7 +35,7 @@ func ScanStruct(srcRow Row, destStruct any, namer sqldb.StructFieldMapper) error
3535
return err
3636
}
3737

38-
fieldPointers, err := ReflectStructColumnPointers(v, namer, columns)
38+
fieldPointers, err := ReflectStructColumnPointers(v, mapper, columns)
3939
if err != nil {
4040
return fmt.Errorf("ScanStruct: %w", err)
4141
}

0 commit comments

Comments
 (0)