@@ -13,7 +13,7 @@ import (
13
13
// for an existing sql.DB connection.
14
14
// argFmt is the format string for argument placeholders like "?" or "$%d"
15
15
// that will be replaced error messages to format a complete query.
16
- func NewGenericConnection (ctx context.Context , db * sql.DB , config * sqldb.Config , listener sqldb.Listener , validateColumnName func (string ) error , argFmt string ) sqldb.Connection {
16
+ func NewGenericConnection (ctx context.Context , db * sql.DB , config * sqldb.Config , listener sqldb.Listener , structFieldMapper sqldb. StructFieldMapper , validateColumnName func (string ) error , argFmt string ) sqldb.Connection {
17
17
if listener == nil {
18
18
listener = sqldb .UnsupportedListener ()
19
19
}
@@ -22,9 +22,9 @@ func NewGenericConnection(ctx context.Context, db *sql.DB, config *sqldb.Config,
22
22
db : db ,
23
23
config : config ,
24
24
listener : listener ,
25
- structFieldNamer : sqldb .DefaultStructFieldMapping ,
26
- argFmt : argFmt ,
25
+ structFieldMapper : structFieldMapper ,
27
26
validateColumnName : validateColumnName ,
27
+ argFmt : argFmt ,
28
28
}
29
29
}
30
30
@@ -33,9 +33,9 @@ type genericConn struct {
33
33
db * sql.DB
34
34
config * sqldb.Config
35
35
listener sqldb.Listener
36
- structFieldNamer sqldb.StructFieldMapper
37
- argFmt string
36
+ structFieldMapper sqldb.StructFieldMapper
38
37
validateColumnName func (string ) error
38
+ argFmt string
39
39
40
40
tx * sql.Tx
41
41
txOptions * sql.TxOptions
@@ -58,14 +58,18 @@ func (conn *genericConn) WithContext(ctx context.Context) sqldb.Connection {
58
58
return c
59
59
}
60
60
61
- func (conn * genericConn ) WithStructFieldMapper (namer sqldb.StructFieldMapper ) sqldb.Connection {
61
+ func (conn * genericConn ) WithStructFieldMapper (mapper sqldb.StructFieldMapper ) sqldb.Connection {
62
62
c := conn .clone ()
63
- c .structFieldNamer = namer
63
+ c .structFieldMapper = mapper
64
64
return c
65
65
}
66
66
67
67
func (conn * genericConn ) StructFieldMapper () sqldb.StructFieldMapper {
68
- return conn .structFieldNamer
68
+ return conn .structFieldMapper
69
+ }
70
+
71
+ func (conn * genericConn ) ValidateColumnName (name string ) error {
72
+ return conn .validateColumnName (name )
69
73
}
70
74
71
75
func (conn * genericConn ) Ping (timeout time.Duration ) error {
@@ -86,40 +90,50 @@ func (conn *genericConn) Config() *sqldb.Config {
86
90
return conn .config
87
91
}
88
92
89
- func (conn * genericConn ) ValidateColumnName (name string ) error {
90
- return conn .validateColumnName (name )
91
- }
92
-
93
93
func (conn * genericConn ) Now () (time.Time , error ) {
94
94
return Now (conn )
95
95
}
96
96
97
+ func (conn * genericConn ) execer () Execer {
98
+ if conn .tx != nil {
99
+ return conn .tx
100
+ }
101
+ return conn .db
102
+ }
103
+
104
+ func (conn * genericConn ) queryer () Queryer {
105
+ if conn .tx != nil {
106
+ return conn .tx
107
+ }
108
+ return conn .db
109
+ }
110
+
97
111
func (conn * genericConn ) Exec (query string , args ... any ) error {
98
- return Exec (conn .ctx , conn .db , conn .argFmt , query , args )
112
+ return Exec (conn .ctx , conn .execer () , conn .argFmt , query , args )
99
113
}
100
114
101
115
func (conn * genericConn ) Insert (table string , columValues sqldb.Values ) error {
102
- return Insert (conn , table , conn .argFmt , columValues )
116
+ return Insert (conn . ctx , conn . execer () , table , conn .argFmt , columValues )
103
117
}
104
118
105
119
func (conn * genericConn ) InsertUnique (table string , values sqldb.Values , onConflict string ) (inserted bool , err error ) {
106
- return InsertUnique (conn , table , conn .argFmt , values , onConflict )
120
+ return InsertUnique (conn . ctx , conn . queryer () , conn .structFieldMapper , conn . argFmt , table , values , onConflict )
107
121
}
108
122
109
123
func (conn * genericConn ) InsertReturning (table string , values sqldb.Values , returning string ) sqldb.RowScanner {
110
124
return InsertReturning (conn , table , conn .argFmt , values , returning )
111
125
}
112
126
113
127
func (conn * genericConn ) InsertStruct (table string , rowStruct any , ignoreColumns ... sqldb.ColumnFilter ) error {
114
- return InsertStruct (conn , table , rowStruct , conn .structFieldNamer , conn .argFmt , ignoreColumns )
128
+ return InsertStruct (conn , table , rowStruct , conn .structFieldMapper , conn .argFmt , ignoreColumns )
115
129
}
116
130
117
131
func (conn * genericConn ) InsertStructs (table string , rowStructs any , ignoreColumns ... sqldb.ColumnFilter ) error {
118
132
return InsertStructs (conn , table , rowStructs , ignoreColumns ... )
119
133
}
120
134
121
135
func (conn * genericConn ) InsertUniqueStruct (table string , rowStruct any , onConflict string , ignoreColumns ... sqldb.ColumnFilter ) (inserted bool , err error ) {
122
- return InsertUniqueStruct (conn , table , rowStruct , onConflict , conn .structFieldNamer , conn .argFmt , ignoreColumns )
136
+ return InsertUniqueStruct (conn , table , rowStruct , onConflict , conn .structFieldMapper , conn .argFmt , ignoreColumns )
123
137
}
124
138
125
139
func (conn * genericConn ) Update (table string , values sqldb.Values , where string , args ... any ) error {
@@ -135,19 +149,31 @@ func (conn *genericConn) UpdateReturningRows(table string, values sqldb.Values,
135
149
}
136
150
137
151
func (conn * genericConn ) UpdateStruct (table string , rowStruct any , ignoreColumns ... sqldb.ColumnFilter ) error {
138
- return UpdateStruct (conn , table , rowStruct , conn .structFieldNamer , conn .argFmt , ignoreColumns )
152
+ return UpdateStruct (conn , table , rowStruct , conn .structFieldMapper , conn .argFmt , ignoreColumns )
139
153
}
140
154
141
155
func (conn * genericConn ) UpsertStruct (table string , rowStruct any , ignoreColumns ... sqldb.ColumnFilter ) error {
142
- return UpsertStruct (conn , table , rowStruct , conn .structFieldNamer , conn .argFmt , ignoreColumns )
156
+ return UpsertStruct (conn , table , rowStruct , conn .structFieldMapper , conn .argFmt , ignoreColumns )
143
157
}
144
158
145
159
func (conn * genericConn ) QueryRow (query string , args ... any ) sqldb.RowScanner {
146
- return QueryRow (conn .ctx , conn .db , conn .structFieldNamer , conn .argFmt , query , args )
160
+ var queryer Queryer
161
+ if conn .tx != nil {
162
+ queryer = conn .tx
163
+ } else {
164
+ queryer = conn .db
165
+ }
166
+ return QueryRow (conn .ctx , queryer , conn .structFieldMapper , conn .argFmt , query , args )
147
167
}
148
168
149
169
func (conn * genericConn ) QueryRows (query string , args ... any ) sqldb.RowsScanner {
150
- return QueryRows (conn .ctx , conn .db , conn .structFieldNamer , conn .argFmt , query , args )
170
+ var queryer Queryer
171
+ if conn .tx != nil {
172
+ queryer = conn .tx
173
+ } else {
174
+ queryer = conn .db
175
+ }
176
+ return QueryRows (conn .ctx , queryer , conn .structFieldMapper , conn .argFmt , query , args )
151
177
}
152
178
153
179
func (conn * genericConn ) IsTransaction () bool {
@@ -169,10 +195,10 @@ func (conn *genericConn) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection
169
195
}
170
196
return & genericConn {
171
197
ctx : conn .ctx ,
172
- db : nil ,
198
+ db : conn . db , // needed for PingContext, Stats
173
199
config : conn .config ,
174
200
listener : conn .listener ,
175
- structFieldNamer : conn .structFieldNamer ,
201
+ structFieldMapper : conn .structFieldMapper ,
176
202
argFmt : conn .argFmt ,
177
203
validateColumnName : conn .validateColumnName ,
178
204
tx : tx ,
0 commit comments