diff --git a/anyvalue.go b/anyvalue.go index 9c660aa..39fa86b 100644 --- a/anyvalue.go +++ b/anyvalue.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "slices" "unicode/utf8" ) @@ -14,23 +15,24 @@ var ( _ fmt.GoStringer = AnyValue{} ) -// AnyValue wraps a driver.Value and is useful for +// AnyValue can hold any value and is useful for // generic code that can handle unknown column types. // // AnyValue implements the following interfaces: -// database/sql.Scanner -// database/sql/driver.Valuer -// fmt.Stringer -// fmt.GoStringer +// - database/sql.Scanner +// - database/sql/driver.Valuer +// - fmt.Stringer +// - fmt.GoStringer // -// When scanned, Val can have one of the following underlying types: -// int64 -// float64 -// bool -// []byte -// string -// time.Time -// nil - for NULL values +// When scanned with the Scan method +// Val will have one of the following types: +// - int64 +// - float64 +// - bool +// - []byte +// - string +// - time.Time +// - nil (for SQL NULL values) type AnyValue struct { Val any } @@ -39,7 +41,7 @@ type AnyValue struct { func (any *AnyValue) Scan(val any) error { if b, ok := val.([]byte); ok { // Copy bytes because they won't be valid after this method call - any.Val = append([]byte(nil), b...) + any.Val = slices.Clone(b) } else { any.Val = val } @@ -52,7 +54,7 @@ func (any AnyValue) Value() (driver.Value, error) { } // String returns the value formatted as string using fmt.Sprint -// except when it's of type []byte and valid UTF-8, +// except when it is of type []byte and valid UTF-8, // then it is directly converted into a string. func (any AnyValue) String() string { if b, ok := any.Val.([]byte); ok && utf8.Valid(b) { @@ -64,7 +66,7 @@ func (any AnyValue) String() string { // GoString returns a Go representation of the wrapped value. func (any AnyValue) GoString() string { if b, ok := any.Val.([]byte); ok && utf8.Valid(b) { - return fmt.Sprintf("[]byte(%q)", b) + return fmt.Sprintf("[]byte(%#v)", string(b)) } return fmt.Sprintf("%#v", any.Val) } diff --git a/impl/arrays.go b/arrays.go similarity index 50% rename from impl/arrays.go rename to arrays.go index d83fb7a..42e5f64 100644 --- a/impl/arrays.go +++ b/arrays.go @@ -1,22 +1,56 @@ -package impl +package sqldb import ( "database/sql" - "database/sql/driver" "reflect" - - "github.com/lib/pq" ) -func WrapForArray(a interface{}) interface { - driver.Valuer - sql.Scanner -} { - return pq.Array(a) +// func WrapForArray(a any) interface { +// driver.Valuer +// sql.Scanner +// } { +// return pq.Array(a) +// } + +type ArrayHandler interface { + AsArrayScanner(dest any) sql.Scanner +} + +func MakeArrayScannable(dest []any, arrayHandler ArrayHandler) []any { + if arrayHandler == nil { + return dest + } + var wrappedDest []any + for i, d := range dest { + if ShouldWrapForArrayScanning(reflect.ValueOf(d).Elem()) { + if wrappedDest == nil { + // Allocate new slice for wrapped element + wrappedDest = make([]any, len(dest)) + // Copy previous elements + for h := 0; h < i; h++ { + wrappedDest[h] = dest[h] + } + } + wrappedDest[i] = arrayHandler.AsArrayScanner(d) + } else if wrappedDest != nil { + wrappedDest[i] = d + } + } + if wrappedDest != nil { + return wrappedDest + } + return dest } -func ShouldWrapForArray(v reflect.Value) bool { +func ShouldWrapForArrayScanning(v reflect.Value) bool { t := v.Type() + if t.Implements(typeOfSQLScanner) { + return false + } + if t.Kind() == reflect.Pointer && !v.IsNil() { + v = v.Elem() + t = v.Type() + } switch t.Kind() { case reflect.Slice: if t.Elem() == typeOfByte { @@ -29,6 +63,64 @@ func ShouldWrapForArray(v reflect.Value) bool { return false } +// IsSliceOrArray returns true if passed value is a slice or array, +// or a pointer to a slice or array and in case of a slice +// not of type []byte. +func IsSliceOrArray(value any) bool { + if value == nil { + return false + } + v := reflect.ValueOf(value) + if v.Kind() == reflect.Pointer { + if v.IsNil() { + return false + } + v = v.Elem() + } + t := v.Type() + k := t.Kind() + return k == reflect.Slice && t != typeOfByteSlice || k == reflect.Array +} + +// IsNonDriverValuerSliceOrArrayType returns true if passed type +// does not implement driver.Valuer and is a slice or array, +// or a pointer to a slice or array and in case of a slice +// not of type []byte. +func IsNonDriverValuerSliceOrArrayType(t reflect.Type) bool { + if t == nil || t.Implements(typeOfDriverValuer) { + return false + } + k := t.Kind() + if k == reflect.Pointer { + t = t.Elem() + k = t.Kind() + } + return k == reflect.Slice && t != typeOfByteSlice || k == reflect.Array +} + +// func FormatArrays(args []any) []any { +// var wrappedArgs []any +// for i, arg := range args { +// if ShouldFormatArray(arg) { +// if wrappedArgs == nil { +// // Allocate new slice for wrapped element +// wrappedArgs = make([]any, len(args)) +// // Copy previous elements +// for h := 0; h < i; h++ { +// wrappedArgs[h] = args[h] +// } +// } +// wrappedArgs[i], _ = pq.Array(arg).Value() +// } else if wrappedArgs != nil { +// wrappedArgs[i] = arg +// } +// } +// if wrappedArgs != nil { +// return wrappedArgs +// } +// return args +// } + // type ArrayScanner struct { // Dest reflect.Value // } @@ -81,7 +173,7 @@ func ShouldWrapForArray(v reflect.Value) bool { // } else { // newDest = reflect.New(a.Dest.Type()).Elem() // } -// if reflect.PtrTo(elemType).Implements(typeOfSQLScanner) { +// if reflect.PointerTo(elemType).Implements(typeOfSQLScanner) { // for i, elemStr := range elems { // err = newDest.Index(i).Addr().Interface().(sql.Scanner).Scan(elemStr) // if err != nil { diff --git a/arrays_test.go b/arrays_test.go new file mode 100644 index 0000000..56f435d --- /dev/null +++ b/arrays_test.go @@ -0,0 +1,94 @@ +package sqldb + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + "testing" + + "github.com/lib/pq" + "github.com/stretchr/testify/assert" + + "github.com/domonda/go-types/nullable" +) + +func defaultAsArrayScanner(a any) sql.Scanner { + return pq.Array(a) // TODO replace with own implementation +} + +func TestShouldWrapForArrayScanning(t *testing.T) { + tests := []struct { + v reflect.Value + want bool + }{ + {v: reflect.ValueOf([]byte(nil)), want: false}, + {v: reflect.ValueOf([]byte{}), want: false}, + {v: reflect.ValueOf(""), want: false}, + {v: reflect.ValueOf(0), want: false}, + {v: reflect.ValueOf(json.RawMessage([]byte("null"))), want: false}, + {v: reflect.ValueOf(nullable.JSON([]byte("null"))), want: false}, + {v: reflect.ValueOf(new(sql.NullInt64)).Elem(), want: false}, + {v: reflect.ValueOf(defaultAsArrayScanner([]int{0, 1})), want: false}, + + {v: reflect.ValueOf(new([3]string)).Elem(), want: true}, + {v: reflect.ValueOf(new([]string)).Elem(), want: true}, + {v: reflect.ValueOf(new([]sql.NullString)).Elem(), want: true}, + } + for _, tt := range tests { + got := ShouldWrapForArrayScanning(tt.v) + assert.Equal(t, tt.want, got) + } +} + +func TestIsNonDriverValuerSliceOrArrayType(t *testing.T) { + tests := []struct { + t reflect.Type + want bool + }{ + {t: reflect.TypeOf(nil), want: false}, + {t: reflect.TypeOf(0), want: false}, + {t: reflect.TypeOf(new(int)), want: false}, + {t: reflect.TypeOf("string"), want: false}, + {t: reflect.TypeOf([]byte("string")), want: false}, + {t: reflect.TypeOf(new([]byte)), want: false}, + {t: reflect.TypeOf(pq.BoolArray{true}), want: false}, + {t: reflect.TypeOf(new(pq.BoolArray)), want: false}, + {t: reflect.TypeOf(new(*[]int)), want: false}, // pointer to a pointer to a slice + {t: reflect.TypeOf((*driver.Valuer)(nil)), want: false}, + {t: reflect.TypeOf((*driver.Valuer)(nil)).Elem(), want: false}, + + {t: reflect.TypeOf([3]int{1, 2, 3}), want: true}, + {t: reflect.TypeOf((*[3]int)(nil)), want: true}, + {t: reflect.TypeOf([]int{1, 2, 3}), want: true}, + {t: reflect.TypeOf((*[]int)(nil)), want: true}, + {t: reflect.TypeOf((*[][]byte)(nil)), want: true}, + } + for _, tt := range tests { + got := IsNonDriverValuerSliceOrArrayType(tt.t) + assert.Equalf(t, tt.want, got, "IsNonDriverValuerSliceOrArrayType(%s)", tt.t) + } +} + +// func TestWrapArgsForArrays(t *testing.T) { +// tests := []struct { +// args []any +// want []any +// }{ +// {args: nil, want: nil}, +// {args: []any{}, want: []any{}}, +// {args: []any{0}, want: []any{0}}, +// {args: []any{nil}, want: []any{nil}}, +// {args: []any{new(int)}, want: []any{new(int)}}, +// {args: []any{0, []int{0, 1}, "string"}, want: []any{0, wrapArgForArray([]int{0, 1}), "string"}}, +// {args: []any{wrapArgForArray([]int{0, 1})}, want: []any{wrapArgForArray([]int{0, 1})}}, +// {args: []any{[]byte("don't wrap []byte")}, want: []any{[]byte("don't wrap []byte")}}, +// {args: []any{pq.BoolArray{true}}, want: []any{pq.BoolArray{true}}}, +// {args: []any{[3]int{1, 2, 3}}, want: []any{wrapArgForArray([3]int{1, 2, 3})}}, +// {args: []any{wrapArgForArray([3]int{1, 2, 3})}, want: []any{wrapArgForArray([3]int{1, 2, 3})}}, +// } +// for _, tt := range tests { +// got := WrapArgsForArrays(tt.args) +// assert.Equal(t, tt.want, got) +// } +// } diff --git a/config.go b/config.go index 071fa7e..6b90649 100644 --- a/config.go +++ b/config.go @@ -47,18 +47,12 @@ type Config struct { // // If ConnMaxLifetime <= 0, connections are not closed due to a connection's age. ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"` - - DefaultIsolationLevel sql.IsolationLevel `json:"-"` - Err error `json:"-"` } // Validate returns Config.Err if it is not nil // or an error if the Config does not have // a Driver, Host, or Database. func (c *Config) Validate() error { - if c.Err != nil { - return c.Err - } if c.Driver == "" { return fmt.Errorf("missing sqldb.Config.Driver") } diff --git a/connection.go b/connection.go index be4fbae..a9c63bc 100644 --- a/connection.go +++ b/connection.go @@ -1,180 +1,111 @@ package sqldb import ( - "context" "database/sql" + "errors" "time" + + "golang.org/x/net/context" ) -type ( - // OnNotifyFunc is a callback type passed to Connection.ListenOnChannel - OnNotifyFunc func(channel, payload string) +var ( + globalConnection Connection = ErrorConnection(errors.New("sqldb not initialized, use sqldb.SetGlobalConnection!")) - // OnUnlistenFunc is a callback type passed to Connection.ListenOnChannel - OnUnlistenFunc func(channel string) + connectionCtxKey int ) -// Connection represents a database connection or transaction -type Connection interface { - // Context that all connection operations use. - // See also WithContext. - Context() context.Context +// GlobalConnection returns the global connection +// that will never be nil but an ErrorConnection +// if not initialized with SetGlobalConnection. +func GlobalConnection() Connection { + return globalConnection +} - // WithContext returns a connection that uses the passed - // context for its operations. - WithContext(ctx context.Context) Connection +// SetGlobalConnection sets the global connection +// that will be returned by ContextConnection +// if the context has no other connection. +// +// This function is not thread-safe becaues the global connection +// is not expected to change between threads. +func SetGlobalConnection(conn Connection) { + if conn == nil { + panic(" Connection") + } + globalConnection = conn +} - // WithStructFieldMapper returns a copy of the connection - // that will use the passed StructFieldMapper. - WithStructFieldMapper(StructFieldMapper) Connection +// ContextConnection returns the connection added +// to the context or the global connection +// if the context has no connection. +// +// See ContextWithConnection and SetGlobalConnection. +func ContextConnection(ctx context.Context) Connection { + return ContextConnectionOr(ctx, globalConnection) +} - // StructFieldMapper used by methods of this Connection. - StructFieldMapper() StructFieldMapper +// ContextConnectionOr returns the connection added +// to the context or the passed defaultConn +// if the context has no connection. +// +// See ContextWithConnection and SetGlobalConnection. +func ContextConnectionOr(ctx context.Context, defaultConn Connection) Connection { + if ctxConn, ok := ctx.Value(&connectionCtxKey).(Connection); ok { + return ctxConn + } + return globalConnection +} - // Ping returns an error if the database - // does not answer on this connection - // with an optional timeout. - // The passed timeout has to be greater zero - // to be considered. - Ping(timeout time.Duration) error +// ContextWithConnection returns a new context with the passed Connection +// added as value so it can be retrieved again using ContextConnection(ctx). +func ContextWithConnection(ctx context.Context, conn Connection) context.Context { + if conn == nil { + panic(" Connection") + } + return context.WithValue(ctx, &connectionCtxKey, conn) +} - // Stats returns the sql.DBStats of this connection. - Stats() sql.DBStats +type FullyFeaturedConnection interface { + Connection + TxConnection + NotificationConnection +} + +// Connection represents a database connection or transaction +type Connection interface { + QueryFormatter + StructFieldMapper + + // String returns information about the connection + String() string + + // DatabaseKind returns the vendor name of the database kind + DatabaseKind() string + + // Err returns an error if the connection + // is in some non-working state. + Err() error // Config returns the configuration used // to create this connection. Config() *Config - // ValidateColumnName returns an error - // if the passed name is not valid for a - // column of the connection's database. - ValidateColumnName(name string) error - - // Now returns the result of the SQL now() - // function for the current connection. - // Useful for getting the timestamp of a - // SQL transaction for use in Go code. - Now() (time.Time, error) - - // Exec executes a query with optional args. - Exec(query string, args ...any) error - - // Insert a new row into table using the values. - Insert(table string, values Values) error - - // InsertUnique inserts a new row into table using the passed values - // or does nothing if the onConflict statement applies. - // Returns if a row was inserted. - InsertUnique(table string, values Values, onConflict string) (inserted bool, err error) - - // InsertReturning inserts a new row into table using values - // and returns values from the inserted row listed in returning. - InsertReturning(table string, values Values, returning string) RowScanner - - // InsertStruct inserts a new row into table using the connection's - // StructFieldMapper to map struct fields to column names. - // Optional ColumnFilter can be passed to ignore mapped columns. - InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // InsertStructs inserts a slice or array of structs - // as new rows into table using the connection's - // StructFieldMapper to map struct fields to column names. - // Optional ColumnFilter can be passed to ignore mapped columns. - // - // TODO optimized version with single query if possible - // split into multiple queries depending or maxArgs for query - InsertStructs(table string, rowStructs any, ignoreColumns ...ColumnFilter) error - - // InsertUniqueStruct inserts a new row into table using the connection's - // StructFieldMapper to map struct fields to column names. - // Optional ColumnFilter can be passed to ignore mapped columns. - // Does nothing if the onConflict statement applies - // and returns if a row was inserted. - InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error) - - // Update table rows(s) with values using the where statement with passed in args starting at $1. - Update(table string, values Values, where string, args ...any) error - - // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 - // and returning a single row with the columns specified in returning argument. - UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner - - // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 - // and returning multiple rows with the columns specified in returning argument. - UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner - - // UpdateStruct updates a row in a table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // UpsertStruct upserts a row to table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - // If inserting conflicts on the primary key column(s), then an update is performed. - UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // QueryRow queries a single row and returns a RowScanner for the results. - QueryRow(query string, args ...any) RowScanner - - // QueryRows queries multiple rows and returns a RowsScanner for the results. - QueryRows(query string, args ...any) RowsScanner + // DBStats returns the sql.DBStats of this connection. + DBStats() sql.DBStats // IsTransaction returns if the connection is a transaction IsTransaction() bool - // TransactionNo returns the globally unique number of the transaction - // or zero if the connection is not a transaction. - // Implementations should use the package function NextTransactionNo - // to aquire a new number in a threadsafe way. - TransactionNo() uint64 - - // TransactionOptions returns the sql.TxOptions of the - // current transaction and true as second result value, - // or false if the connection is not a transaction. - TransactionOptions() (*sql.TxOptions, bool) - - // Begin a new transaction. - // If the connection is already a transaction then a brand - // new transaction will begin on the parent's connection. - // The passed no will be returnd from the transaction's - // Connection.TransactionNo method. - // Implementations should use the package function NextTransactionNo - // to aquire a new number in a threadsafe way. - Begin(opts *sql.TxOptions, no uint64) (Connection, error) - - // Commit the current transaction. - // Returns ErrNotWithinTransaction if the connection - // is not within a transaction. - Commit() error - - // Rollback the current transaction. - // Returns ErrNotWithinTransaction if the connection - // is not within a transaction. - Rollback() error - - // ListenOnChannel will call onNotify for every channel notification - // and onUnlisten if the channel gets unlistened - // or the listener connection gets closed for some reason. - // It is valid to pass nil for onNotify or onUnlisten to not get those callbacks. - // Note that the callbacks are called in sequence from a single go routine, - // so callbacks should offload long running or potentially blocking code to other go routines. - // Panics from callbacks will be recovered and logged. - ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error - - // UnlistenChannel will stop listening on the channel. - // An error is returned, when the channel was not listened to - // or the listener connection is closed. - UnlistenChannel(channel string) error - - // IsListeningOnChannel returns if a channel is listened to. - IsListeningOnChannel(channel string) bool + // Ping returns an error if the database + // does not answer on this connection + // with an optional timeout. + // The passed timeout has to be greater zero + // to be considered. + Ping(ctx context.Context, timeout time.Duration) error + + // Exec executes a query with optional args. + Exec(ctx context.Context, query string, args ...any) error + + Query(ctx context.Context, query string, args ...any) (Rows, error) // Close the connection. // Transactions will be rolled back. diff --git a/connectionimpl.go b/connectionimpl.go new file mode 100644 index 0000000..62612f6 --- /dev/null +++ b/connectionimpl.go @@ -0,0 +1,196 @@ +package sqldb + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "time" +) + +var ( + _ Connection = new(ConnectionImpl) + _ TxConnection = new(TxConnectionImpl) +) + +type ConnectionImpl struct { + StructFieldMapper + QueryFormatter + ArrayHandler + Kind string + DB *sql.DB + Conf *Config + ValueConverter driver.ValueConverter +} + +func (conn *ConnectionImpl) String() string { + return fmt.Sprintf("%s connection: %s", conn.DatabaseKind(), conn.Config().ConnectURL()) +} + +func (conn *ConnectionImpl) DatabaseKind() string { + return conn.Kind +} + +func (conn *ConnectionImpl) Config() *Config { + return conn.Conf +} + +func (conn *ConnectionImpl) Err() error { + return nil +} + +func (conn *ConnectionImpl) Ping(ctx context.Context, timeout time.Duration) error { + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return conn.DB.PingContext(ctx) +} + +func (conn *ConnectionImpl) DBStats() sql.DBStats { + return conn.DB.Stats() +} + +func (conn *ConnectionImpl) IsTransaction() bool { + return false +} + +func (conn *ConnectionImpl) Exec(ctx context.Context, query string, args ...any) (err error) { + if conn.ValueConverter != nil { + for i, value := range args { + args[i], err = conn.ValueConverter.ConvertValue(value) + if err != nil { + return WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + } + } + _, err = conn.DB.ExecContext(ctx, query, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + return nil +} + +func (conn *ConnectionImpl) Query(ctx context.Context, query string, args ...any) (rows Rows, err error) { + if conn.ValueConverter != nil { + for i, value := range args { + args[i], err = conn.ValueConverter.ConvertValue(value) + if err != nil { + return nil, WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + } + } + rows, err = conn.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + return rows, nil +} + +func (conn *ConnectionImpl) Close() error { + return conn.DB.Close() +} + +type TxConnectionImpl struct { + ConnectionImpl + + DefaultLevel sql.IsolationLevel + Tx *sql.Tx + TxNo uint64 + TxOpts *sql.TxOptions +} + +func (conn *TxConnectionImpl) Exec(ctx context.Context, query string, args ...any) (err error) { + if conn.ValueConverter != nil { + for i, value := range args { + args[i], err = conn.ValueConverter.ConvertValue(value) + if err != nil { + return WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + } + } + if conn.Tx != nil { + _, err = conn.Tx.ExecContext(ctx, query, args...) + } else { + _, err = conn.DB.ExecContext(ctx, query, args...) + } + if err != nil { + return WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + return nil +} + +func (conn *TxConnectionImpl) Query(ctx context.Context, query string, args ...any) (rows Rows, err error) { + if conn.ValueConverter != nil { + for i, value := range args { + args[i], err = conn.ValueConverter.ConvertValue(value) + if err != nil { + return nil, WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + } + } + if conn.Tx != nil { + rows, err = conn.Tx.QueryContext(ctx, query, args...) + } else { + rows, err = conn.DB.QueryContext(ctx, query, args...) + } + if err != nil { + return nil, WrapErrorWithQuery(err, query, args, conn.QueryFormatter) + } + return rows, nil +} + +func (conn *TxConnectionImpl) IsTransaction() bool { + return conn.Tx != nil +} + +func (conn *TxConnectionImpl) DefaultIsolationLevel() sql.IsolationLevel { + return conn.DefaultLevel +} + +func (conn *TxConnectionImpl) TxNumber() uint64 { + return conn.TxNo +} + +func (conn *TxConnectionImpl) TxOptions() (*sql.TxOptions, bool) { + return conn.TxOpts, conn.Tx != nil +} + +func (conn *TxConnectionImpl) Begin(ctx context.Context, opts *sql.TxOptions, no uint64) (TxConnection, error) { + if conn.Tx != nil { + return nil, ErrWithinTransaction + } + tx, err := conn.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &TxConnectionImpl{ + ConnectionImpl: conn.ConnectionImpl, + DefaultLevel: conn.DefaultLevel, + Tx: tx, + TxNo: no, + TxOpts: opts, + }, nil +} + +func (conn *TxConnectionImpl) Commit() error { + if conn.Tx == nil { + return ErrNotWithinTransaction + } + return conn.Tx.Commit() +} + +func (conn *TxConnectionImpl) Rollback() error { + if conn.Tx == nil { + return ErrNotWithinTransaction + } + return conn.Tx.Rollback() +} + +func (conn *TxConnectionImpl) Close() error { + if conn.Tx != nil { + return conn.Tx.Rollback() + } + return conn.ConnectionImpl.Close() +} diff --git a/db/config.go b/db/config.go index bd6c1b4..3a49c63 100644 --- a/db/config.go +++ b/db/config.go @@ -1,23 +1 @@ package db - -import ( - "context" - "errors" - - "github.com/domonda/go-sqldb" -) - -var ( - // Number of retries used for a SerializedTransaction - // before it fails - SerializedTransactionRetries = 10 -) - -var ( - globalConn = sqldb.ConnectionWithError( - context.Background(), - errors.New("database connection not initialized"), - ) - globalConnCtxKey int - serializedTransactionCtxKey int -) diff --git a/db/conn.go b/db/conn.go deleted file mode 100644 index 281e248..0000000 --- a/db/conn.go +++ /dev/null @@ -1,54 +0,0 @@ -package db - -import ( - "context" - - "github.com/domonda/go-sqldb" -) - -// SetConn sets the global connection returned by Conn -// if there is no other connection in the context passed to Conn. -func SetConn(c sqldb.Connection) { - if c == nil { - panic("must not set nil sqldb.Connection") - } - globalConn = c -} - -// Conn returns a non nil sqldb.Connection from ctx -// or the global connection set with SetConn. -// The returned connection will use the passed context. -// See sqldb.Connection.WithContext -func Conn(ctx context.Context) sqldb.Connection { - return ConnDefault(ctx, globalConn) -} - -// ConnDefault returns a non nil sqldb.Connection from ctx -// or the passed defaultConn. -// The returned connection will use the passed context. -// See sqldb.Connection.WithContext -func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { - c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection) - if c == nil { - c = defaultConn - } - if c.Context() == ctx { - return c - } - return c.WithContext(ctx) -} - -// ContextWithConn returns a new context with the passed sqldb.Connection -// added as value so it can be retrieved again using Conn(ctx). -// Passing a nil connection causes Conn(ctx) -// to return the global connection set with SetConn. -func ContextWithConn(ctx context.Context, conn sqldb.Connection) context.Context { - return context.WithValue(ctx, &globalConnCtxKey, conn) -} - -// IsTransaction indicates if the connection from the context, -// or the default connection if the context has none, -// is a transaction. -func IsTransaction(ctx context.Context) bool { - return Conn(ctx).IsTransaction() -} diff --git a/db/query.go b/db/query.go index d99c4e0..c5fbb21 100644 --- a/db/query.go +++ b/db/query.go @@ -2,181 +2,10 @@ package db import ( "context" - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" "github.com/domonda/go-sqldb" ) -// Now returns the result of the SQL now() -// function for the current connection. -// Useful for getting the timestamp of a -// SQL transaction for use in Go code. -func Now(ctx context.Context) (time.Time, error) { - return Conn(ctx).Now() -} - -// Exec executes a query with optional args. -func Exec(ctx context.Context, query string, args ...any) error { - return Conn(ctx).Exec(query, args...) -} - -// QueryRow queries a single row and returns a RowScanner for the results. -func QueryRow(ctx context.Context, query string, args ...any) sqldb.RowScanner { - return Conn(ctx).QueryRow(query, args...) -} - -// QueryRows queries multiple rows and returns a RowsScanner for the results. -func QueryRows(ctx context.Context, query string, args ...any) sqldb.RowsScanner { - return Conn(ctx).QueryRows(query, args...) -} - -// QueryValue queries a single value of type T. -func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) - if err != nil { - var zero T - return zero, err - } - return value, nil -} - -// QueryValueOr queries a single value of type T -// or returns the passed defaultValue in case of sql.ErrNoRows. -func QueryValueOr[T any](ctx context.Context, defaultValue T, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return defaultValue, nil - } - var zero T - return zero, err - } - return value, err -} - -// QueryRowStruct queries a row and scans it as struct. -func QueryRowStruct[S any](ctx context.Context, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) - if err != nil { - return nil, err - } - return row, nil -} - -// QueryRowStructOrNil queries a row and scans it as struct -// or returns nil in case of sql.ErrNoRows. -func QueryRowStructOrNil[S any](ctx context.Context, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - return row, nil -} - -// GetRow uses the passed pkValue+pkValues to query a table row -// and scan it into a struct of type S that must have tagged fields -// with primary key flags to identify the primary key column names -// for the passed pkValue+pkValues and a table name. -func GetRow[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { - // Using explicit first pkValue value - // to not be able to compile without any value - pkValues = append([]any{pkValue}, pkValues...) - t := reflect.TypeOf(row).Elem() - if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("expected struct template type instead of %s", t) - } - conn := Conn(ctx) - table, pkColumns, err := pkColumnsOfStruct(conn, t) - if err != nil { - return nil, err - } - if len(pkColumns) != len(pkValues) { - return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns)) - } - var query strings.Builder - fmt.Fprintf(&query, `SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) //#nosec G104 - for i := 1; i < len(pkColumns); i++ { - fmt.Fprintf(&query, ` AND "%s" = $%d`, pkColumns[i], i+1) //#nosec G104 - } - err = conn.QueryRow(query.String(), pkValues...).ScanStruct(&row) - if err != nil { - return nil, err - } - return row, nil -} - -// GetRowOrNil uses the passed pkValue+pkValues to query a table row -// and scan it into a struct of type S that must have tagged fields -// with primary key flags to identify the primary key column names -// for the passed pkValue+pkValues and a table name. -// Returns nil as row and error if no row could be found with the -// passed pkValue+pkValues. -func GetRowOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { - row, err = GetRow[S](ctx, pkValue, pkValues...) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - return row, nil -} - -func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { - mapper := conn.StructFieldMapper() - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - fieldTable, column, flags, ok := mapper.MapStructField(field) - if !ok { - continue - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } - - if column == "" { - fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) - if err != nil { - return "", nil, err - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } - columns = append(columns, columnsEmbed...) - } else if flags.PrimaryKey() { - if err = conn.ValidateColumnName(column); err != nil { - return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) - } - columns = append(columns, column) - } - } - return table, columns, nil -} - -// QueryStructSlice returns queried rows as slice of the generic type S -// which must be a struct or a pointer to a struct. -func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []S, err error) { - err = Conn(ctx).QueryRows(query, args...).ScanStructSlice(&rows) - if err != nil { - return nil, err - } - return rows, nil -} - // InsertStruct inserts a new row into table using the connection's // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. diff --git a/db/transaction.go b/db/transaction.go deleted file mode 100644 index fb85419..0000000 --- a/db/transaction.go +++ /dev/null @@ -1,204 +0,0 @@ -package db - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/hex" - "errors" - "fmt" - "strings" - - "github.com/domonda/go-sqldb" -) - -// ValidateWithinTransaction returns sqldb.ErrNotWithinTransaction -// if the database connection from the context is not a transaction. -func ValidateWithinTransaction(ctx context.Context) error { - conn := Conn(ctx) - if err := conn.Config().Err; err != nil { - return err - } - if !conn.IsTransaction() { - return sqldb.ErrNotWithinTransaction - } - return nil -} - -// ValidateNotWithinTransaction returns sqldb.ErrWithinTransaction -// if the database connection from the context is a transaction. -func ValidateNotWithinTransaction(ctx context.Context) error { - conn := Conn(ctx) - if err := conn.Config().Err; err != nil { - return err - } - if conn.IsTransaction() { - return sqldb.ErrWithinTransaction - } - return nil -} - -// DebugNoTransaction executes nonTxFunc without a database transaction. -// Useful to temporarely replace Transaction to debug the same code without using a transaction. -func DebugNoTransaction(ctx context.Context, nonTxFunc func(context.Context) error) error { - return nonTxFunc(ctx) -} - -// IsolatedTransaction executes txFunc within a database transaction that is passed in to txFunc as tx Connection. -// IsolatedTransaction returns all errors from txFunc or transaction commit errors happening after txFunc. -// If parentConn is already a transaction, a brand new transaction will begin on the parent's connection. -// Errors and panics from txFunc will rollback the transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. -func IsolatedTransaction(ctx context.Context, txFunc func(context.Context) error) error { - return sqldb.IsolatedTransaction(Conn(ctx), nil, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) - }) -} - -// Transaction executes txFunc within a database transaction that is passed in to txFunc via the context. -// Use db.Conn(ctx) to get the transaction connection within txFunc. -// Transaction returns all errors from txFunc or transaction commit errors happening after txFunc. -// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx sqldb.Connection -// and no parentConn.Begin, Commit, or Rollback calls will occour within this Transaction call. -// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. -func Transaction(ctx context.Context, txFunc func(context.Context) error) error { - return sqldb.Transaction(Conn(ctx), nil, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) - }) -} - -// SerializedTransaction executes txFunc "serially" within a database transaction that is passed in to txFunc via the context. -// Use db.Conn(ctx) to get the transaction connection within txFunc. -// Transaction returns all errors from txFunc or transaction commit errors happening after txFunc. -// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx sqldb.Connection -// and no parentConn.Begin, Commit, or Rollback calls will occour within this Transaction call. -// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. -// -// Serialized transactions are typically necessary when an insert depends on a previous select within -// the transaction, but that pre-insert select can't lock the table like it's possible with SELECT FOR UPDATE. -// During transaction execution, the isolation level "Serializable" is set. This does not mean -// that the transaction will be run in series. On the contrary, it actually means that Postgres will -// track read/write dependencies and will report an error in case other concurrent transactions -// have altered the results of the statements within this transaction. If no serialisation is possible, -// raw Postgres error will be: -// ``` -// ERROR: could not serialize access due to read/write dependencies among transactions -// HINT: The transaction might succeed if retried. -// ``` -// or -// ``` -// ERROR: could not serialize access due to concurrent update -// HINT: The transaction might succeed if retried. -// ``` -// In this case, retry the whole transaction (as Postgres hints). This works simply -// because if you run the transaction for the second (or Nth) time, the queries will -// yield different results therefore altering the end result. -// -// SerializedTransaction calls can be nested, in which case nested calls just execute the -// txFunc within the parent's serialized transaction. -// It's not valid to nest a SerializedTransaction within a normal Transaction function -// because in this case serialization retries can't be delegated up to the -// partent transaction that doesn't know anything about serialization. -// -// Because of the retryable nature, please be careful with the size of the transaction and the retry cost. -func SerializedTransaction(ctx context.Context, txFunc func(context.Context) error) error { - // Pass nested serialized transactions through - if Conn(ctx).IsTransaction() { - if ctx.Value(&serializedTransactionCtxKey) == nil { - return errors.New("SerializedTransaction called from within a non-serialized transaction") - } - return txFunc(ctx) - } - - // Add value to context to check for nested serialized transactions - ctx = context.WithValue(ctx, &serializedTransactionCtxKey, struct{}{}) - - opts := sql.TxOptions{Isolation: sql.LevelSerializable} - for i := 0; i < SerializedTransactionRetries; i++ { - err := TransactionOpts(ctx, &opts, txFunc) - if err == nil || !strings.Contains(err.Error(), "could not serialize access") { - return err // nil or err - } - } - - return errors.New("SerializedTransaction retried too many times") -} - -// TransactionOpts executes txFunc within a database transaction with sql.TxOptions that is passed in to txFunc via the context. -// Use db.Conn(ctx) to get the transaction connection within txFunc. -// TransactionOpts returns all errors from txFunc or transaction commit errors happening after txFunc. -// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx sqldb.Connection -// and no parentConn.Begin, Commit, or Rollback calls will occour within this TransactionOpts call. -// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. -func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(context.Context) error) error { - return sqldb.Transaction(Conn(ctx), opts, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) - }) -} - -// TransactionReadOnly executes txFunc within a read-only database transaction that is passed in to txFunc via the context. -// Use db.Conn(ctx) to get the transaction connection within txFunc. -// TransactionReadOnly returns all errors from txFunc or transaction commit errors happening after txFunc. -// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx sqldb.Connection -// and no parentConn.Begin, Commit, or Rollback calls will occour within this TransactionReadOnly call. -// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. -func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error) error { - opts := sql.TxOptions{ReadOnly: true} - return sqldb.Transaction(Conn(ctx), &opts, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) - }) -} - -// TransactionSavepoint executes txFunc within a database transaction or uses savepoints for rollback. -// If the passed context already has a database transaction connection, -// then a savepoint with a random name is created before the execution of txFunc. -// If txFunc returns an error, then the transaction is rolled back to the savepoint -// but the transaction from the context is not rolled back. -// If the passed context does not have a database transaction connection, -// then Transaction(ctx, txFunc) is called without savepoints. -// Use db.Conn(ctx) to get the transaction connection within txFunc. -// TransactionSavepoint returns all errors from txFunc, transaction, savepoint, and rollback errors. -// Panics from txFunc are not recovered to rollback to the savepoint, -// they should behandled by the parent Transaction function. -func TransactionSavepoint(ctx context.Context, txFunc func(context.Context) error) error { - conn := Conn(ctx) - if !conn.IsTransaction() { - // If not already in a transaction, then execute txFunc - // within a as transaction instead of using savepoints: - return Transaction(ctx, txFunc) - } - - savepoint, err := randomSavepoint() - if err != nil { - return err - } - err = conn.Exec("savepoint " + savepoint) - if err != nil { - return err - } - - err = txFunc(ctx) - if err != nil { - e := conn.Exec("rollback to " + savepoint) - if e != nil && !errors.Is(e, sql.ErrTxDone) { - // Double error situation, wrap err with e so it doesn't get lost - err = fmt.Errorf("TransactionSavepoint error (%s) from rollback after error: %w", e, err) - } - return err - } - - return conn.Exec("release savepoint " + savepoint) -} - -func randomSavepoint() (string, error) { - b := make([]byte, 8) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return "SP" + hex.EncodeToString(b), nil -} diff --git a/db/transaction_test.go b/db/transaction_test.go deleted file mode 100644 index a6e16ac..0000000 --- a/db/transaction_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package db - -import ( - "context" - "errors" - "os" - "testing" - - "github.com/domonda/go-sqldb/mockconn" -) - -func TestSerializedTransaction(t *testing.T) { - globalConn = mockconn.New(context.Background(), os.Stdout, nil) - - expectSerialized := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { - panic("not in transaction") - } - if ctx.Value(&serializedTransactionCtxKey) == nil { - panic("no SerializedTransaction") - } - return nil - } - - expectSerializedWithError := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { - panic("not in transaction") - } - if ctx.Value(&serializedTransactionCtxKey) == nil { - panic("no SerializedTransaction") - } - return errors.New("expected error") - } - - nestedSerializedTransaction := func(ctx context.Context) error { - return SerializedTransaction(ctx, expectSerialized) - } - - okNestedTransaction := func(ctx context.Context) error { - return Transaction(ctx, nestedSerializedTransaction) - } - - type args struct { - ctx context.Context - txFunc func(context.Context) error - } - tests := []struct { - name string - args args - wantErr bool - }{ - {name: "flat call", args: args{ctx: context.Background(), txFunc: expectSerialized}, wantErr: false}, - {name: "expect error", args: args{ctx: context.Background(), txFunc: expectSerializedWithError}, wantErr: true}, - {name: "nested call", args: args{ctx: context.Background(), txFunc: nestedSerializedTransaction}, wantErr: false}, - {name: "nested tx call", args: args{ctx: context.Background(), txFunc: okNestedTransaction}, wantErr: false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := SerializedTransaction(tt.args.ctx, tt.args.txFunc); (err != nil) != tt.wantErr { - t.Errorf("SerializedTransaction() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestTransaction(t *testing.T) { - globalConn = mockconn.New(context.Background(), os.Stdout, nil) - - expectNonSerialized := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { - panic("not in transaction") - } - if ctx.Value(&serializedTransactionCtxKey) != nil { - panic("SerializedTransaction") - } - return nil - } - - expectNonSerializedWithError := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { - panic("not in transaction") - } - if ctx.Value(&serializedTransactionCtxKey) != nil { - panic("SerializedTransaction") - } - return errors.New("expected error") - } - - nestedTransaction := func(ctx context.Context) error { - return Transaction(ctx, expectNonSerialized) - } - - nestedSerializedTransaction := func(ctx context.Context) error { - return SerializedTransaction(ctx, nestedTransaction) - } - - type args struct { - ctx context.Context - txFunc func(context.Context) error - } - tests := []struct { - name string - args args - wantErr bool - }{ - {name: "flat call", args: args{ctx: context.Background(), txFunc: expectNonSerialized}, wantErr: false}, - {name: "expected error", args: args{ctx: context.Background(), txFunc: expectNonSerializedWithError}, wantErr: true}, - {name: "nested call", args: args{ctx: context.Background(), txFunc: nestedTransaction}, wantErr: false}, - {name: "nested serialized", args: args{ctx: context.Background(), txFunc: nestedSerializedTransaction}, wantErr: true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := Transaction(tt.args.ctx, tt.args.txFunc); (err != nil) != tt.wantErr { - t.Errorf("Transaction() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/db/utils.go b/db/utils.go index 10f59c4..3385936 100644 --- a/db/utils.go +++ b/db/utils.go @@ -30,7 +30,7 @@ func IsOtherThanErrNoRows(err error) bool { // and the current time of the database using `select now()` // or an error if the time could not be queried. func DebugPrintConn(ctx context.Context, args ...any) { - opts, isTx := Conn(ctx).TransactionOptions() + opts, isTx := Conn(ctx).TxOptions() if isTx { args = append(args, "SQL-Transaction") if optsStr := TxOptionsString(opts); optsStr != "" { diff --git a/errorconnection.go b/errorconnection.go new file mode 100644 index 0000000..94d8bc2 --- /dev/null +++ b/errorconnection.go @@ -0,0 +1,129 @@ +package sqldb + +import ( + "context" + "database/sql" + "errors" + "reflect" + "time" +) + +// ErrorConnection returns a FullyFeaturedConnection +// where all methods that return errors +// return the passed error. +func ErrorConnection(err error) FullyFeaturedConnection { + if err == nil { + panic("nil error for ErrorConnection") + } + return errorConnection{err} +} + +type errorConnection struct { + err error +} + +func (e errorConnection) Err() error { + return e.err +} + +func (e errorConnection) String() string { + return e.err.Error() +} + +func (e errorConnection) DatabaseKind() string { + return e.err.Error() +} + +func (e errorConnection) StringLiteral(s string) string { + return defaultQueryFormatter{}.StringLiteral(s) +} + +func (e errorConnection) ArrayLiteral(array any) (string, error) { + return "", e.err +} + +func (e errorConnection) ParameterPlaceholder(index int) string { + return defaultQueryFormatter{}.ParameterPlaceholder(index) +} + +func (e errorConnection) ValidateColumnName(name string) error { + return e.err +} + +func (e errorConnection) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { + return "", "", 0, false +} + +func (e errorConnection) MaxParameters() int { return 0 } + +func (e errorConnection) DBStats() sql.DBStats { + return sql.DBStats{} +} + +func (e errorConnection) Config() *Config { + return &Config{Driver: "ErrorConnection"} +} + +func (e errorConnection) IsTransaction() bool { + return false +} + +func (e errorConnection) Ping(ctx context.Context, timeout time.Duration) error { + return errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) Exec(ctx context.Context, query string, args ...any) error { + return errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) { + return nil, errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) DefaultIsolationLevel() sql.IsolationLevel { + return sql.LevelDefault +} + +func (e errorConnection) TxNumber() uint64 { + return 0 +} + +func (ce errorConnection) TxOptions() (*sql.TxOptions, bool) { + return nil, false +} + +func (e errorConnection) Begin(ctx context.Context, opts *sql.TxOptions, no uint64) (TxConnection, error) { + return nil, errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) Commit() error { + return e.err +} + +func (e errorConnection) Rollback() error { + return e.err +} + +func (e errorConnection) ListenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) (cancel func() error, err error) { + return func() error { return e.err }, errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) UnlistenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc) error { + return errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) IsListeningChannel(ctx context.Context, channel string) bool { + return false +} + +func (e errorConnection) ListeningChannels(ctx context.Context) ([]string, error) { + return nil, errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) NotifyChannel(ctx context.Context, channel, payload string) error { + return errors.Join(e.err, ctx.Err()) +} + +func (e errorConnection) Close() error { + return e.err +} diff --git a/errors.go b/errors.go index 6bbdaf6..9f20cb2 100644 --- a/errors.go +++ b/errors.go @@ -1,16 +1,10 @@ package sqldb import ( - "context" "database/sql" "errors" - "time" -) - -var ( - _ Connection = connectionWithError{} - _ RowScanner = rowScannerWithError{} - _ RowsScanner = rowsScannerWithError{} + "fmt" + "reflect" ) // ReplaceErrNoRows returns the passed replacement error @@ -29,6 +23,41 @@ func IsOtherThanErrNoRows(err error) bool { return err != nil && !errors.Is(err, sql.ErrNoRows) } +// WrapErrorWithQuery wraps non nil errors with a formatted query +// if the error was not already wrapped with a query. +// +// If the passed error is nil, then nil will be returned. +func WrapErrorWithQuery(err error, query string, args []any, formatter QueryFormatter) error { + var wrapped errWithQuery + if err == nil || errors.As(err, &wrapped) { + return err + } + return errWithQuery{err, query, args, formatter} +} + +// WrapResultErrorWithQuery wraps non nil errors referenced by errPtr +// with a formatted query if the error was not already wrapped with a query. +// +// If the passed error is nil, then nil will be returned. +func WrapResultErrorWithQuery(errPtr *error, query string, args []any, formatter QueryFormatter) { + if *errPtr != nil { + *errPtr = WrapErrorWithQuery(*errPtr, query, args, formatter) + } +} + +type errWithQuery struct { + err error + query string + args []any + formatter QueryFormatter +} + +func (e errWithQuery) Unwrap() error { return e.err } + +func (e errWithQuery) Error() string { + return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.args, e.formatter)) +} + // sentinelError implements the error interface for a string // and is meant to be used to declare const sentinel errors. // @@ -57,6 +86,30 @@ const ( ErrNullValueNotAllowed sentinelError = "null value not allowed" ) +// ErrColumnsWithoutStructFields + +type ErrColumnsWithoutStructFields struct { + Columns []string + Struct reflect.Value +} + +func (e ErrColumnsWithoutStructFields) Error() string { + return fmt.Sprintf("columns %#v has no mapped struct field in %s", e.Columns, e.Struct.Type()) +} + +// ErrStructFieldHasNoColumn + +type ErrStructFieldHasNoColumn struct { + StructField reflect.StructField + Columns []string +} + +func (e ErrStructFieldHasNoColumn) Error() string { + return fmt.Sprintf("struct field %s has no mapped column in %#v", e.StructField.Name, e.Columns) +} + +// ErrRaisedException + type ErrRaisedException struct { Message string } @@ -65,6 +118,8 @@ func (e ErrRaisedException) Error() string { return "raised exception: " + e.Message } +// ErrIntegrityConstraintViolation + type ErrIntegrityConstraintViolation struct { Constraint string } @@ -76,6 +131,8 @@ func (e ErrIntegrityConstraintViolation) Error() string { return "integrity constraint violation of constraint: " + e.Constraint } +// ErrRestrictViolation + type ErrRestrictViolation struct { Constraint string } @@ -91,6 +148,8 @@ func (e ErrRestrictViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } +// ErrNotNullViolation + type ErrNotNullViolation struct { Constraint string } @@ -106,6 +165,8 @@ func (e ErrNotNullViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } +// ErrForeignKeyViolation + type ErrForeignKeyViolation struct { Constraint string } @@ -121,6 +182,8 @@ func (e ErrForeignKeyViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } +// ErrUniqueViolation + type ErrUniqueViolation struct { Constraint string } @@ -136,6 +199,8 @@ func (e ErrUniqueViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } +// ErrCheckViolation + type ErrCheckViolation struct { Constraint string } @@ -151,6 +216,8 @@ func (e ErrCheckViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } +// ErrExclusionViolation + type ErrExclusionViolation struct { Constraint string } @@ -165,221 +232,3 @@ func (e ErrExclusionViolation) Error() string { func (e ErrExclusionViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } - -// ConnectionWithError - -// ConnectionWithError returns a dummy Connection -// where all methods return the passed error. -func ConnectionWithError(ctx context.Context, err error) Connection { - if err == nil { - panic("ConnectionWithError needs an error") - } - return connectionWithError{ctx, err} -} - -type connectionWithError struct { - ctx context.Context - err error -} - -func (e connectionWithError) Context() context.Context { return e.ctx } - -func (e connectionWithError) WithContext(ctx context.Context) Connection { - return connectionWithError{ctx: ctx, err: e.err} -} - -func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection { - return e -} - -func (e connectionWithError) StructFieldMapper() StructFieldMapper { - return DefaultStructFieldMapping -} - -func (e connectionWithError) Ping(time.Duration) error { - return e.err -} - -func (e connectionWithError) Stats() sql.DBStats { - return sql.DBStats{} -} - -func (e connectionWithError) Config() *Config { - return &Config{Err: e.err} -} - -func (e connectionWithError) ValidateColumnName(name string) error { - return e.err -} - -func (e connectionWithError) Now() (time.Time, error) { - return time.Time{}, e.err -} - -func (e connectionWithError) Exec(query string, args ...any) error { - return e.err -} - -func (e connectionWithError) Insert(table string, values Values) error { - return e.err -} - -func (e connectionWithError) InsertUnique(table string, values Values, onConflict string) (inserted bool, err error) { - return false, e.err -} - -func (e connectionWithError) InsertReturning(table string, values Values, returning string) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) InsertStructs(table string, rowStructs any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error) { - return false, e.err -} - -func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { - return e.err -} - -func (e connectionWithError) UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) IsTransaction() bool { - return false -} - -func (e connectionWithError) TransactionNo() uint64 { - return 0 -} - -func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (e connectionWithError) Begin(opts *sql.TxOptions, no uint64) (Connection, error) { - return nil, e.err -} - -func (e connectionWithError) Commit() error { - return e.err -} - -func (e connectionWithError) Rollback() error { - return e.err -} - -func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { - return e.err -} - -func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { - return e.err -} - -func (e connectionWithError) UnlistenChannel(channel string) error { - return e.err -} - -func (e connectionWithError) IsListeningOnChannel(channel string) bool { - return false -} - -func (e connectionWithError) Close() error { - return e.err -} - -// RowScannerWithError - -// RowScannerWithError returns a dummy RowScanner -// where all methods return the passed error. -func RowScannerWithError(err error) RowScanner { - return rowScannerWithError{err} -} - -type rowScannerWithError struct { - err error -} - -func (e rowScannerWithError) Scan(dest ...any) error { - return e.err -} - -func (e rowScannerWithError) ScanStruct(dest any) error { - return e.err -} - -func (e rowScannerWithError) ScanValues() ([]any, error) { - return nil, e.err -} - -func (e rowScannerWithError) ScanStrings() ([]string, error) { - return nil, e.err -} - -func (e rowScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -// RowsScannerWithError - -// RowsScannerWithError returns a dummy RowsScanner -// where all methods return the passed error. -func RowsScannerWithError(err error) RowsScanner { - return rowsScannerWithError{err} -} - -type rowsScannerWithError struct { - err error -} - -func (e rowsScannerWithError) ScanSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) ScanStructSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ScanAllRowsAsStrings(headerRow bool) ([][]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ForEachRow(callback func(RowScanner) error) error { - return e.err -} - -func (e rowsScannerWithError) ForEachRowCall(callback any) error { - return e.err -} diff --git a/impl/errors_test.go b/errors_test.go similarity index 86% rename from impl/errors_test.go rename to errors_test.go index 6e544e5..c8229df 100644 --- a/impl/errors_test.go +++ b/errors_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql" @@ -7,7 +7,7 @@ import ( "testing" ) -func TestWrapNonNilErrorWithQuery(t *testing.T) { +func TestWrapErrorWithQuery(t *testing.T) { type args struct { err error query string @@ -33,7 +33,7 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := WrapNonNilErrorWithQuery(tt.args.err, tt.args.query, tt.args.argFmt, tt.args.args) + err := WrapErrorWithQuery(tt.args.err, tt.args.query, tt.args.args, defaultQueryFormatter{}) if tt.wantError == "" && err != nil || tt.wantError != "" && (err == nil || err.Error() != tt.wantError) { t.Errorf("WrapNonNilErrorWithQuery() error = %v, wantErr %v", err, tt.wantError) } diff --git a/impl/format.go b/format.go similarity index 55% rename from impl/format.go rename to format.go index 5f78ba3..1111899 100644 --- a/impl/format.go +++ b/format.go @@ -1,20 +1,84 @@ -package impl +package sqldb import ( "database/sql/driver" "encoding/hex" + "errors" "fmt" "reflect" "strings" "time" "unicode" "unicode/utf8" + + "github.com/lib/pq" ) const timeFormat = "'2006-01-02 15:04:05.999999Z07:00:00'" +// type StringFormatter interface { +// StringLiteral(string) string +// } + +// type StringFormatterFunc func(string) string + +// func (f StringFormatterFunc) StringLiteral(s string) string { +// return f(s) +// } + +type QueryFormatter interface { + StringLiteral(str string) string + ArrayLiteral(array any) (string, error) + ValidateColumnName(name string) error + ParameterPlaceholder(index int) string + MaxParameters() int +} + +type defaultQueryFormatter struct{} + +func (defaultQueryFormatter) StringLiteral(str string) string { + return defaultStringLiteral(str) +} + +func (defaultQueryFormatter) ArrayLiteral(array any) (string, error) { + value, err := pq.Array(array).Value() + if err != nil { + return "", fmt.Errorf("can't format %T as SQL array because: %w", array, err) + } + return fmt.Sprintf("'%s'", value), nil +} + +func (defaultQueryFormatter) ValidateColumnName(name string) error { + if name == `` || name == `""` { + return errors.New("empty column name") + } + if strings.ContainsFunc(name, unicode.IsSpace) { + return fmt.Errorf("column name %q contains whitespace", name) + } + if strings.ContainsFunc(name, unicode.IsControl) { + return fmt.Errorf("column name %q contains control characters", name) + } + return nil +} + +func (defaultQueryFormatter) ParameterPlaceholder(index int) string { + return fmt.Sprintf("$%d", index+1) +} + +func (defaultQueryFormatter) MaxParameters() int { return 1024 } + +// AlwaysFormatValue formats a value for debugging or logging SQL statements. +// In case of any problems fmt.Sprint(val) is returned. +func AlwaysFormatValue(val any, formatter QueryFormatter) string { + str, err := FormatValue(val, formatter) + if err != nil { + return fmt.Sprint(val) + } + return str +} + // FormatValue formats a value for debugging or logging SQL statements. -func FormatValue(val any) (string, error) { +func FormatValue(val any, formatter QueryFormatter) (string, error) { if val == nil { return "NULL", nil } @@ -23,7 +87,7 @@ func FormatValue(val any) (string, error) { switch x := val.(type) { case driver.Valuer: - if v.Kind() == reflect.Ptr && v.IsNil() { + if v.Kind() == reflect.Pointer && v.IsNil() { // Assume nil pointer implementing driver.Valuer is NULL // because if the method Value is implemented by value // the nil pointer will still implement driver.Valuer @@ -34,18 +98,18 @@ func FormatValue(val any) (string, error) { if err != nil { return "", err } - return FormatValue(value) + return FormatValue(value, formatter) case time.Time: return x.Format(timeFormat), nil } switch v.Kind() { - case reflect.Ptr: + case reflect.Pointer: if v.IsNil() { return "NULL", nil } - return FormatValue(v.Elem().Interface()) + return FormatValue(v.Elem().Interface(), formatter) case reflect.Bool: if v.Bool() { @@ -57,9 +121,10 @@ func FormatValue(val any) (string, error) { case reflect.String: s := v.String() if l := len(s); l >= 2 && (s[0] == '{' && s[l-1] == '}' || s[0] == '[' && s[l-1] == ']') { + // String is already an array literal, just quote it return `'` + s + `'`, nil } - return QuoteLiteral(s), nil + return formatter.StringLiteral(s), nil case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { @@ -70,23 +135,27 @@ func FormatValue(val any) (string, error) { if l := len(b); l >= 2 && (b[0] == '{' && b[l-1] == '}' || b[0] == '[' && b[l-1] == ']') { return `'` + string(b) + `'`, nil } - return QuoteLiteral(string(b)), nil + return formatter.StringLiteral(string(b)), nil } + return formatter.ArrayLiteral(v.Interface()) + + case reflect.Array: + return formatter.ArrayLiteral(v.Interface()) } return fmt.Sprint(val), nil } -func FormatQuery(query, argFmt string, args ...any) string { +func FormatQuery(query string, args []any, naming QueryFormatter) string { + // Replace placeholders with formatted args for i := len(args) - 1; i >= 0; i-- { - placeholder := fmt.Sprintf(argFmt, i+1) - value, err := FormatValue(args[i]) - if err != nil { - value = "FORMATERROR:" + err.Error() - } - query = strings.ReplaceAll(query, placeholder, value) + placeholder := naming.ParameterPlaceholder(i) + formattedArg := AlwaysFormatValue(args[i], naming) + query = strings.ReplaceAll(query, placeholder, formattedArg) } + // Line endings and indentation: + lines := strings.Split(query, "\n") if len(lines) == 1 { return strings.TrimSpace(query) @@ -124,17 +193,7 @@ func FormatQuery(query, argFmt string, args ...any) string { return strings.Join(lines, "\n") } -// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal -// to DDL and other statements that do not accept parameters) to be used as part -// of an SQL statement. For example: -// -// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") -// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) -// -// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be -// replaced by two backslashes (i.e. "\\") and the C-style escape identifier -// that PostgreSQL provides ('E') will be prepended to the string. -func QuoteLiteral(literal string) string { +func defaultStringLiteral(literal string) string { // This follows the PostgreSQL internal algorithm for handling quoted literals // from libpq, which can be found in the "PQEscapeStringInternal" function, // which is found in the libpq/fe-exec.c source file: diff --git a/impl/format_test.go b/format_test.go similarity index 93% rename from impl/format_test.go rename to format_test.go index 8d6ab7c..1747337 100644 --- a/impl/format_test.go +++ b/format_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql/driver" @@ -36,7 +36,7 @@ func TestFormatValue(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := FormatValue(tt.val) + got, err := FormatValue(tt.val, defaultQueryFormatter{}) if (err != nil) != tt.wantErr { t.Errorf("FormatValue() error = %v, wantErr %v", err, tt.wantErr) return @@ -49,6 +49,8 @@ func TestFormatValue(t *testing.T) { } func TestFormatQuery(t *testing.T) { + formatter := defaultQueryFormatter{} + query1 := ` SELECT * @@ -90,7 +92,7 @@ WHERE } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := FormatQuery(tt.query, tt.argFmt, tt.args...); got != tt.want { + if got := FormatQuery(tt.query, tt.args, formatter); got != tt.want { t.Errorf("FormatQuery():\n%q\nWant:\n%q", got, tt.want) } }) diff --git a/go.mod b/go.mod index 0cabf0d..e541dd0 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/domonda/go-sqldb go 1.21 require ( - github.com/domonda/go-errs v0.0.0-20230810132956-1b6272f9fc8f - github.com/domonda/go-types v0.0.0-20230829145420-30f9974e0bc7 - github.com/go-sql-driver/mysql v1.7.1 + github.com/domonda/go-errs v0.0.0-20230920094343-6b122da4d22f + github.com/domonda/go-types v0.0.0-20230926122236-a75565adcd2b github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.8.4 + golang.org/x/net v0.15.0 ) require ( diff --git a/go.sum b/go.sum index c4df604..025c18b 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/domonda/go-errs v0.0.0-20230810132956-1b6272f9fc8f h1:OQaXlKXZc52Vsz7iH23NhddeMr0niW0tetB8Fq3k4yQ= -github.com/domonda/go-errs v0.0.0-20230810132956-1b6272f9fc8f/go.mod h1:DYkFE3rxUGhTCMmR5MpQ2NTtoCPiORdjBATGkIEeGKM= +github.com/domonda/go-errs v0.0.0-20230920094343-6b122da4d22f h1:ECYzMHlxXTmVwOYKqAZf7wEri1QvMv7AX9u25zotV0k= +github.com/domonda/go-errs v0.0.0-20230920094343-6b122da4d22f/go.mod h1:DYkFE3rxUGhTCMmR5MpQ2NTtoCPiORdjBATGkIEeGKM= github.com/domonda/go-pretty v0.0.0-20230810130018-8920f571470a h1:b3a6MwwMrHR9dw6585e3Ky51T50OKuD3fRuLyh8ziEw= github.com/domonda/go-pretty v0.0.0-20230810130018-8920f571470a/go.mod h1:3QkM8UJdyJMeKZiIo7hYzSkQBpRS3k0gOHw4ysyEIB4= -github.com/domonda/go-types v0.0.0-20230829145420-30f9974e0bc7 h1:riEK9SQ1O0ADGI66P1Rz2zqB+g3qlREm/wF7DINj7RI= -github.com/domonda/go-types v0.0.0-20230829145420-30f9974e0bc7/go.mod h1:qMSeU/23ZUopt+1kY0pJ27iqNRtsY1jATQklyCyLRAU= -github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= -github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/domonda/go-types v0.0.0-20230926122236-a75565adcd2b h1:CZyyHCaLPkcGXXxuC2qFcwsEnJX8MJ3jNZBlCHoIEa0= +github.com/domonda/go-types v0.0.0-20230926122236-a75565adcd2b/go.mod h1:2mMAkLzvuxdGvjJQ8o8I0QTJGRiCaBV/ZleKeoRbgJ0= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -23,6 +21,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/ungerik/go-reflection v0.0.0-20230810134712-a63435f6bc7e h1:BPksMeVdgSD8L4yXHYSY3HpdJ/5z2Ok5lF6PxHIVgEQ= github.com/ungerik/go-reflection v0.0.0-20230810134712-a63435f6bc7e/go.mod h1:1Q14POg/xa/P6/hWKfnUexqUhW1X6jgw+6gG7lOne1E= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/go.work b/go.work new file mode 100644 index 0000000..94b4da1 --- /dev/null +++ b/go.work @@ -0,0 +1,8 @@ +go 1.21 + +use ( + . + ./mysqlconn + ./pqconn + ./pqconn/tests +) diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..38b886d --- /dev/null +++ b/go.work.sum @@ -0,0 +1,47 @@ +github.com/cention-sany/utf7 v0.0.0-20170124080048-26cad61bd60a h1:MISbI8sU/PSK/ztvmWKFcI7UGb5/HQT7B+i3a2myKgI= +github.com/cention-sany/utf7 v0.0.0-20170124080048-26cad61bd60a/go.mod h1:2GxOXOlEPAMFPfp014mK1SWq8G8BN8o7/dfYqJrVGn8= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f h1:3BSP1Tbs2djlpprl7wCLuiqMaUh5SJkkzI2gDs+FgLs= +github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f/go.mod h1:Pcatq5tYkCW2Q6yrR2VRHlbHpZ/R4/7qyL1TCF7vl14= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/jaytaylor/html2text v0.0.0-20211105163654-bc68cce691ba h1:QFQpJdgbON7I0jr2hYW7Bs+XV0qjc3d5tZoDnRFnqTg= +github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056/go.mod h1:CVKlgaMiht+LXvHG173ujK6JUhZXKb2u/BQtjPDIvyk= +github.com/jhillyerd/enmime v0.10.1 h1:3VP8gFhK7R948YJBrna5bOgnTXEuPAoICo79kKkBKfA= +github.com/jhillyerd/enmime v1.0.0/go.mod h1:EktNOa/V6ka9yCrfoB2uxgefp1lno6OVdszW0iQ5LnM= +github.com/jhillyerd/enmime v1.0.1/go.mod h1:LMMbm6oTlzWHghPavqHtOrP/NosVv3l42CUrZjn03/Q= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/rivo/uniseg v0.4.3 h1:utMvzDsuh3suAEnhH0RdHmoPbU648o6CvXxTx4SBMOw= +github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo= +github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/teamwork/test v0.0.0-20200108114543-02621bae84ad/go.mod h1:TIbx7tx6WHBjQeLRM4eWQZBL7kmBZ7/KI4x4v7Y5YmA= +github.com/teamwork/tnef v0.0.0-20200108124832-7deabccfdb32 h1:j15wq0XPAY/HR/0+dtwUrIrF2ZTKbk7QIES2p4dAG+k= +github.com/teamwork/tnef v0.0.0-20200108124832-7deabccfdb32/go.mod h1:v7dFaQrF/4+curx7UTH9rqTkHTgXqghfI3thANW150o= +github.com/teamwork/utils v0.0.0-20220314153103-637fa45fa6cc/go.mod h1:3Fn0qxFeRNpvsg/9T1+btOOOKkd1qG2nPYKKcOmNpcs= +github.com/ungerik/go-fs v0.0.0-20230206141012-abb864f815e3 h1:IEm9Je1L3HAIEwiXOuGgJuUPjXXHzf/1e704VyIbcGc= +github.com/ungerik/go-fs v0.0.0-20230810132455-f7ff27f6fa2b h1:hZ/Tp1sn1oRwYqIZfjpfUp8N+5e3LGk32O8OAMh9VOk= +github.com/ungerik/go-fs v0.0.0-20230810132455-f7ff27f6fa2b/go.mod h1:P8k1DG+Ox0KP4MFNTSPd8ojoDUwXjrWdGjsssF6vT/g= +github.com/ungerik/go-fs v0.0.0-20230828210517-6ca798932ba7 h1:tEV8EZoXxANOeDk1jVZfT8NUNnX5TtxcHSIo5nBlo4c= +github.com/ungerik/go-fs v0.0.0-20230828210517-6ca798932ba7/go.mod h1:P8k1DG+Ox0KP4MFNTSPd8ojoDUwXjrWdGjsssF6vT/g= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +gotest.tools/v3 v3.2.0/go.mod h1:Mcr9QNxkg0uMvy/YElmo4SpXgJKWgQvYrT7Kw5RzJ1A= +mvdan.cc/xurls/v2 v2.4.0 h1:tzxjVAj+wSBmDcF6zBB7/myTy3gX9xvi8Tyr28AuQgc= +mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= diff --git a/impl/arrays_test.go b/impl/arrays_test.go deleted file mode 100644 index 2072fc2..0000000 --- a/impl/arrays_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package impl - -import ( - "database/sql" - "encoding/json" - "reflect" - "testing" - - "github.com/domonda/go-types/nullable" -) - -func TestShouldWrapForArray(t *testing.T) { - tests := []struct { - v reflect.Value - want bool - }{ - {v: reflect.ValueOf([]byte(nil)), want: false}, - {v: reflect.ValueOf([]byte{}), want: false}, - {v: reflect.ValueOf(""), want: false}, - {v: reflect.ValueOf(0), want: false}, - {v: reflect.ValueOf(json.RawMessage([]byte("null"))), want: false}, - {v: reflect.ValueOf(nullable.JSON([]byte("null"))), want: false}, - {v: reflect.ValueOf(new(sql.NullInt64)).Elem(), want: false}, - - {v: reflect.ValueOf(new([3]string)).Elem(), want: true}, - {v: reflect.ValueOf(new([]string)).Elem(), want: true}, - {v: reflect.ValueOf(new([]sql.NullString)).Elem(), want: true}, - } - for _, tt := range tests { - if got := ShouldWrapForArray(tt.v); got != tt.want { - t.Errorf("shouldWrapArray() = %v, want %v", got, tt.want) - } - } -} diff --git a/impl/connection.go b/impl/connection.go index e3ca281..02b45e8 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -1,11 +1,9 @@ package impl import ( - "context" "database/sql" "errors" "fmt" - "time" "github.com/domonda/go-sqldb" ) @@ -14,148 +12,6 @@ import ( // for an existing sql.DB connection. // argFmt is the format string for argument placeholders like "?" or "$%d" // that will be replaced error messages to format a complete query. -func Connection(ctx context.Context, db *sql.DB, config *sqldb.Config, validateColumnName func(string) error, argFmt string) sqldb.Connection { - return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldNamer: sqldb.DefaultStructFieldMapping, - argFmt: argFmt, - validateColumnName: validateColumnName, - } -} - -type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper - argFmt string - validateColumnName func(string) error -} - -func (conn *connection) clone() *connection { - c := *conn - return &c -} - -func (conn *connection) Context() context.Context { return conn.ctx } - -func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.ctx { - return conn - } - c := conn.clone() - c.ctx = ctx - return c -} - -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *connection) Ping(timeout time.Duration) error { - ctx := conn.ctx - if timeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - return conn.db.PingContext(ctx) -} - -func (conn *connection) Stats() sql.DBStats { - return conn.db.Stats() -} - -func (conn *connection) Config() *sqldb.Config { - return conn.config -} - -func (conn *connection) ValidateColumnName(name string) error { - return conn.validateColumnName(name) -} - -func (conn *connection) Now() (time.Time, error) { - return Now(conn) -} - -func (conn *connection) Exec(query string, args ...any) error { - _, err := conn.db.ExecContext(conn.ctx, query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) -} - -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return Insert(conn, table, conn.argFmt, columValues) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return InsertUnique(conn, table, conn.argFmt, values, onConflict) -} - -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return InsertReturning(conn, table, conn.argFmt, values, returning) -} - -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) InsertStructs(table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { - return InsertStructs(conn, table, rowStructs, ignoreColumns...) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.argFmt, args) -} - -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) -} - -func (conn *connection) IsTransaction() bool { - return false -} func (conn *connection) TransactionNo() uint64 { return 0 @@ -181,7 +37,7 @@ func (conn *connection) Rollback() error { return sqldb.ErrNotWithinTransaction } -func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { +func (conn *connection) ListenChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { return fmt.Errorf("notifications %w", errors.ErrUnsupported) } diff --git a/impl/errors.go b/impl/errors.go deleted file mode 100644 index 787e149..0000000 --- a/impl/errors.go +++ /dev/null @@ -1,30 +0,0 @@ -package impl - -import ( - "errors" - "fmt" -) - -// WrapNonNilErrorWithQuery wraps non nil errors with a formatted query -// if the error was not already wrapped with a query. -// If the passed error is nil, then nil will be returned. -func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error { - var wrapped errWithQuery - if err == nil || errors.As(err, &wrapped) { - return err - } - return errWithQuery{err, query, argFmt, args} -} - -type errWithQuery struct { - err error - query string - argFmt string - args []any -} - -func (e errWithQuery) Unwrap() error { return e.err } - -func (e errWithQuery) Error() string { - return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.argFmt, e.args...)) -} diff --git a/impl/foreachrow.go b/impl/foreachrow.go index 82898b2..8e0ffc7 100644 --- a/impl/foreachrow.go +++ b/impl/foreachrow.go @@ -2,21 +2,10 @@ package impl import ( "context" - "database/sql" "fmt" "reflect" - "time" - sqldb "github.com/domonda/go-sqldb" -) - -var ( - typeOfError = reflect.TypeOf((*error)(nil)).Elem() - typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem() - typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() - typeOfTime = reflect.TypeOf(time.Time{}) - typeOfByte = reflect.TypeOf(byte(0)) - typeOfByteSlice = reflect.TypeOf((*[]byte)(nil)).Elem() + "github.com/domonda/go-sqldb" ) // ForEachRowCallFunc will call the passed callback with scanned values or a struct for every row. @@ -48,7 +37,7 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan structArg := false for i := firstArg; i < typ.NumIn(); i++ { t := typ.In(i) - for t.Kind() == reflect.Ptr { + for t.Kind() == reflect.Pointer { t = t.Elem() } if t == typeOfTime { @@ -56,7 +45,7 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan } switch t.Kind() { case reflect.Struct: - if t.Implements(typeOfSQLScanner) || reflect.PtrTo(t).Implements(typeOfSQLScanner) { + if t.Implements(typeOfSQLScanner) || reflect.PointerTo(t).Implements(typeOfSQLScanner) { continue } if structArg { diff --git a/impl/genericconnection.go b/impl/genericconnection.go new file mode 100644 index 0000000..c5c5514 --- /dev/null +++ b/impl/genericconnection.go @@ -0,0 +1,227 @@ +package impl + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "time" + + "github.com/domonda/go-sqldb" +) + +// NewGenericConnection returns a generic sqldb.Connection implementation +// for an existing sql.DB connection. +// argFmt is the format string for argument placeholders like "?" or "$%d" +// that will be replaced error messages to format a complete query. +func NewGenericConnection(ctx context.Context, db *sql.DB, config *sqldb.Config, listener sqldb.Listener, structFieldMapper sqldb.StructFieldMapper, validateColumnName func(string) error, converter driver.ValueConverter, argFmt string) sqldb.Connection { + if listener == nil { + listener = sqldb.UnsupportedListener() + } + return &genericConn{ + ctx: ctx, + db: db, + config: config, + listener: listener, + structFieldMapper: structFieldMapper, + validateColumnName: validateColumnName, + converter: converter, + argFmt: argFmt, + } +} + +type genericConn struct { + ctx context.Context + db *sql.DB + config *sqldb.Config + listener sqldb.Listener + structFieldMapper sqldb.StructFieldMapper + validateColumnName func(string) error + converter driver.ValueConverter + argFmt string + + tx *sql.Tx + txOptions *sql.TxOptions + txNo uint64 +} + +func (conn *genericConn) clone() *genericConn { + c := *conn + return &c +} + +func (conn *genericConn) Context() context.Context { return conn.ctx } + +func (conn *genericConn) WithContext(ctx context.Context) sqldb.Connection { + if ctx == conn.ctx { + return conn + } + c := conn.clone() + c.ctx = ctx + return c +} + +func (conn *genericConn) WithStructFieldMapper(mapper sqldb.StructFieldMapper) sqldb.Connection { + c := conn.clone() + c.structFieldMapper = mapper + return c +} + +func (conn *genericConn) StructFieldMapper() sqldb.StructFieldMapper { + return conn.structFieldMapper +} + +func (conn *genericConn) ValidateColumnName(name string) error { + return conn.validateColumnName(name) +} + +func (conn *genericConn) Ping(timeout time.Duration) error { + ctx := conn.ctx + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return conn.db.PingContext(ctx) +} + +func (conn *genericConn) Stats() sql.DBStats { + return conn.db.Stats() +} + +func (conn *genericConn) Config() *sqldb.Config { + return conn.config +} + +func (conn *genericConn) Now() (time.Time, error) { + return QueryNow(conn) +} + +func (conn *genericConn) execer() Execer { + if conn.tx != nil { + return conn.tx + } + return conn.db +} + +func (conn *genericConn) queryer() Queryer { + if conn.tx != nil { + return conn.tx + } + return conn.db +} + +func (conn *genericConn) Exec(query string, args ...any) error { + return Exec(conn.ctx, conn.execer(), query, args, conn.converter, conn.argFmt) +} + +func (conn *genericConn) QueryRow(query string, args ...any) sqldb.RowScanner { + return QueryRow(conn.ctx, conn.queryer(), query, args, conn.converter, conn.argFmt, conn.structFieldMapper) +} + +func (conn *genericConn) QueryRows(query string, args ...any) sqldb.RowsScanner { + return QueryRows(conn.ctx, conn.queryer(), query, args, conn.converter, conn.argFmt, conn.structFieldMapper) +} + +func (conn *genericConn) Insert(table string, columValues sqldb.Values) error { + return Insert(conn.ctx, conn.execer(), table, columValues, conn.converter, conn.argFmt) +} + +func (conn *genericConn) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { + return InsertUnique(conn.ctx, conn.queryer(), table, values, onConflict, conn.converter, conn.argFmt, conn.structFieldMapper) +} + +func (conn *genericConn) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { + return InsertReturning(conn.ctx, conn.queryer(), table, values, returning, conn.converter, conn.argFmt, conn.structFieldMapper) +} + +func (conn *genericConn) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { + return InsertStruct(conn.ctx, conn.execer(), table, rowStruct, conn.structFieldMapper, ignoreColumns, conn.converter, conn.argFmt) +} + +func (conn *genericConn) InsertStructs(table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { + // TODO optimized version with single query if possible, split into multiple queries depending or maxArgs for query + return InsertStructs(conn, table, rowStructs, ignoreColumns...) +} + +func (conn *genericConn) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { + return InsertUniqueStruct(conn.ctx, conn.queryer(), conn.structFieldMapper, table, rowStruct, onConflict, ignoreColumns, conn.converter, conn.argFmt) +} + +func (conn *genericConn) Update(table string, values sqldb.Values, where string, args ...any) error { + return Update(conn.ctx, conn.execer(), table, values, where, args, conn.converter, conn.argFmt) +} + +func (conn *genericConn) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { + return UpdateReturningRow(conn.ctx, conn.queryer(), table, values, returning, where, args, conn.converter, conn.argFmt, conn.structFieldMapper) +} + +func (conn *genericConn) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { + return UpdateReturningRows(conn.ctx, conn.queryer(), table, values, returning, where, args, conn.converter, conn.argFmt, conn.structFieldMapper) +} + +func (conn *genericConn) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { + return UpdateStruct(conn.ctx, conn.execer(), table, rowStruct, conn.structFieldMapper, ignoreColumns, conn.converter, conn.argFmt) +} + +func (conn *genericConn) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { + return UpsertStruct(conn.ctx, conn.execer(), table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns) +} + +func (conn *genericConn) IsTransaction() bool { + return conn.tx != nil +} + +func (conn *genericConn) TransactionNo() uint64 { + return conn.txNo +} + +func (conn *genericConn) TransactionOptions() (*sql.TxOptions, bool) { + return conn.txOptions, conn.tx != nil +} + +func (conn *genericConn) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { + tx, err := conn.db.BeginTx(conn.ctx, opts) + if err != nil { + return nil, err + } + txConn := conn.clone() + txConn.tx = tx + txConn.txOptions = opts + txConn.txNo = no + return txConn, nil +} + +func (conn *genericConn) Commit() error { + if conn.tx == nil { + return sqldb.ErrNotWithinTransaction + } + return conn.tx.Commit() +} + +func (conn *genericConn) Rollback() error { + if conn.tx == nil { + return sqldb.ErrNotWithinTransaction + } + return conn.tx.Rollback() +} + +func (conn *genericConn) ListenChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { + return conn.listener.ListenChannel(conn, channel, onNotify, onUnlisten) +} + +func (conn *genericConn) UnlistenChannel(channel string) (err error) { + return conn.listener.UnlistenChannel(conn, channel) +} + +func (conn *genericConn) IsListeningOnChannel(channel string) bool { + return conn.listener.IsListeningOnChannel(conn, channel) +} + +func (conn *genericConn) Close() error { + err := conn.listener.Close(conn) + if conn.tx != nil { + return errors.Join(err, conn.tx.Rollback()) + } + return errors.Join(err, conn.db.Close()) +} diff --git a/impl/insert.go b/impl/insert.go index 47689de..4227826 100644 --- a/impl/insert.go +++ b/impl/insert.go @@ -1,6 +1,8 @@ package impl import ( + "context" + "database/sql/driver" "fmt" "reflect" "strings" @@ -9,25 +11,22 @@ import ( ) // Insert a new row into table using the values. -func Insert(conn sqldb.Connection, table, argFmt string, values sqldb.Values) error { +func Insert(ctx context.Context, conn Execer, table string, values sqldb.Values, converter driver.ValueConverter, argFmt string) error { if len(values) == 0 { return fmt.Errorf("Insert into table %s: no values", table) } - names, vals := values.Sorted() - b := strings.Builder{} - writeInsertQuery(&b, table, argFmt, names) - query := b.String() - - err := conn.Exec(query, vals...) + names, args := values.Sorted() + query := strings.Builder{} + writeInsertQuery(&query, table, argFmt, names) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return Exec(ctx, conn, query.String(), args, converter, argFmt) } // InsertUnique inserts a new row into table using the passed values // or does nothing if the onConflict statement applies. // Returns if a row was inserted. -func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Values, onConflict string) (inserted bool, err error) { +func InsertUnique(ctx context.Context, conn Queryer, table string, values sqldb.Values, onConflict string, converter driver.ValueConverter, argFmt string, mapper sqldb.StructFieldMapper) (inserted bool, err error) { if len(values) == 0 { return false, fmt.Errorf("InsertUnique into table %s: no values", table) } @@ -41,16 +40,13 @@ func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Valu writeInsertQuery(&query, table, argFmt, names) fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) - - err = sqldb.ReplaceErrNoRows(err, nil) - err = WrapNonNilErrorWithQuery(err, query.String(), argFmt, vals) - return inserted, err + err = QueryRow(ctx, conn, query.String(), vals, converter, argFmt, mapper).Scan(&inserted) + return inserted, sqldb.ReplaceErrNoRows(err, nil) } // InsertReturning inserts a new row into table using values // and returns values from the inserted row listed in returning. -func InsertReturning(conn sqldb.Connection, table, argFmt string, values sqldb.Values, returning string) sqldb.RowScanner { +func InsertReturning(ctx context.Context, conn Queryer, table string, values sqldb.Values, returning string, converter driver.ValueConverter, argFmt string, mapper sqldb.StructFieldMapper) sqldb.RowScanner { if len(values) == 0 { return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) } @@ -60,45 +56,23 @@ func InsertReturning(conn sqldb.Connection, table, argFmt string, values sqldb.V writeInsertQuery(&query, table, argFmt, names) query.WriteString(" RETURNING ") query.WriteString(returning) - return conn.QueryRow(query.String(), vals...) -} -func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { - fmt.Fprintf(w, `INSERT INTO %s(`, table) - for i, name := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteByte('"') - w.WriteString(name) - w.WriteByte('"') - } - w.WriteString(`) VALUES(`) - for i := range names { - if i > 0 { - w.WriteByte(',') - } - fmt.Fprintf(w, argFmt, i+1) - } - w.WriteByte(')') + return QueryRow(ctx, conn, query.String(), vals, converter, argFmt, mapper) } // InsertStruct inserts a new row into table using the connection's // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. -func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) +func InsertStruct(ctx context.Context, conn Execer, table string, rowStruct any, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter, converter driver.ValueConverter, argFmt string) error { + columns, vals, err := insertStructValues(table, rowStruct, mapper, ignoreColumns) if err != nil { return err } - var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) - query := b.String() - - err = conn.Exec(query, vals...) + var query strings.Builder + writeInsertQuery(&query, table, argFmt, columns) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return Exec(ctx, conn, query.String(), vals, converter, argFmt) } // InsertUniqueStruct inserts a new row into table using the connection's @@ -106,8 +80,8 @@ func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld // Optional ColumnFilter can be passed to ignore mapped columns. // Does nothing if the onConflict statement applies // and returns if a row was inserted. -func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onConflict string, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) (inserted bool, err error) { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) +func InsertUniqueStruct(ctx context.Context, conn Queryer, mapper sqldb.StructFieldMapper, table string, rowStruct any, onConflict string, ignoreColumns []sqldb.ColumnFilter, converter driver.ValueConverter, argFmt string) (inserted bool, err error) { + columns, vals, err := insertStructValues(table, rowStruct, mapper, ignoreColumns) if err != nil { return false, err } @@ -116,30 +90,27 @@ func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onCo onConflict = onConflict[1 : len(onConflict)-1] } - var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) - fmt.Fprintf(&b, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - query := b.String() - - err = conn.QueryRow(query, vals...).Scan(&inserted) - err = sqldb.ReplaceErrNoRows(err, nil) + var query strings.Builder + writeInsertQuery(&query, table, argFmt, columns) + fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - return inserted, WrapNonNilErrorWithQuery(err, query, argFmt, vals) + err = QueryRow(ctx, conn, query.String(), vals, converter, argFmt, mapper).Scan(&inserted) + return inserted, sqldb.ReplaceErrNoRows(err, nil) } -func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { +func insertStructValues(table string, rowStruct any, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { + for v.Kind() == reflect.Pointer && !v.IsNil() { v = v.Elem() } switch { - case v.Kind() == reflect.Ptr && v.IsNil(): + case v.Kind() == reflect.Pointer && v.IsNil(): return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) case v.Kind() != reflect.Struct: return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) } - columns, _, vals = ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + columns, _, vals = ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) return columns, vals, nil } @@ -148,6 +119,7 @@ func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapp // The inserts are performed within a new transaction // if the passed conn is not already a transaction. func InsertStructs(conn sqldb.Connection, table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { + // TODO optimized version with single query if possible, split into multiple queries depending or maxArgs for query v := reflect.ValueOf(rowStructs) if k := v.Type().Kind(); k != reflect.Slice && k != reflect.Array { return fmt.Errorf("InsertStructs expects a slice or array as rowStructs, got %T", rowStructs) diff --git a/impl/now.go b/impl/now.go deleted file mode 100644 index 4a1bd2f..0000000 --- a/impl/now.go +++ /dev/null @@ -1,15 +0,0 @@ -package impl - -import ( - "time" - - "github.com/domonda/go-sqldb" -) - -func Now(conn sqldb.Connection) (now time.Time, err error) { - err = conn.QueryRow(`select now()`).Scan(&now) - if err != nil { - return time.Time{}, err - } - return now, nil -} diff --git a/impl/query.go b/impl/query.go new file mode 100644 index 0000000..253add8 --- /dev/null +++ b/impl/query.go @@ -0,0 +1,55 @@ +package impl + +import ( + "context" + "database/sql" + "database/sql/driver" + + "github.com/domonda/go-sqldb" +) + +type Execer interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +func Exec(ctx context.Context, conn Execer, query string, args []any, converter driver.ValueConverter, argFmt string) error { + err := convertValuesInPlace(args, converter) + if err != nil { + err = WrapNonNilErrorWithQuery(err, query, argFmt, args) + return err + } + _, err = conn.ExecContext(ctx, query, args...) + return WrapNonNilErrorWithQuery(err, query, argFmt, args) +} + +type Queryer interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +func QueryRow(ctx context.Context, conn Queryer, query string, args []any, converter driver.ValueConverter, argFmt string, mapper sqldb.StructFieldMapper) sqldb.RowScanner { + err := convertValuesInPlace(args, converter) + if err != nil { + err = WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.RowScannerWithError(err) + } + rows, err := conn.QueryContext(ctx, query, args...) + if err != nil { + err = WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.RowScannerWithError(err) + } + return NewRowScanner(rows, mapper, query, argFmt, args) +} + +func QueryRows(ctx context.Context, conn Queryer, query string, args []any, converter driver.ValueConverter, argFmt string, mapper sqldb.StructFieldMapper) sqldb.RowsScanner { + err := convertValuesInPlace(args, converter) + if err != nil { + err = WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.RowsScannerWithError(err) + } + rows, err := conn.QueryContext(ctx, query, args...) + if err != nil { + err = WrapNonNilErrorWithQuery(err, query, argFmt, args) + return sqldb.RowsScannerWithError(err) + } + return NewRowsScanner(ctx, rows, mapper, query, argFmt, args) +} diff --git a/impl/reflectstruct.go b/impl/reflectstruct.go index ddb41c4..59eb9ca 100644 --- a/impl/reflectstruct.go +++ b/impl/reflectstruct.go @@ -5,15 +5,14 @@ import ( "fmt" "reflect" "slices" - "strings" "github.com/domonda/go-sqldb" ) -func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { +func ReflectStructValues(structVal reflect.Value, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { for i := 0; i < structVal.NumField(); i++ { fieldType := structVal.Type().Field(i) - _, column, flags, use := namer.MapStructField(fieldType) + _, column, flags, use := mapper.MapStructField(fieldType) if !use { continue } @@ -21,7 +20,7 @@ func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, if column == "" { // Embedded struct field - columnsEmbed, pkColsEmbed, valuesEmbed := ReflectStructValues(fieldValue, namer, ignoreColumns) + columnsEmbed, pkColsEmbed, valuesEmbed := ReflectStructValues(fieldValue, mapper, ignoreColumns) for _, pkCol := range pkColsEmbed { pkCols = append(pkCols, pkCol+len(columns)) } @@ -43,39 +42,49 @@ func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, return columns, pkCols, values } -func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { +// ReflectStructColumnPointers uses the passed mapper +// to find the passed columns as fields of the passed struct +// and returns a pointer to a struct field for every mapped column. +// +// If columns and struct fields could not be mapped 1:1 then +// an ErrColumnsWithoutStructFields or ErrStructFieldHasNoColumn +// error is returned together with the successfully mapped pointers. +func ReflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { + if structVal.Kind() != reflect.Struct { + return nil, fmt.Errorf("got %s instead of a struct", structVal) + } + if !structVal.CanAddr() { + return nil, errors.New("struct can't be addressed") + } if len(columns) == 0 { return nil, errors.New("no columns") } pointers = make([]any, len(columns)) - err = reflectStructColumnPointers(structVal, namer, columns, pointers) + err = reflectStructColumnPointers(structVal, mapper, columns, pointers) if err != nil { return nil, err } - for _, ptr := range pointers { - if ptr != nil { - continue + // Check if any column could not be mapped onto the struct, + // indicated by having a nil struct field pointer. + var nilCols sqldb.ErrColumnsWithoutStructFields + for i, ptr := range pointers { + if ptr == nil { + nilCols.Columns = append(nilCols.Columns, columns[i]) + nilCols.Struct = structVal } - nilCols := new(strings.Builder) - for i, ptr := range pointers { - if ptr != nil { - continue - } - if nilCols.Len() > 0 { - nilCols.WriteString(", ") - } - fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) - } - return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) + } + if len(nilCols.Columns) > 0 { + pointers = slices.DeleteFunc(pointers, func(e any) bool { return e == nil }) + return pointers, nilCols } return pointers, nil } -func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error { +func reflectStructColumnPointers(structVal reflect.Value, mapper sqldb.StructFieldMapper, columns []string, pointers []any) error { structType := structVal.Type() for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) - _, column, _, use := namer.MapStructField(field) + _, column, _, use := mapper.MapStructField(field) if !use { continue } @@ -83,7 +92,7 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel if column == "" { // Embedded struct field - err := reflectStructColumnPointers(fieldValue, namer, columns, pointers) + err := reflectStructColumnPointers(fieldValue, mapper, columns, pointers) if err != nil { return err } @@ -103,8 +112,8 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel // If field is a slice or array that does not implement sql.Scanner // and it's not a string scannable []byte type underneath // then wrap it with WrapForArray to make it scannable - if ShouldWrapForArray(fieldValue) { - pointer = WrapForArray(pointer) + if ShouldWrapForArrayScanning(fieldValue) { + pointer = WrapForArrayScanning(pointer) } pointers[colIndex] = pointer } diff --git a/impl/reflectstruct_test.go b/impl/reflectstruct_test.go new file mode 100644 index 0000000..39c6dcd --- /dev/null +++ b/impl/reflectstruct_test.go @@ -0,0 +1,113 @@ +package impl + +import ( + "reflect" + "testing" + + "github.com/domonda/go-sqldb" + "github.com/stretchr/testify/require" +) + +func TestReflectStructColumnPointers(t *testing.T) { + type DeepEmbeddedStruct struct { + DeeperEmbInt int `db:"deep_emb_int"` + } + type embeddedStruct struct { + DeepEmbeddedStruct + EmbInt int `db:"emb_int"` + } + type Struct struct { + ID string `db:"id,pk"` + Int int `db:"int"` + Ignore int `db:"-"` + embeddedStruct + UntaggedField int + Struct struct { + InlineStructInt int `db:"inline_struct_int"` + } `db:"-"` // TODO enable access to named embedded fields? + NilPtr *byte `db:"nil_ptr"` + } + var ( + structPtr = new(Struct) + structFieldPtrs = []any{ + &structPtr.ID, + &structPtr.Int, + &structPtr.DeeperEmbInt, + &structPtr.EmbInt, + // &structPtr.Struct.InlineStructInt, + &structPtr.NilPtr, + } + structCols = []string{"id", "int", "deep_emb_int", "emb_int" /*"inline_struct_int",*/, "nil_ptr"} + ) + + type args struct { + structVal reflect.Value + mapper sqldb.StructFieldMapper + columns []string + } + tests := []struct { + name string + args args + wantPointers []any + wantErr bool + }{ + { + name: "ok", + args: args{ + structVal: reflect.ValueOf(structPtr).Elem(), + mapper: sqldb.NewTaggedStructFieldMapping(), + columns: structCols, + }, + wantPointers: structFieldPtrs, + }, + + // Errors: + { + name: "no columns", + args: args{ + structVal: reflect.ValueOf(structPtr).Elem(), + mapper: sqldb.NewTaggedStructFieldMapping(), + columns: []string{}, + }, + wantErr: true, + }, + { + name: "not a struct", + args: args{ + structVal: reflect.ValueOf(structPtr), + mapper: sqldb.NewTaggedStructFieldMapping(), + columns: structCols, + }, + wantErr: true, + }, + { + name: "extra columns", + args: args{ + structVal: reflect.ValueOf(structPtr).Elem(), + mapper: sqldb.NewTaggedStructFieldMapping(), + columns: append(structCols, "some_column_not_found_at_struct"), + }, + wantErr: true, + }, + { + name: "not enough columns", + args: args{ + structVal: reflect.ValueOf(structPtr).Elem(), + mapper: sqldb.NewTaggedStructFieldMapping(), + columns: structCols[1:], + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPointers, err := ReflectStructColumnPointers(tt.args.structVal, tt.args.mapper, tt.args.columns) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.wantPointers, gotPointers) + }) + } +} diff --git a/impl/row.go b/impl/row.go index a3b34c7..78375e8 100644 --- a/impl/row.go +++ b/impl/row.go @@ -11,3 +11,24 @@ type Row interface { // number of columns in Rows. Scan(dest ...any) error } + +type RowWithArrays interface { + Row + + ScanWithArrays(dest []any) error +} + +// func AsRowWithArrays(row Row) RowWithArrays { +// if r, ok := row.(RowWithArrays); ok { +// return r +// } +// return rowWithArrays{row} +// } + +// type rowWithArrays struct { +// Row +// } + +// func (r rowWithArrays) ScanWithArrays(dest []any) error { +// return ScanRowWithArrays(r.Row, dest) +// } diff --git a/impl/rowscanner.go b/impl/rowscanner.go index 6bc5826..c6c68a1 100644 --- a/impl/rowscanner.go +++ b/impl/rowscanner.go @@ -1,81 +1,14 @@ package impl import ( - "database/sql" - "errors" - sqldb "github.com/domonda/go-sqldb" ) var ( - _ sqldb.RowScanner = &RowScanner{} _ sqldb.RowScanner = CurrentRowScanner{} - _ sqldb.RowScanner = SingleRowScanner{} + // _ sqldb.RowScanner = SingleRowScanner{} ) -// RowScanner implements sqldb.RowScanner for a sql.Row -type RowScanner struct { - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowScanner(rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowScanner { - return &RowScanner{rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowScanner) Scan(dest ...any) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return s.rows.Scan(dest...) -} - -func (s *RowScanner) ScanStruct(dest any) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return ScanStruct(s.rows, dest, s.structFieldNamer) -} - -func (s *RowScanner) ScanValues() ([]any, error) { - return ScanValues(s.rows) -} - -func (s *RowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.rows) -} - -func (s *RowScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - // CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close type CurrentRowScanner struct { Rows Rows @@ -83,7 +16,7 @@ type CurrentRowScanner struct { } func (s CurrentRowScanner) Scan(dest ...any) error { - return s.Rows.Scan(dest...) + return ScanRowWithArrays(s.Rows, dest) } func (s CurrentRowScanner) ScanStruct(dest any) error { @@ -103,27 +36,27 @@ func (s CurrentRowScanner) Columns() ([]string, error) { } // SingleRowScanner always uses the same Row -type SingleRowScanner struct { - Row Row - StructFieldMapper sqldb.StructFieldMapper -} - -func (s SingleRowScanner) Scan(dest ...any) error { - return s.Row.Scan(dest...) -} - -func (s SingleRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Row, dest, s.StructFieldMapper) -} - -func (s SingleRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Row) -} - -func (s SingleRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Row) -} - -func (s SingleRowScanner) Columns() ([]string, error) { - return s.Row.Columns() -} +// type SingleRowScanner struct { +// Row Row +// StructFieldMapper sqldb.StructFieldMapper +// } + +// func (s SingleRowScanner) Scan(dest ...any) error { +// return s.Row.Scan(dest...) +// } + +// func (s SingleRowScanner) ScanStruct(dest any) error { +// return ScanStruct(s.Row, dest, s.StructFieldMapper) +// } + +// func (s SingleRowScanner) ScanValues() ([]any, error) { +// return ScanValues(s.Row) +// } + +// func (s SingleRowScanner) ScanStrings() ([]string, error) { +// return ScanStrings(s.Row) +// } + +// func (s SingleRowScanner) Columns() ([]string, error) { +// return s.Row.Columns() +// } diff --git a/impl/rowsscanner.go b/impl/rowsscanner.go deleted file mode 100644 index 339a0e7..0000000 --- a/impl/rowsscanner.go +++ /dev/null @@ -1,154 +0,0 @@ -package impl - -import ( - "context" - "errors" - "fmt" - "reflect" - - sqldb "github.com/domonda/go-sqldb" -) - -var _ sqldb.RowsScanner = &RowsScanner{} - -// RowsScanner implements sqldb.RowsScanner with Rows -type RowsScanner struct { - ctx context.Context // ctx is checked for every row and passed through to callbacks - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowsScanner(ctx context.Context, rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowsScanner { - return &RowsScanner{ctx, rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowsScanner) ScanSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, nil) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) ScanStructSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldNamer) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -func (s *RowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { - cols, err := s.rows.Columns() - if err != nil { - return nil, err - } - if headerRow { - rows = [][]string{cols} - } - stringScannablePtrs := make([]any, len(cols)) - err = s.ForEachRow(func(rowScanner sqldb.RowScanner) error { - row := make([]string, len(cols)) - for i := range stringScannablePtrs { - stringScannablePtrs[i] = (*sqldb.StringScannable)(&row[i]) - } - err := rowScanner.Scan(stringScannablePtrs...) - if err != nil { - return err - } - rows = append(rows, row) - return nil - }) - return rows, err -} - -func (s *RowsScanner) ForEachRow(callback func(sqldb.RowScanner) error) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - for s.rows.Next() { - if s.ctx.Err() != nil { - return s.ctx.Err() - } - - err := callback(CurrentRowScanner{s.rows, s.structFieldNamer}) - if err != nil { - return err - } - } - return s.rows.Err() -} - -func (s *RowsScanner) ForEachRowCall(callback any) error { - forEachRowFunc, err := ForEachRowCallFunc(s.ctx, callback) - if err != nil { - return err - } - return s.ForEachRow(forEachRowFunc) -} - -// ScanRowsAsSlice scans all srcRows as slice into dest. -// The rows must either have only one column compatible with the element type of the slice, -// or if multiple columns are returned then the slice element type must me a struct or struction pointer -// so that every column maps on exactly one struct field using structFieldNamer. -// In case of single column rows, nil must be passed for structFieldNamer. -// ScanRowsAsSlice calls srcRows.Close(). -func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer sqldb.StructFieldMapper) error { - defer srcRows.Close() - - destVal := reflect.ValueOf(dest) - if destVal.Kind() != reflect.Ptr { - return fmt.Errorf("scan dest is not a pointer but %s", destVal.Type()) - } - if destVal.IsNil() { - return errors.New("scan dest is nil") - } - slice := destVal.Elem() - if slice.Kind() != reflect.Slice { - return fmt.Errorf("scan dest is not pointer to slice but %s", destVal.Type()) - } - sliceElemType := slice.Type().Elem() - - newSlice := reflect.MakeSlice(slice.Type(), 0, 32) - - for srcRows.Next() { - if ctx.Err() != nil { - return ctx.Err() - } - - newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType)) - target := newSlice.Index(newSlice.Len() - 1).Addr() - if structFieldNamer != nil { - err := ScanStruct(srcRows, target.Interface(), structFieldNamer) - if err != nil { - return err - } - } else { - err := srcRows.Scan(target.Interface()) - if err != nil { - return err - } - } - } - if srcRows.Err() != nil { - return srcRows.Err() - } - - // Assign newSlice if there were no errors - if newSlice.Len() == 0 { - slice.SetLen(0) - } else { - slice.Set(newSlice) - } - - return nil -} diff --git a/impl/scanresult.go b/impl/scanresult.go index d8a4bc3..4f9d22e 100644 --- a/impl/scanresult.go +++ b/impl/scanresult.go @@ -1,54 +1 @@ package impl - -import "github.com/domonda/go-sqldb" - -// ScanValues returns the values of a row exactly how they are -// passed from the database driver to an sql.Scanner. -// Byte slices will be copied. -func ScanValues(src Row) ([]any, error) { - cols, err := src.Columns() - if err != nil { - return nil, err - } - var ( - anys = make([]sqldb.AnyValue, len(cols)) - result = make([]any, len(cols)) - ) - // result elements hold pointer to sqldb.AnyValue for scanning - for i := range result { - result[i] = &anys[i] - } - err = src.Scan(result...) - if err != nil { - return nil, err - } - // don't return pointers to sqldb.AnyValue - // but what internal value has been scanned - for i := range result { - result[i] = anys[i].Val - } - return result, nil -} - -// ScanStrings scans the values of a row as strings. -// Byte slices will be interpreted as strings, -// nil (SQL NULL) will be converted to an empty string, -// all other types are converted with fmt.Sprint. -func ScanStrings(src Row) ([]string, error) { - cols, err := src.Columns() - if err != nil { - return nil, err - } - var ( - result = make([]string, len(cols)) - resultPtrs = make([]any, len(cols)) - ) - for i := range resultPtrs { - resultPtrs[i] = (*sqldb.StringScannable)(&result[i]) - } - err = src.Scan(resultPtrs...) - if err != nil { - return nil, err - } - return result, nil -} diff --git a/impl/scanstruct.go b/impl/scanstruct.go index ce2dd56..4f9d22e 100644 --- a/impl/scanstruct.go +++ b/impl/scanstruct.go @@ -1,53 +1 @@ package impl - -import ( - "fmt" - "reflect" - - sqldb "github.com/domonda/go-sqldb" -) - -func ScanStruct(srcRow Row, destStruct any, namer sqldb.StructFieldMapper) error { - v := reflect.ValueOf(destStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - - var ( - setDestStructPtr = false - destStructPtr reflect.Value - newStructPtr reflect.Value - ) - if v.Kind() == reflect.Ptr && v.IsNil() && v.CanSet() { - // Got a nil pointer that we can set with a newly allocated struct - setDestStructPtr = true - destStructPtr = v - newStructPtr = reflect.New(v.Type().Elem()) - // Continue with the newly allocated struct - v = newStructPtr.Elem() - } - if v.Kind() != reflect.Struct { - return fmt.Errorf("ScanStruct: expected struct but got %T", destStruct) - } - - columns, err := srcRow.Columns() - if err != nil { - return err - } - - fieldPointers, err := ReflectStructColumnPointers(v, namer, columns) - if err != nil { - return fmt.Errorf("ScanStruct: %w", err) - } - - err = srcRow.Scan(fieldPointers...) - if err != nil { - return err - } - - if setDestStructPtr { - destStructPtr.Set(newStructPtr) - } - - return nil -} diff --git a/impl/transaction.go b/impl/transaction.go index 298c3fd..b2183ae 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -3,8 +3,6 @@ package impl import ( "context" "database/sql" - "errors" - "fmt" "time" "github.com/domonda/go-sqldb" @@ -134,47 +132,3 @@ func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner } return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) } - -func (conn *transaction) IsTransaction() bool { - return true -} - -func (conn *transaction) TransactionNo() uint64 { - return conn.no -} - -func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn *transaction) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { - tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn.parent, tx, opts, no), nil -} - -func (conn *transaction) Commit() error { - return conn.tx.Commit() -} - -func (conn *transaction) Rollback() error { - return conn.tx.Rollback() -} - -func (conn *transaction) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *transaction) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *transaction) IsListeningOnChannel(channel string) bool { - return false -} - -func (conn *transaction) Close() error { - return conn.Rollback() -} diff --git a/impl/types.go b/impl/types.go new file mode 100644 index 0000000..6f813cd --- /dev/null +++ b/impl/types.go @@ -0,0 +1,19 @@ +package impl + +import ( + "context" + "database/sql" + "database/sql/driver" + "reflect" + "time" +) + +var ( + typeOfError = reflect.TypeOf((*error)(nil)).Elem() + typeOfByte = reflect.TypeOf(byte(0)) + typeOfByteSlice = reflect.TypeOf((*[]byte)(nil)).Elem() + typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem() + typeOfTime = reflect.TypeOf(time.Time{}) + typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + typeOfDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +) diff --git a/impl/update.go b/impl/update.go index e5f6ca1..782b86b 100644 --- a/impl/update.go +++ b/impl/update.go @@ -1,6 +1,8 @@ package impl import ( + "context" + "database/sql/driver" "fmt" "reflect" "slices" @@ -10,41 +12,41 @@ import ( ) // Update table rows(s) with values using the where statement with passed in args starting at $1. -func Update(conn sqldb.Connection, table string, values sqldb.Values, where, argFmt string, args []any) error { +func Update(ctx context.Context, conn Execer, table string, values sqldb.Values, where string, args []any, converter driver.ValueConverter, argFmt string) error { if len(values) == 0 { return fmt.Errorf("Update table %s: no values passed", table) } - query, vals := buildUpdateQuery(table, values, where, args) - err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + query, args := buildUpdateQuery(table, values, where, args) + return Exec(ctx, conn, query, args, converter, argFmt) } // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 // and returning a single row with the columns specified in returning argument. -func UpdateReturningRow(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { +func UpdateReturningRow(ctx context.Context, conn Queryer, table string, values sqldb.Values, returning, where string, args []any, converter driver.ValueConverter, argFmt string, mapper sqldb.StructFieldMapper) sqldb.RowScanner { if len(values) == 0 { return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) } query, vals := buildUpdateQuery(table, values, where, args) query += " RETURNING " + returning - return conn.QueryRow(query, vals...) + return QueryRow(ctx, conn, query, vals, converter, argFmt, mapper) } // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 // and returning multiple rows with the columns specified in returning argument. -func UpdateReturningRows(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { +func UpdateReturningRows(ctx context.Context, conn Queryer, table string, values sqldb.Values, returning, where string, args []any, converter driver.ValueConverter, argFmt string, mapper sqldb.StructFieldMapper) sqldb.RowsScanner { if len(values) == 0 { return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) } query, vals := buildUpdateQuery(table, values, where, args) query += " RETURNING " + returning - return conn.QueryRows(query, vals...) + return QueryRows(ctx, conn, query, vals, converter, argFmt, mapper) } func buildUpdateQuery(table string, values sqldb.Values, where string, args []any) (string, []any) { + // args = WrapArgsForArrays(args) names, vals := values.Sorted() var query strings.Builder @@ -65,25 +67,25 @@ func buildUpdateQuery(table string, values sqldb.Values, where string, args []an // Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. // If restrictToColumns are provided, then only struct fields with a `db` tag // matching any of the passed column names will be used. -func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { +func UpdateStruct(ctx context.Context, conn Execer, table string, rowStruct any, mapper sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter, converter driver.ValueConverter, argFmt string) error { v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { + for v.Kind() == reflect.Pointer && !v.IsNil() { v = v.Elem() } switch { - case v.Kind() == reflect.Ptr && v.IsNil(): + case v.Kind() == reflect.Pointer && v.IsNil(): return fmt.Errorf("UpdateStruct of table %s: can't insert nil", table) case v.Kind() != reflect.Struct: return fmt.Errorf("UpdateStruct of table %s: expected struct but got %T", table, rowStruct) } - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + columns, pkCols, vals := ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) if len(pkCols) == 0 { return fmt.Errorf("UpdateStruct of table %s: %s has no mapped primary key field", table, v.Type()) } - var b strings.Builder - fmt.Fprintf(&b, `UPDATE %s SET `, table) + var query strings.Builder + fmt.Fprintf(&query, `UPDATE %s SET `, table) first := true for i := range columns { if slices.Contains(pkCols, i) { @@ -92,22 +94,18 @@ func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld if first { first = false } else { - b.WriteByte(',') + query.WriteByte(',') } - fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) + fmt.Fprintf(&query, `"%s"=$%d`, columns[i], i+1) } - b.WriteString(` WHERE `) + query.WriteString(` WHERE `) for i, pkCol := range pkCols { if i > 0 { - b.WriteString(` AND `) + query.WriteString(` AND `) } - fmt.Fprintf(&b, `"%s"=$%d`, columns[pkCol], i+1) + fmt.Fprintf(&query, `"%s"=$%d`, columns[pkCol], i+1) } - query := b.String() - - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return Exec(ctx, conn, query.String(), vals, converter, argFmt) } diff --git a/impl/upsert.go b/impl/upsert.go index ebbbfbd..37821b3 100644 --- a/impl/upsert.go +++ b/impl/upsert.go @@ -1,12 +1,13 @@ package impl import ( + "context" "fmt" "reflect" "slices" "strings" - "github.com/domonda/go-sqldb" + sqldb "github.com/domonda/go-sqldb" ) // UpsertStruct upserts a row to table using the exported fields @@ -15,19 +16,19 @@ import ( // If restrictToColumns are provided, then only struct fields with a `db` tag // matching any of the passed column names will be used. // If inserting conflicts on pkColumn, then an update of the existing row is performed. -func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { +func UpsertStruct(ctx context.Context, conn Execer, table string, rowStruct any, mapper sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { + for v.Kind() == reflect.Pointer && !v.IsNil() { v = v.Elem() } switch { - case v.Kind() == reflect.Ptr && v.IsNil(): + case v.Kind() == reflect.Pointer && v.IsNil(): return fmt.Errorf("UpsertStruct to table %s: can't insert nil", table) case v.Kind() != reflect.Struct: return fmt.Errorf("UpsertStruct to table %s: expected struct but got %T", table, rowStruct) } - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + columns, pkCols, vals := ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) if len(pkCols) == 0 { return fmt.Errorf("UpsertStruct of table %s: %s has no mapped primary key field", table, v.Type()) } @@ -57,7 +58,7 @@ func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld } query := b.String() - err := conn.Exec(query, vals...) + _, err := conn.ExecContext(ctx, query, vals...) return WrapNonNilErrorWithQuery(err, query, argFmt, vals) } diff --git a/insert.go b/insert.go new file mode 100644 index 0000000..50bc00b --- /dev/null +++ b/insert.go @@ -0,0 +1,152 @@ +package sqldb + +import ( + "context" + "fmt" + "reflect" + "strings" +) + +func Insert(ctx context.Context, table string, rows any) error { + conn := ContextConnection(ctx) + + v := reflect.ValueOf(rows) + if v.Kind() == reflect.Pointer && !v.IsNil() { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice: + if v.Len() == 0 { + return nil + } + mapped, err := MapStructType(conn, v.Type().Elem()) + if err != nil { + return err + } + return insertRows(ctx, conn, mapped, table, reflect.ValueOf(rows)) + + case reflect.Struct: + columns, values, table, err := MapStructFieldValues(conn, rows) + if err != nil { + return err + } + query := createInsertQuery(table, columns, 1, conn) + return conn.Exec(ctx, query, values...) + + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { + return fmt.Errorf("%T is not a map with a string key type", rows) + } + columns, values := mapKeysAndValues(v) + query := createInsertQuery(table, columns, 1, conn) + return conn.Exec(ctx, query, values...) + + default: + return fmt.Errorf("%T not supported as rows argument", rows) + } +} + +func InsertRow(ctx context.Context, row RowWithTableName) error { + conn := ContextConnection(ctx) + columns, values, table, err := MapStructFieldValues(conn, row) + if err != nil { + return err + } + query := createInsertQuery(table, columns, 1, conn) + return conn.Exec(ctx, query, values...) +} + +func InsertRows[R RowWithTableName](ctx context.Context, rows []R) error { + if len(rows) == 0 { + return nil + } + conn := ContextConnection(ctx) + + mapped, err := MapStructType(conn, reflect.TypeOf(rows[0])) + if err != nil { + return err + } + return insertRows(ctx, conn, mapped, mapped.Table, reflect.ValueOf(rows)) +} + +func insertRows(ctx context.Context, conn Connection, mapped *MappedStruct, table string, rows reflect.Value) error { + numRows := rows.Len() + numRowsRemaining := numRows + numCols := len(mapped.Fields) + maxParams := conn.MaxParameters() + + if maxParams > 0 && maxParams < numRows*numCols { + maxRowsPerInsert := maxParams / numCols + if maxRowsPerInsert == 0 { + return fmt.Errorf("%s has %d mapped struct fields which is greater than Connection.MaxParameters of %d", mapped.Type, numCols, conn.MaxParameters()) + } + numRowsPerInsert := numRows / maxRowsPerInsert + insertValues := make([]any, 0, numRowsPerInsert*numCols) + + for i := 0; i < numRows; i += numRowsPerInsert { + for r := 0; r < numRowsPerInsert; r++ { + rowValues, err := mapped.StructFieldValues(rows.Index(i + r)) + if err != nil { + return err + } + insertValues = append(insertValues, rowValues...) + } + query := createInsertQuery(table, mapped.Columns, numRowsPerInsert, conn) + err := conn.Exec(ctx, query, insertValues...) + if err != nil { + return err + } + + insertValues = insertValues[:0] + numRowsRemaining -= numRowsPerInsert + if numRowsRemaining < 0 { + panic("can't happen") + } + } + } + if numRowsRemaining == 0 { + return nil + } + + insertValues := make([]any, 0, numCols*numRowsRemaining) + for r := numRows - numRowsRemaining; r < numRows; r++ { + rowValues, err := mapped.StructFieldValues(rows.Index(r)) + if err != nil { + return err + } + insertValues = append(insertValues, rowValues...) + } + query := createInsertQuery(table, mapped.Columns, numRowsRemaining, conn) + return conn.Exec(ctx, query, insertValues...) +} + +func createInsertQuery(table string, columns []string, numRows int, formatter QueryFormatter) string { + var b strings.Builder + b.WriteString("INSERT INTO ") + b.WriteString(table) + b.WriteByte('(') + for i, column := range columns { + if i > 0 { + b.WriteByte(',') + } + b.WriteByte('"') + b.WriteString(column) + b.WriteByte('"') + } + b.WriteString(")\nVALUES") + for r := 0; r < numRows; r++ { + if r > 0 { + b.WriteString("\n , ") + } + b.WriteByte('(') + for c := range columns { + if c > 0 { + b.WriteByte(',') + } + b.WriteString(formatter.ParameterPlaceholder(r*len(columns) + c)) + } + b.WriteByte(')') + } + return b.String() +} diff --git a/insert_test.go b/insert_test.go new file mode 100644 index 0000000..4aa5042 --- /dev/null +++ b/insert_test.go @@ -0,0 +1,25 @@ +package sqldb + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInsert(t *testing.T) { + var queryBuf QueryBuffer + conn := LogConnection(NullConnection(defaultQueryFormatter{}, DefaultStructFieldMapping), &queryBuf) + ctx := ContextWithConnection(context.Background(), conn) + + err := Insert(ctx, "test_table", Values{ + "a": "Hello", + "b": true, + "c": 666, + }) + require.NoError(t, err) + expectedQuery := `INSERT INTO test_table("a","b","c") +VALUES($1,$2,$3); +` + require.Equal(t, expectedQuery, queryBuf.String()) +} diff --git a/logconnection.go b/logconnection.go new file mode 100644 index 0000000..2ab201c --- /dev/null +++ b/logconnection.go @@ -0,0 +1,236 @@ +package sqldb + +import ( + "context" + "database/sql" + "errors" + "reflect" + "strings" + "time" +) + +type QueryLogger interface { + LogQuery(query string, args []any) +} + +type QueryLoggerFunc func(query string, args []any) + +func (f QueryLoggerFunc) LogQuery(query string, args []any) { + f(query, args) +} + +type QueryBuffer struct { + b strings.Builder +} + +func (b *QueryBuffer) LogQuery(query string, args []any) { + b.b.WriteString(query) + b.b.WriteString(";\n") +} + +func (b *QueryBuffer) String() string { + return b.b.String() +} + +func (b *QueryBuffer) Len() int { + return b.b.Len() +} + +func (b *QueryBuffer) Reset() { + b.b.Reset() +} + +func LogConnection(target Connection, queryLogger QueryLogger) FullyFeaturedConnection { + if target == nil { + panic(" target Connection") + } + if queryLogger == nil { + panic(" queryLogger") + } + return &logConnection{ + target: target, + queryLogger: queryLogger, + } +} + +type logConnection struct { + target Connection + queryLogger QueryLogger +} + +func (c *logConnection) Err() error { + return c.target.Err() +} + +func (c *logConnection) String() string { + return "LogConnection->" + c.target.String() +} + +func (c *logConnection) DatabaseKind() string { + return c.target.DatabaseKind() +} + +func (c *logConnection) StringLiteral(s string) string { + return c.target.StringLiteral(s) +} + +func (c *logConnection) ArrayLiteral(array any) (string, error) { + return c.target.ArrayLiteral(array) +} + +func (c *logConnection) ParameterPlaceholder(index int) string { + return c.target.ParameterPlaceholder(index) +} + +func (c *logConnection) ValidateColumnName(name string) error { + return c.target.ValidateColumnName(name) +} + +func (c *logConnection) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { + return c.target.MapStructField(field) +} + +func (c *logConnection) MaxParameters() int { + return c.target.MaxParameters() +} + +func (c *logConnection) DBStats() sql.DBStats { + return c.target.DBStats() +} + +func (c *logConnection) Config() *Config { + return c.target.Config() +} + +func (c *logConnection) IsTransaction() bool { + return c.target.IsTransaction() +} + +func (c *logConnection) Ping(ctx context.Context, timeout time.Duration) error { + return c.target.Ping(ctx, timeout) +} + +func (c *logConnection) Exec(ctx context.Context, query string, args ...any) error { + c.queryLogger.LogQuery(query, args) + return c.target.Exec(ctx, query, args...) +} + +func (c *logConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) { + c.queryLogger.LogQuery(query, args) + return c.target.Query(ctx, query, args...) +} + +func (c *logConnection) DefaultIsolationLevel() sql.IsolationLevel { + target, ok := c.target.(TxConnection) + if !ok { + return sql.LevelDefault + } + return target.DefaultIsolationLevel() +} + +func (c *logConnection) TxNumber() uint64 { + target, ok := c.target.(TxConnection) + if !ok { + return 0 + } + return target.TxNumber() +} + +func (c *logConnection) TxOptions() (*sql.TxOptions, bool) { + target, ok := c.target.(TxConnection) + if !ok { + return nil, false + } + return target.TxOptions() +} + +func (c *logConnection) Begin(ctx context.Context, opts *sql.TxOptions, no uint64) (TxConnection, error) { + target, ok := c.target.(TxConnection) + if !ok { + return nil, errors.ErrUnsupported + } + query := "BEGIN" + if opts != nil { + if opts.Isolation != sql.LevelDefault { + query += " " + strings.ToUpper(opts.Isolation.String()) + } + if opts.ReadOnly { + query += " READ ONLY" + } + } + c.queryLogger.LogQuery(query, nil) + tx, err := target.Begin(ctx, opts, no) + if err != nil { + return nil, err + } + return LogConnection(tx, c.queryLogger), nil +} + +func (c *logConnection) Commit() error { + target, ok := c.target.(TxConnection) + if !ok { + return errors.ErrUnsupported + } + c.queryLogger.LogQuery("COMMIT", nil) + return target.Commit() +} + +func (c *logConnection) Rollback() error { + target, ok := c.target.(TxConnection) + if !ok { + return errors.ErrUnsupported + } + c.queryLogger.LogQuery("ROLLBACK", nil) + return target.Commit() +} + +func (c *logConnection) ListenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) (cancel func() error, err error) { + target, ok := c.target.(NotificationConnection) + if !ok { + return nil, errors.ErrUnsupported + } + c.queryLogger.LogQuery("LISTEN "+channel, nil) + return target.ListenChannel(ctx, channel, onNotify, onUnlisten) +} + +func (c *logConnection) UnlistenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc) error { + target, ok := c.target.(NotificationConnection) + if !ok { + return errors.ErrUnsupported + } + c.queryLogger.LogQuery("UNLISTEN "+channel, nil) + return target.UnlistenChannel(ctx, channel, onNotify) +} + +func (c *logConnection) IsListeningChannel(ctx context.Context, channel string) bool { + target, ok := c.target.(NotificationConnection) + if !ok { + return false + } + return target.IsListeningChannel(ctx, channel) +} + +func (c *logConnection) ListeningChannels(ctx context.Context) ([]string, error) { + target, ok := c.target.(NotificationConnection) + if !ok { + return nil, nil + } + return target.ListeningChannels(ctx) +} + +func (c *logConnection) NotifyChannel(ctx context.Context, channel, payload string) error { + target, ok := c.target.(NotificationConnection) + if !ok { + return errors.ErrUnsupported + } + query := "NOTIFY " + channel + if payload != "" { + query += " " + c.StringLiteral(payload) + } + c.queryLogger.LogQuery(query, nil) + return target.NotifyChannel(ctx, channel, payload) +} + +func (c *logConnection) Close() error { + return c.target.Close() +} diff --git a/mockconn/connection.go b/mockconn/connection.go index 27ae29b..78004cd 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -3,6 +3,7 @@ package mockconn import ( "context" "database/sql" + "database/sql/driver" "fmt" "io" "time" @@ -15,22 +16,28 @@ var DefaultArgFmt = "$%d" func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) sqldb.Connection { return &connection{ - ctx: ctx, - queryWriter: queryWriter, - listening: newBoolMap(), - rowsProvider: rowsProvider, - structFieldNamer: sqldb.DefaultStructFieldMapping, - argFmt: DefaultArgFmt, + ctx: ctx, + queryWriter: queryWriter, + listening: newBoolMap(), + rowsProvider: rowsProvider, + structFieldMapper: sqldb.DefaultStructFieldMapping, + argFmt: DefaultArgFmt, } } type connection struct { - ctx context.Context - queryWriter io.Writer - listening *boolMap - rowsProvider RowsProvider - structFieldNamer sqldb.StructFieldMapper - argFmt string + ctx context.Context + queryWriter io.Writer + listening *boolMap + rowsProvider RowsProvider + structFieldMapper sqldb.StructFieldMapper + converter driver.ValueConverter + argFmt string +} + +func (conn *connection) clone() *connection { + c := *conn + return &c } func (conn *connection) Context() context.Context { return conn.ctx } @@ -39,37 +46,19 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { if ctx == conn.ctx { return conn } - return &connection{ - ctx: ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldNamer: conn.structFieldNamer, - argFmt: conn.argFmt, - } + c := conn.clone() + c.ctx = ctx + return c } -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - return &connection{ - ctx: conn.ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldNamer: namer, - argFmt: conn.argFmt, - } +func (conn *connection) WithStructFieldMapper(mapper sqldb.StructFieldMapper) sqldb.Connection { + c := conn.clone() + c.structFieldMapper = mapper + return c } func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *connection) Stats() sql.DBStats { - return sql.DBStats{} -} - -func (conn *connection) Config() *sqldb.Config { - return &sqldb.Config{Driver: "mockconn", Host: "localhost", Database: "mock"} + return conn.structFieldMapper } func (conn *connection) ValidateColumnName(name string) error { @@ -80,31 +69,43 @@ func (conn *connection) Ping(time.Duration) error { return nil } +func (conn *connection) Stats() sql.DBStats { + return sql.DBStats{} +} + +func (conn *connection) Config() *sqldb.Config { + return &sqldb.Config{Driver: "mockconn", Host: "localhost", Database: "mock"} +} + func (conn *connection) Now() (time.Time, error) { return time.Now(), nil } func (conn *connection) Exec(query string, args ...any) error { + return impl.Exec(conn.ctx, conn, query, args, conn.converter, conn.argFmt) +} + +func (conn *connection) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, query) } - return nil + return nil, nil } func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, conn.argFmt, columValues) + return impl.Insert(conn.ctx, conn, table, columValues, conn.converter, conn.argFmt) } func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, conn.argFmt, values, onConflict) + return impl.InsertUnique(conn.ctx, conn, table, values, onConflict, conn.converter, conn.argFmt, conn.structFieldMapper) } func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, conn.argFmt, values, returning) + return impl.InsertReturning(conn.ctx, conn, table, values, returning, conn.converter, conn.argFmt, conn.structFieldMapper) } func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) + return impl.InsertStruct(conn.ctx, conn, table, rowStruct, conn.structFieldMapper, ignoreColumns, conn.converter, conn.argFmt) } func (conn *connection) InsertStructs(table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { @@ -112,27 +113,31 @@ func (conn *connection) InsertStructs(table string, rowStructs any, ignoreColumn } func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns) + return impl.InsertUniqueStruct(conn.ctx, conn, conn.structFieldMapper, table, rowStruct, onConflict, ignoreColumns, conn.converter, conn.argFmt) } func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, conn.argFmt, args) + return impl.Update(conn.ctx, conn, table, values, where, args, conn.converter, conn.argFmt) } func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) + return impl.UpdateReturningRow(conn.ctx, conn, table, values, returning, where, args, conn.converter, conn.argFmt, conn.structFieldMapper) } func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) + return impl.UpdateReturningRows(conn.ctx, conn, table, values, returning, where, args, conn.converter, conn.argFmt, conn.structFieldMapper) } func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) + return impl.UpdateStruct(conn.ctx, conn, table, rowStruct, conn.structFieldMapper, ignoreColumns, conn.converter, conn.argFmt) } func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) + return impl.UpsertStruct(conn.ctx, conn, table, rowStruct, conn.structFieldMapper, conn.argFmt, ignoreColumns) +} + +func (conn *connection) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + panic("todo") } func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { @@ -145,7 +150,7 @@ func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { if conn.rowsProvider == nil { return sqldb.RowScannerWithError(nil) } - return conn.rowsProvider.QueryRow(conn.structFieldNamer, query, args...) + return conn.rowsProvider.QueryRow(conn.structFieldMapper, query, args...) } func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { @@ -158,7 +163,7 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { if conn.rowsProvider == nil { return sqldb.RowsScannerWithError(nil) } - return conn.rowsProvider.QueryRows(conn.structFieldNamer, query, args...) + return conn.rowsProvider.QueryRows(conn.structFieldMapper, query, args...) } func (conn *connection) IsTransaction() bool { @@ -188,7 +193,7 @@ func (conn *connection) Rollback() error { return sqldb.ErrNotWithinTransaction } -func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { +func (conn *connection) ListenChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { conn.listening.Set(channel, true) if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, "LISTEN "+channel) diff --git a/mockconn/row.go b/mockconn/row.go index 579f754..5e160d5 100644 --- a/mockconn/row.go +++ b/mockconn/row.go @@ -21,7 +21,7 @@ type Row struct { func NewRow(rowStruct any, columnNamer sqldb.StructFieldMapper) *Row { val := reflect.ValueOf(rowStruct) - for val.Kind() == reflect.Ptr { + for val.Kind() == reflect.Pointer { val = val.Elem() } return &Row{ @@ -215,7 +215,7 @@ func convertAssign(dest, src any) error { } dpv := reflect.ValueOf(dest) - if dpv.Kind() != reflect.Ptr { + if dpv.Kind() != reflect.Pointer { return errors.New("destination not a pointer") } if dpv.IsNil() { @@ -248,7 +248,7 @@ func convertAssign(dest, src any) error { // This also allows scanning into user defined types such as "type Int int64". // For symmetry, also check for string destination types. switch dv.Kind() { - case reflect.Ptr: + case reflect.Pointer: if src == nil { dv.SetZero() return nil diff --git a/mockconn/rows.go b/mockconn/rows.go index 8698fe5..653cf37 100644 --- a/mockconn/rows.go +++ b/mockconn/rows.go @@ -20,7 +20,7 @@ func NewRowsFromStructs(rowStructs any, columnNamer sqldb.StructFieldMapper) *Ro if t.Kind() != reflect.Array && t.Kind() != reflect.Slice { panic("rowStructs must be array or slice of structs, but is " + t.String()) } - if t.Elem().Kind() != reflect.Struct && !(t.Elem().Kind() == reflect.Ptr && t.Elem().Elem().Kind() == reflect.Struct) { + if t.Elem().Kind() != reflect.Struct && !(t.Elem().Kind() == reflect.Pointer && t.Elem().Elem().Kind() == reflect.Struct) { panic("rowStructs element type must be struct or struct pointer, but is " + t.Elem().String()) } diff --git a/mockconn/transaction.go b/mockconn/transaction.go index 5611e23..c2dd01b 100644 --- a/mockconn/transaction.go +++ b/mockconn/transaction.go @@ -60,7 +60,7 @@ func (conn transaction) Rollback() error { return nil } -func (conn transaction) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { +func (conn transaction) ListenChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { return sqldb.ErrWithinTransaction } diff --git a/mysqlconn/connection.go b/mysqlconn/connection.go index 46be636..0f924eb 100644 --- a/mysqlconn/connection.go +++ b/mysqlconn/connection.go @@ -23,7 +23,16 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { if err != nil { return nil, err } - return impl.Connection(ctx, db, config, validateColumnName, argFmt), nil + return impl.NewGenericConnection( + ctx, + db, + config, + sqldb.UnsupportedListener(), + sqldb.DefaultStructFieldMapping, + validateColumnName, + nil, // No driver.ValueConverter necessary because MySQL doesn't suport arrays + argFmt, + ), nil } // MustNew creates a new sqldb.Connection using the passed sqldb.Config diff --git a/mysqlconn/go.mod b/mysqlconn/go.mod new file mode 100644 index 0000000..a677d05 --- /dev/null +++ b/mysqlconn/go.mod @@ -0,0 +1,8 @@ +module github.com/domonda/go-sqldb/mysqlconn + +go 1.21 + +require ( + github.com/domonda/go-sqldb v0.0.0-20230306182246-f05b99238fd7 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect +) diff --git a/mysqlconn/go.sum b/mysqlconn/go.sum new file mode 100644 index 0000000..05bfa62 --- /dev/null +++ b/mysqlconn/go.sum @@ -0,0 +1,4 @@ +github.com/domonda/go-sqldb v0.0.0-20230306182246-f05b99238fd7 h1:p37LtaROmulrFh9Ofp4QNDLF6rJE5BS/EDpcBQRnJag= +github.com/domonda/go-sqldb v0.0.0-20230306182246-f05b99238fd7/go.mod h1:2jzc1XhGjWXhB49ex/zPkRGPU/3oxIIVCD/N8DxsyCk= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= diff --git a/notificationconnection.go b/notificationconnection.go new file mode 100644 index 0000000..6b5f296 --- /dev/null +++ b/notificationconnection.go @@ -0,0 +1,56 @@ +package sqldb + +import "context" + +type ( + // OnNotifyFunc is a callback type passed to Connection.ListenChannel + OnNotifyFunc func(channel, payload string) + + // OnUnlistenFunc is a callback type passed to Connection.ListenChannel + OnUnlistenFunc func(channel string) +) + +type NotificationConnection interface { + Connection + + // ListenChannel will call onNotify for every channel notification + // and onUnlisten if the channel gets unlistened + // or the listener connection gets closed for some reason. + // + // It is valid to call ListenChannel multiple times for the same channel + // to register multiple callbacks. + // + // It is valid to pass nil for onNotify or onUnlisten to not get those callbacks. + // + // Note that the callbacks are called in sequence from a single go routine, + // so callbacks should offload long running or potentially blocking code to other go routines. + // + // Panics from callbacks will be recovered and logged. + // + // TODO calling the returned cancel function + // will cancel a particular watch. + ListenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) (cancel func() error, err error) + + // UnlistenChannel will stop listening on the channel. + // + // If the passed onNotify callback function is not nil, + // then only this callback will be unsubscribed but other + // callback might still be active. + // If nil is passed for onNotify, then all callbacks + // will be unsubscribed. + // + // An error is returned, when the channel was not listened to + // or the listener connection is closed + // or the passed onNotify callback was not subscribed with ListenChannel. + UnlistenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc) error + + // IsListeningChannel returns if a channel is listened to. + IsListeningChannel(ctx context.Context, channel string) bool + + // ListeningChannels returns all listened channel names. + ListeningChannels(ctx context.Context) ([]string, error) + + // NotifyChannel notifies a channel with the optional payload. + // If the payload is an empty string, then it won't be added to the notification. + NotifyChannel(ctx context.Context, channel, payload string) error +} diff --git a/nullable.go b/nullable.go index 5c6427d..77f7819 100644 --- a/nullable.go +++ b/nullable.go @@ -45,7 +45,7 @@ func IsNull(val any) bool { } switch v := reflect.ValueOf(val); v.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map: + case reflect.Pointer, reflect.Slice, reflect.Map: return v.IsNil() } diff --git a/nullconnection.go b/nullconnection.go new file mode 100644 index 0000000..e4c4f33 --- /dev/null +++ b/nullconnection.go @@ -0,0 +1,173 @@ +package sqldb + +import ( + "context" + "database/sql" + "reflect" + "time" +) + +func NullConnection(queryFormatter QueryFormatter, structFieldMapper StructFieldMapper) FullyFeaturedConnection { + return &nullConnection{ + queryFormatter: queryFormatter, + structFieldMapper: structFieldMapper, + } +} + +type nullConnection struct { + queryFormatter QueryFormatter + structFieldMapper StructFieldMapper + + maxParameters int + txNo uint64 + txOpts *sql.TxOptions +} + +func (c *nullConnection) Err() error { + return nil +} + +func (c *nullConnection) String() string { + return "NullConnection" +} + +func (c *nullConnection) DatabaseKind() string { + return "NullConnection" +} + +func (c *nullConnection) StringLiteral(s string) string { + if c.queryFormatter == nil { + return defaultQueryFormatter{}.StringLiteral(s) + } + return c.queryFormatter.StringLiteral(s) +} + +func (c *nullConnection) ArrayLiteral(array any) (string, error) { + if c.queryFormatter == nil { + return defaultQueryFormatter{}.ArrayLiteral(array) + } + return c.queryFormatter.ArrayLiteral(array) +} + +func (c *nullConnection) ParameterPlaceholder(index int) string { + if c.queryFormatter == nil { + return defaultQueryFormatter{}.ParameterPlaceholder(index) + } + return c.queryFormatter.ParameterPlaceholder(index) +} + +func (c *nullConnection) ValidateColumnName(name string) error { + if c.queryFormatter == nil { + return nil + } + return c.queryFormatter.ValidateColumnName(name) +} + +func (c *nullConnection) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { + if c.structFieldMapper == nil { + DefaultStructFieldMapping.MapStructField(field) + } + return c.structFieldMapper.MapStructField(field) +} + +func (c *nullConnection) MaxParameters() int { + return c.maxParameters +} + +func (c *nullConnection) DBStats() sql.DBStats { + return sql.DBStats{} +} + +func (c *nullConnection) Config() *Config { + return &Config{Driver: "NullConnection"} +} + +func (c *nullConnection) IsTransaction() bool { + return c.txNo > 0 +} + +func (c *nullConnection) Ping(ctx context.Context, timeout time.Duration) error { + return ctx.Err() +} + +func (c *nullConnection) Exec(ctx context.Context, query string, args ...any) error { + return ctx.Err() +} + +func (c *nullConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nullRows{}, nil +} + +func (c *nullConnection) DefaultIsolationLevel() sql.IsolationLevel { + return sql.LevelDefault +} + +func (c *nullConnection) TxNumber() uint64 { + return c.txNo +} + +func (c *nullConnection) TxOptions() (*sql.TxOptions, bool) { + return c.txOpts, c.IsTransaction() +} + +func (c *nullConnection) Begin(ctx context.Context, opts *sql.TxOptions, no uint64) (TxConnection, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + if c.IsTransaction() { + return nil, ErrWithinTransaction + } + tx := c + tx.txNo = no + tx.txOpts = opts + return tx, nil +} + +func (c *nullConnection) Commit() error { + if !c.IsTransaction() { + return ErrNotWithinTransaction + } + return nil +} + +func (c *nullConnection) Rollback() error { + if !c.IsTransaction() { + return ErrNotWithinTransaction + } + return nil +} + +func (c *nullConnection) ListenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) (cancel func() error, err error) { + return func() error { return nil }, ctx.Err() +} + +func (c *nullConnection) UnlistenChannel(ctx context.Context, channel string, onNotify OnNotifyFunc) error { + return ctx.Err() +} + +func (c *nullConnection) IsListeningChannel(ctx context.Context, channel string) bool { + return false +} + +func (c *nullConnection) ListeningChannels(ctx context.Context) ([]string, error) { + return nil, nil +} + +func (c *nullConnection) NotifyChannel(ctx context.Context, channel, payload string) error { + return ctx.Err() +} + +func (c *nullConnection) Close() error { + return nil +} + +type nullRows struct{} + +func (r nullRows) Columns() ([]string, error) { return nil, nil } +func (r nullRows) Scan(...any) error { return nil } +func (r nullRows) Close() error { return nil } +func (r nullRows) Next() bool { return false } +func (r nullRows) Err() error { return nil } diff --git a/pqconn/connection.go b/pqconn/connection.go index dada5d9..fcd27e4 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -3,8 +3,12 @@ package pqconn import ( "context" "database/sql" + "database/sql/driver" "fmt" - "time" + "regexp" + "strings" + + "github.com/lib/pq" "github.com/domonda/go-sqldb" "github.com/domonda/go-sqldb/impl" @@ -12,6 +16,29 @@ import ( const argFmt = "$%d" +var columnNameRegex = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]{0,58}$`) + +const maxParameters = 65534 + +func validateColumnName(name string) error { + if !columnNameRegex.MatchString(name) { + return fmt.Errorf("invalid Postgres column name: %q", name) + } + return nil +} + +type valueConverter struct{} + +func (valueConverter) ConvertValue(v any) (driver.Value, error) { + if valuer, ok := v.(driver.Valuer); ok { + return valuer.Value() + } + if sqldb.IsSliceOrArray(v) { + return pq.Array(v).Value() + } + return v, nil +} + // New creates a new sqldb.Connection using the passed sqldb.Config // and github.com/lib/pq as driver implementation. // The connection is pinged with the passed context @@ -26,12 +53,16 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { if err != nil { return nil, err } - return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldNamer: sqldb.DefaultStructFieldMapping, - }, nil + return impl.NewGenericConnection( + ctx, + db, + config, + Listener, + sqldb.DefaultStructFieldMapping, + validateColumnName, + valueConverter{}, + argFmt, + ), nil } // MustNew creates a new sqldb.Connection using the passed sqldb.Config @@ -47,175 +78,35 @@ func MustNew(ctx context.Context, config *sqldb.Config) sqldb.Connection { return conn } -type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper -} - -func (conn *connection) clone() *connection { - c := *conn - return &c -} - -func (conn *connection) Context() context.Context { return conn.ctx } - -func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.ctx { - return conn - } - c := conn.clone() - c.ctx = ctx - return c -} - -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *connection) Ping(timeout time.Duration) error { - ctx := conn.ctx - if timeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - return conn.db.PingContext(ctx) -} - -func (conn *connection) Stats() sql.DBStats { - return conn.db.Stats() -} - -func (conn *connection) Config() *sqldb.Config { - return conn.config -} - -func (conn *connection) ValidateColumnName(name string) error { - return validateColumnName(name) -} - -func (conn *connection) Now() (time.Time, error) { - return impl.Now(conn) -} - -func (conn *connection) Exec(query string, args ...any) error { - _, err := conn.db.ExecContext(conn.ctx, query, args...) - return wrapError(err, query, argFmt, args) -} - -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return WrapKnownErrors(impl.Insert(conn, table, argFmt, columValues)) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - inserted, err = impl.InsertUnique(conn, table, argFmt, values, onConflict) - return inserted, WrapKnownErrors(err) -} - -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, argFmt, values, returning) -} - -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return WrapKnownErrors(impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns)) -} - -func (conn *connection) InsertStructs(table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { - return WrapKnownErrors(impl.InsertStructs(conn, table, rowStructs, ignoreColumns...)) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - // TODO more error wrapping - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = wrapError(err, query, argFmt, args) - return sqldb.RowScannerWithError(err) +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` } - return impl.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) -} - -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = wrapError(err, query, argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return impl.NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, argFmt, args) -} - -func (conn *connection) IsTransaction() bool { - return false -} - -func (conn *connection) TransactionNo() uint64 { - return 0 -} - -func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (conn *connection) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { - tx, err := conn.db.BeginTx(conn.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn, tx, opts, no), nil -} - -func (conn *connection) Commit() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) Rollback() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return conn.getOrCreateListener().listenOnChannel(channel, onNotify, onUnlisten) -} - -func (conn *connection) UnlistenChannel(channel string) (err error) { - return conn.getListenerOrNil().unlistenChannel(channel) -} - -func (conn *connection) IsListeningOnChannel(channel string) bool { - return conn.getListenerOrNil().isListeningOnChannel(channel) -} - -func (conn *connection) Close() error { - conn.getListenerOrNil().close() - return conn.db.Close() + return literal } diff --git a/pqconn/go.mod b/pqconn/go.mod new file mode 100644 index 0000000..e70444a --- /dev/null +++ b/pqconn/go.mod @@ -0,0 +1,3 @@ +module github.com/domonda/go-sqldb/pqconn + +go 1.21 diff --git a/pqconn/listener.go b/pqconn/listener.go index 377bb49..c7bc61c 100644 --- a/pqconn/listener.go +++ b/pqconn/listener.go @@ -10,6 +10,66 @@ import ( "github.com/domonda/go-sqldb" ) +var ( + // ListenerEventLogger will log all subscribed channel listener events if not nil + ListenerEventLogger sqldb.Logger + + // Listener for all Postgres notifications + Listener = listenerImpl{} +) + +// type Listener interface { +// // ListenOnChannel will call onNotify for every channel notification +// // and onUnlisten if the channel gets unlistened +// // or the listener connection gets closed for some reason. +// // It is valid to pass nil for onNotify or onUnlisten to not get those callbacks. +// // Note that the callbacks are called in sequence from a single go routine, +// // so callbacks should offload long running or potentially blocking code to other go routines. +// // Panics from callbacks will be recovered and logged. +// ListenOnChannel(conn Connection, channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error + +// // UnlistenChannel will stop listening on the channel. +// // An error is returned, when the channel was not listened to +// // or the listener connection is closed. +// UnlistenChannel(conn Connection, channel string) error + +// // IsListeningOnChannel returns if a channel is listened to. +// IsListeningOnChannel(conn Connection, channel string) bool + +// // Close the listener. +// Close(conn Connection) error +// } + +type listenerImpl struct{} + +func (listenerImpl) ListenChannel(conn sqldb.Connection, channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) error { + if conn.IsTransaction() { + return sqldb.ErrWithinTransaction + } + return getOrCreateListener(conn.Config().ConnectURL()).listenOnChannel(channel, onNotify, onUnlisten) +} + +func (listenerImpl) UnlistenChannel(conn sqldb.Connection, channel string) error { + if conn.IsTransaction() { + return sqldb.ErrWithinTransaction + } + return getListenerOrNil(conn.Config().ConnectURL()).unlistenChannel(channel) +} + +func (listenerImpl) IsListeningChannel(conn sqldb.Connection, channel string) bool { + if conn.IsTransaction() { + return false + } + return getListenerOrNil(conn.Config().ConnectURL()).isListeningOnChannel(channel) +} + +func (listenerImpl) Close(conn sqldb.Connection) error { + if !conn.IsTransaction() { + getListenerOrNil(conn.Config().ConnectURL()).close() + } + return nil +} + var ( globalListeners = make(map[string]*listener) globalListenersMtx sync.RWMutex @@ -29,9 +89,7 @@ type listener struct { unlistenCallbacks map[string][]sqldb.OnUnlistenFunc } -func (conn *connection) getOrCreateListener() *listener { - connURL := conn.config.ConnectURL() - +func getOrCreateListener(connURL string) *listener { globalListenersMtx.Lock() defer globalListenersMtx.Unlock() @@ -58,9 +116,7 @@ func (conn *connection) getOrCreateListener() *listener { return l } -func (conn *connection) getListenerOrNil() *listener { - connURL := conn.config.ConnectURL() - +func getListenerOrNil(connURL string) *listener { globalListenersMtx.RLock() l := globalListeners[connURL] globalListenersMtx.RUnlock() @@ -102,6 +158,20 @@ func (l *listener) notify(notification *pq.Notification) { } } +func recoverAndLogListenerPanic(operation, channel string) { + p := recover() + switch { + case p == nil: + return + + case ListenerEventLogger != nil: + ListenerEventLogger.Printf("%s on channel %q paniced with: %+v", operation, channel, p) + + case sqldb.ErrLogger != nil: + sqldb.ErrLogger.Printf("%s on channel %q paniced with: %+v", operation, channel, p) + } +} + func (l *listener) safeNotifyCallback(callback sqldb.OnNotifyFunc, channel, payload string) { defer recoverAndLogListenerPanic("notify", channel) diff --git a/pqconn/logger.go b/pqconn/logger.go deleted file mode 100644 index f9bbdf2..0000000 --- a/pqconn/logger.go +++ /dev/null @@ -1,20 +0,0 @@ -package pqconn - -import "github.com/domonda/go-sqldb" - -// ListenerEventLogger will log all subscribed channel listener events if not nil -var ListenerEventLogger sqldb.Logger - -func recoverAndLogListenerPanic(operation, channel string) { - p := recover() - switch { - case p == nil: - return - - case ListenerEventLogger != nil: - ListenerEventLogger.Printf("%s on channel %q paniced with: %+v", operation, channel, p) - - case sqldb.ErrLogger != nil: - sqldb.ErrLogger.Printf("%s on channel %q paniced with: %+v", operation, channel, p) - } -} diff --git a/pqconn/postgres.go b/pqconn/postgres.go deleted file mode 100644 index f3d8431..0000000 --- a/pqconn/postgres.go +++ /dev/null @@ -1,15 +0,0 @@ -package pqconn - -import ( - "fmt" - "regexp" -) - -var columnNameRegex = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]{0,58}$`) - -func validateColumnName(name string) error { - if !columnNameRegex.MatchString(name) { - return fmt.Errorf("invalid Postgres column name: %q", name) - } - return nil -} diff --git a/pqconn/tests/docker-compose.yaml b/pqconn/tests/docker-compose.yaml new file mode 100644 index 0000000..a905639 --- /dev/null +++ b/pqconn/tests/docker-compose.yaml @@ -0,0 +1,14 @@ +services: + + db: + image: postgres:15.2 + restart: always + environment: + POSTGRES_PASSWORD: go-sqldb + # POSTGRES_INITDB_ARGS: + + # adminer: + # image: adminer + # restart: always + # ports: + # - 8080:8080 \ No newline at end of file diff --git a/pqconn/tests/go.mod b/pqconn/tests/go.mod new file mode 100644 index 0000000..5534b1b --- /dev/null +++ b/pqconn/tests/go.mod @@ -0,0 +1,37 @@ +module github.com/domonda/go-sqldb/pqconn/tests + +go 1.20 + +require ( + github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect + github.com/Microsoft/go-winio v0.6.0 // indirect + github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect + github.com/cenkalti/backoff/v4 v4.2.0 // indirect + github.com/containerd/continuity v0.3.0 // indirect + github.com/docker/cli v23.0.1+incompatible // indirect + github.com/docker/docker v23.0.1+incompatible // indirect + github.com/docker/go-connections v0.4.0 // indirect + github.com/docker/go-units v0.5.0 // indirect + github.com/domonda/go-sqldb v0.0.0-20230306182246-f05b99238fd7 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/imdario/mergo v0.3.13 // indirect + github.com/lib/pq v1.10.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/moby/term v0.0.0-20221205130635-1aeaba878587 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.0.2 // indirect + github.com/opencontainers/runc v1.1.4 // indirect + github.com/ory/dockertest/v3 v3.9.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect + golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb // indirect + golang.org/x/mod v0.9.0 // indirect + golang.org/x/net v0.8.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/tools v0.7.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/pqconn/tests/go.sum b/pqconn/tests/go.sum new file mode 100644 index 0000000..16cd306 --- /dev/null +++ b/pqconn/tests/go.sum @@ -0,0 +1,146 @@ +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg= +github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= +github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= +github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= +github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4= +github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/checkpoint-restore/go-criu/v5 v5.3.0/go.mod h1:E/eQpaFtUKGOOSEBZgmKAcn+zUUwWxqcaKZlF54wK8E= +github.com/cilium/ebpf v0.7.0/go.mod h1:/oI2+1shJiTGAMgl6/RgJr36Eo1jzrRcAWbcXO2usCA= +github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= +github.com/containerd/continuity v0.3.0 h1:nisirsYROK15TAMVukJOUyGJjz4BNQJBVsNvAXZJ/eg= +github.com/containerd/continuity v0.3.0/go.mod h1:wJEAIwKOm/pBZuBd0JmeTvnLquTB1Ag8espWhkykbPM= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cyphar/filepath-securejoin v0.2.3/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/cli v23.0.1+incompatible h1:LRyWITpGzl2C9e9uGxzisptnxAn1zfZKXy13Ul2Q5oM= +github.com/docker/cli v23.0.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/docker v23.0.1+incompatible h1:vjgvJZxprTTE1A37nm+CLNAdwu6xZekyoiVlUZEINcY= +github.com/docker/docker v23.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= +github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/domonda/go-sqldb v0.0.0-20230306182246-f05b99238fd7 h1:p37LtaROmulrFh9Ofp4QNDLF6rJE5BS/EDpcBQRnJag= +github.com/domonda/go-sqldb v0.0.0-20230306182246-f05b99238fd7/go.mod h1:2jzc1XhGjWXhB49ex/zPkRGPU/3oxIIVCD/N8DxsyCk= +github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/godbus/dbus/v5 v5.0.6/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk= +github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/moby/sys/mountinfo v0.5.0/go.mod h1:3bMD3Rg+zkqx8MRYPi7Pyb0Ie97QEBmdxbhnCLlSvSU= +github.com/moby/term v0.0.0-20221205130635-1aeaba878587 h1:HfkjXDfhgVaN5rmueG8cL8KKeFNecRCXFhaJ2qZ5SKA= +github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrBg0D7ufOcFM= +github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= +github.com/opencontainers/runc v1.1.4 h1:nRCz/8sKg6K6jgYAFLDlXzPeITBZJyX28DBVhWD+5dg= +github.com/opencontainers/runc v1.1.4/go.mod h1:1J5XiS+vdZ3wCyZybsuxXZWGrgSr8fFJHLXuG2PsnNg= +github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/selinux v1.10.0/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI= +github.com/ory/dockertest/v3 v3.9.1 h1:v4dkG+dlu76goxMiTT2j8zV7s4oPPEppKT8K8p2f1kY= +github.com/ory/dockertest/v3 v3.9.1/go.mod h1:42Ir9hmvaAPm0Mgibk6mBPi7SFvTXxEcnztDYOJ//uM= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/seccomp/libseccomp-golang v0.9.2-0.20220502022130-f33da4d89646/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= +github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= +github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb h1:PaBZQdo+iSDyHT053FjUCgZQ/9uqVwPOcl7KSWhKn6w= +golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= +golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= +golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pqconn/tests/testmain.go b/pqconn/tests/testmain.go new file mode 100644 index 0000000..42623a1 --- /dev/null +++ b/pqconn/tests/testmain.go @@ -0,0 +1,58 @@ +package tests + +import ( + "context" + "log" + "os" + "testing" + "time" + + "github.com/ory/dockertest/v3" + + "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/pqconn" +) + +func TestMain(m *testing.M) { + // uses a sensible default on windows (tcp/http) and linux/osx (socket) + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not construct pool: %s", err) + } + + // uses pool to try to connect to Docker + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + + // pulls an image, creates a container based on it and runs it + resource, err := pool.Run("postgres", "15.2", []string{"POSTGRES_PASSWORD=go-sqldb"}) + if err != nil { + log.Fatalf("Could not start resource: %s", err) + } + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + err = pool.Retry(func() error { + config := &sqldb.Config{ + Driver: "postgres", + } + conn, err := pqconn.New(context.Background(), config) + if err != nil { + return err + } + return conn.Ping(time.Second) + }) + if err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + code := m.Run() + + // You can't defer this because os.Exit doesn't care for defer + if err := pool.Purge(resource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + + os.Exit(code) +} diff --git a/pqconn/transaction.go b/pqconn/transaction.go deleted file mode 100644 index 80353eb..0000000 --- a/pqconn/transaction.go +++ /dev/null @@ -1,179 +0,0 @@ -package pqconn - -import ( - "context" - "database/sql" - "time" - - "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" -) - -type transaction struct { - // The parent non-transaction connection is needed - // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - no uint64 - structFieldNamer sqldb.StructFieldMapper -} - -func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions, no uint64) *transaction { - return &transaction{ - parent: parent, - tx: tx, - opts: opts, - no: no, - structFieldNamer: parent.structFieldNamer, - } -} - -func (conn *transaction) clone() *transaction { - c := *conn - return &c -} - -func (conn *transaction) Context() context.Context { return conn.parent.ctx } - -func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.parent.ctx { - return conn - } - parent := conn.parent.clone() - parent.ctx = ctx - return newTransaction(parent, conn.tx, conn.opts, conn.no) -} - -func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } -func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() } -func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } - -func (conn *transaction) ValidateColumnName(name string) error { - return validateColumnName(name) -} - -func (conn *transaction) Now() (time.Time, error) { - return impl.Now(conn) -} - -func (conn *transaction) Exec(query string, args ...any) error { - _, err := conn.tx.Exec(query, args...) - return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) -} - -func (conn *transaction) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, argFmt, columValues) -} - -func (conn *transaction) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, argFmt, values, onConflict) -} - -func (conn *transaction) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, argFmt, values, returning) -} - -func (conn *transaction) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) InsertStructs(table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStructs(conn, table, rowStructs, ignoreColumns...) -} - -func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) - if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) - return sqldb.RowScannerWithError(err) - } - return impl.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) -} - -func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) - if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return impl.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, argFmt, args) -} - -func (conn *transaction) IsTransaction() bool { - return true -} - -func (conn *transaction) TransactionNo() uint64 { - return conn.no -} - -func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn *transaction) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { - tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn.parent, tx, opts, no), nil -} - -func (conn *transaction) Commit() error { - return conn.tx.Commit() -} - -func (conn *transaction) Rollback() error { - return conn.tx.Rollback() -} - -func (conn *transaction) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return sqldb.ErrWithinTransaction -} - -func (conn *transaction) UnlistenChannel(channel string) (err error) { - return sqldb.ErrWithinTransaction -} - -func (conn *transaction) IsListeningOnChannel(channel string) bool { - return false -} - -func (conn *transaction) Close() error { - return conn.Rollback() -} diff --git a/queries.go b/queries.go new file mode 100644 index 0000000..60d2995 --- /dev/null +++ b/queries.go @@ -0,0 +1,315 @@ +package sqldb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +// Now returns the result of the SQL NOW() +// function for the current connection. +// Useful for getting the timestamp of a +// SQL transaction for use in Go code. +func Now(ctx context.Context) (time.Time, error) { + return QueryValue[time.Time](ctx, `SELECT NOW()`) +} + +// Exec executes a query with optional args. +func Exec(ctx context.Context, query string, args ...any) error { + conn := ContextConnection(ctx) + err := conn.Exec(ctx, query, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, conn) + } + return nil +} + +// QueryRow queries a single row and returns a RowScanner for the results. +func QueryRow(ctx context.Context, query string, args ...any) RowScanner { + conn := ContextConnection(ctx) + rows, err := conn.Query(ctx, query, args...) + if err != nil { + rows = RowsWithError(err) + } + return NewRowScanner(rows, query, args, conn) +} + +// QueryValue queries a single value of type T. +func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) { + err = QueryRow(ctx, query, args...).Scan(&value) + if err != nil { + var zero T + return zero, err + } + return value, nil +} + +// QueryValueOr queries a single value of type T +// or returns the passed defaultValue in case of sql.ErrNoRows. +func QueryValueOr[T any](ctx context.Context, defaultValue T, query string, args ...any) (value T, err error) { + err = QueryRow(ctx, query, args...).Scan(&value) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return defaultValue, nil + } + var zero T + return zero, err + } + return value, err +} + +// QueryRowStruct queries a row and scans it as struct. +func QueryRowStruct[S any](ctx context.Context, query string, args ...any) (row *S, err error) { + err = QueryRow(ctx, query, args...).ScanStruct(&row) + if err != nil { + return nil, err + } + return row, nil +} + +// QueryRowStructOrNil queries a row and scans it as struct +// or returns nil in case of sql.ErrNoRows. +func QueryRowStructOrNil[S any](ctx context.Context, query string, args ...any) (row *S, err error) { + err = QueryRow(ctx, query, args...).ScanStruct(&row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return row, nil +} + +// GetRowStruct uses the passed pkValue+pkValues to query a table row +// and scan it into a struct of type S that must have tagged fields +// with primary key flags to identify the primary key column names +// for the passed pkValue+pkValues and a table name. +func GetRowStruct[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { + // Using explicit first pkValue value + // to not be able to compile without any value + pkValues = append([]any{pkValue}, pkValues...) + t := reflect.TypeOf(row).Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct template type instead of %s", t) + } + conn := ContextConnection(ctx) + table, pkColumns, err := pkColumnsOfStruct(conn, t) + if err != nil { + return nil, err + } + if len(pkColumns) != len(pkValues) { + return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns)) + } + var query strings.Builder + fmt.Fprintf(&query, `SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) //#nosec G104 + for i := 1; i < len(pkColumns); i++ { + fmt.Fprintf(&query, ` AND "%s" = $%d`, pkColumns[i], i+1) //#nosec G104 + } + return QueryRowStruct[S](ctx, query.String(), pkValues...) +} + +// GetRowStructOrNil uses the passed pkValue+pkValues to query a table row +// and scan it into a struct of type S that must have tagged fields +// with primary key flags to identify the primary key column names +// for the passed pkValue+pkValues and a table name. +// Returns nil as row and error if no row could be found with the +// passed pkValue+pkValues. +func GetRowStructOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { + row, err = GetRowStruct[S](ctx, pkValue, pkValues...) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return row, nil +} + +// QueryRowsAsSlice scans one value per row into one slice element of rowVals. +// dest must be a pointer to a slice with a row value compatible element type. TODO +// +// In case of a cancelled context the rows scanned before the cancellation +// will be returned together with the context error. +func QueryRowsAsSlice[T any](ctx context.Context, query string, args ...any) (rows []T, err error) { + conn := ContextConnection(ctx) + defer WrapResultErrorWithQuery(&err, query, args, conn) + + srcRows, err := conn.Query(ctx, query, args...) + if err != nil { + return nil, err + } + + var elem T + scanningStructs := isStructRowType(reflect.TypeOf(elem)) + for srcRows.Next() { + if ctx.Err() != nil { + return rows, ctx.Err() + } + if scanningStructs { + err = ScanStruct(srcRows, &elem, conn) + } else { + err = srcRows.Scan(&elem) + } + if err != nil { + return rows, err + } + rows = append(rows, elem) + } + return rows, srcRows.Err() +} + +func isStructRowType(t reflect.Type) bool { + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + return false + } + if t.Implements(typeOfSQLScanner) { + return false + } + if reflect.PointerTo(t).Implements(typeOfSQLScanner) { + return false + } + return true +} + +// QueryWithRowCallback will call the passed callback function with scanned values +// or a struct for every row. +// If the callback function has a single struct or struct pointer argument, +// then RowScanner.ScanStruct will be used per row, +// else RowScanner.Scan will be used for all arguments of the callback. +// If the function has a context.Context as first argument, +// then the context of the query call will be passed on. +// The callback can have no result or a single error result value. +// If a non nil error is returned from the callback, then this error +// is returned immediately by this function without scanning further rows. +// In case of zero rows, no error will be returned. +func QueryWithRowCallback[F any](ctx context.Context, callback F, query string, args ...any) (err error) { + conn := ContextConnection(ctx) + defer WrapResultErrorWithQuery(&err, query, args, conn) + + funcVal := reflect.ValueOf(callback) + funcType := funcVal.Type() + if funcType.Kind() != reflect.Func { + return fmt.Errorf("expected callback function, got %s", funcType) + } + if funcType.IsVariadic() { + return fmt.Errorf("callback function must not be varidic: %s", funcType) + } + if funcType.NumIn() == 0 || (funcType.NumIn() == 1 && funcType.In(0) == typeOfContext) { + return fmt.Errorf("callback function has no arguments: %s", funcType) + } + firstArgIndex := 0 + hasCtxArg := false + if funcType.In(0) == typeOfContext { + hasCtxArg = true + firstArgIndex = 1 + } + hasStructArg := false + for i := firstArgIndex; i < funcType.NumIn(); i++ { + t := funcType.In(i) + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + switch t.Kind() { + case reflect.Struct: + if t.Implements(typeOfSQLScanner) || reflect.PointerTo(t).Implements(typeOfSQLScanner) { + continue + } + if hasStructArg { + return fmt.Errorf("callback function must not have further argument after struct: %s", funcType) + } + hasStructArg = true + case reflect.Chan, reflect.Func: + return fmt.Errorf("callback function has invalid argument type: %s", funcType.In(i)) + } + } + if funcType.NumOut() == 1 && funcType.Out(0) != typeOfError { + return fmt.Errorf("callback function result must be of type error: %s", funcType) + } + if funcType.NumOut() > 1 { + return fmt.Errorf("callback function can only have one result value: %s", funcType) + } + + rows, err := conn.Query(ctx, query, args...) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + if ctx.Err() != nil { + return ctx.Err() + } + + // First scan row + scannedValPtrs := make([]any, funcType.NumIn()-firstArgIndex) + for i := range scannedValPtrs { + scannedValPtrs[i] = reflect.New(funcType.In(firstArgIndex + i)).Interface() + } + if hasStructArg { + err = ScanStruct(rows, scannedValPtrs[0], conn) + } else { + err = rows.Scan(scannedValPtrs...) + } + if err != nil { + return err + } + + // Then do callback via reflection + args := make([]reflect.Value, funcType.NumIn()) + if hasCtxArg { + args[0] = reflect.ValueOf(ctx) + } + for i := firstArgIndex; i < len(args); i++ { + args[i] = reflect.ValueOf(scannedValPtrs[i-firstArgIndex]).Elem() + } + result := funcVal.Call(args) + if len(result) > 0 && !result[0].IsNil() { + return result[0].Interface().(error) + } + } + return rows.Err() +} + +// QueryStrings returns the queried row values as strings. +// Byte slices will be interpreted as strings, +// nil (SQL NULL) will be converted to an empty string, +// all other types are converted with fmt.Sprint. +// The first row is a header with the column names. +func QueryStrings(ctx context.Context, query string, args ...any) (rows [][]string, err error) { + conn := ContextConnection(ctx) + defer WrapResultErrorWithQuery(&err, query, args, conn) + + srcRows, err := conn.Query(ctx, query, args...) + if err != nil { + return nil, err + } + + cols, err := srcRows.Columns() + if err != nil { + return nil, err + } + rows = [][]string{cols} + stringScannablePtrs := make([]any, len(cols)) + for srcRows.Next() { + if ctx.Err() != nil { + return rows, ctx.Err() + } + row := make([]string, len(cols)) + for i := range stringScannablePtrs { + stringScannablePtrs[i] = (*StringScannable)(&row[i]) + } + err := srcRows.Scan(stringScannablePtrs...) + if err != nil { + return rows, err + } + rows = append(rows, row) + } + return rows, srcRows.Err() +} diff --git a/row.go b/row.go new file mode 100644 index 0000000..20dce44 --- /dev/null +++ b/row.go @@ -0,0 +1,165 @@ +package sqldb + +import ( + "database/sql" +) + +// Row is an interface with the methods of sql.Rows +// that are needed for ScanStruct. +// Allows mocking for tests without an SQL driver. +type Row interface { + // Columns returns the column names. + Columns() ([]string, error) + + // Scan copies the columns in the current row into the values pointed + // at by dest. The number of values in dest must be the same as the + // number of columns in Rows. + Scan(dest ...any) error +} + +// RowScanner scans the values from a single row. +type RowScanner interface { + // Columns returns the column names. + Columns() ([]string, error) + + // Scan values of a row into dest variables, which must be passed as pointers. + Scan(dest ...any) error + + // ScanStruct scans values of a row into a dest struct which must be passed as pointer. + ScanStruct(dest any) error + + // ScanValues returns the values of a row exactly how they are + // passed from the database driver to an sql.Scanner. + // Byte slices will be copied. + ScanValues() ([]any, error) + + // ScanStrings scans the values of a row as strings. + // Byte slices will be interpreted as strings, + // nil (SQL NULL) will be converted to an empty string, + // all other types are converted with fmt.Sprint(src). + ScanStrings() ([]string, error) +} + +// rowScanner implements RowScanner for Rows +type rowScanner struct { + row Rows + query string // for error wrapping + args []any // for error wrapping + conn Connection // for error wrapping +} + +func NewRowScanner(row Rows, query string, args []any, conn Connection) RowScanner { + return &rowScanner{row: row, query: query, args: args, conn: conn} +} + +func (s *rowScanner) Columns() (columns []string, err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + return s.row.Columns() +} + +func (s *rowScanner) Scan(dest ...any) (err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + if s.row.Err() != nil { + return s.row.Err() + } + if !s.row.Next() { + if s.row.Err() != nil { + return s.row.Err() + } + return sql.ErrNoRows + } + + return s.row.Scan(dest...) +} + +func (s *rowScanner) ScanStruct(dest any) (err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + if s.row.Err() != nil { + return s.row.Err() + } + if !s.row.Next() { + if s.row.Err() != nil { + return s.row.Err() + } + return sql.ErrNoRows + } + + return ScanStruct(s.row, dest, s.conn) +} + +func (s *rowScanner) ScanValues() (vals []any, err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + if s.row.Err() != nil { + return nil, s.row.Err() + } + if !s.row.Next() { + if s.row.Err() != nil { + return nil, s.row.Err() + } + return nil, sql.ErrNoRows + } + + return ScanValues(s.row) +} + +func (s *rowScanner) ScanStrings() (strs []string, err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + if s.row.Err() != nil { + return nil, s.row.Err() + } + if !s.row.Next() { + if s.row.Err() != nil { + return nil, s.row.Err() + } + return nil, sql.ErrNoRows + } + + return ScanStrings(s.row) +} + +// currentRowScanner implements RowScanner for Row +type currentRowScanner struct { + row Row + query string // for error wrapping + args []any // for error wrapping + conn Connection // for error wrapping +} + +func NewCurrentRowScanner(row Row, query string, args []any, conn Connection) RowScanner { + return ¤tRowScanner{row: row, query: query, args: args, conn: conn} +} + +func (s *currentRowScanner) Columns() (columns []string, err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + return s.row.Columns() +} + +func (s *currentRowScanner) Scan(dest ...any) (err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + return s.row.Scan(dest...) +} + +func (s *currentRowScanner) ScanStruct(dest any) (err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + return ScanStruct(s.row, dest, s.conn) +} + +func (s *currentRowScanner) ScanValues() (vals []any, err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + return ScanValues(s.row) +} + +func (s *currentRowScanner) ScanStrings() (strs []string, err error) { + defer WrapResultErrorWithQuery(&err, s.query, s.args, s.conn) + + return ScanStrings(s.row) +} diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..f98cd6c --- /dev/null +++ b/rows.go @@ -0,0 +1,45 @@ +package sqldb + +// Rows is an interface with the methods of sql.Rows. +// Allows mocking for tests without an SQL driver. +type Rows interface { + // Columns returns the column names. + Columns() ([]string, error) + + // Scan copies the columns in the current row into the values pointed + // at by dest. The number of values in dest must be the same as the + // number of columns in Rows. + Scan(dest ...any) error + + // Close closes the Rows, preventing further enumeration. If Next is called + // and returns false and there are no further result sets, + // the Rows are closed automatically and it will suffice to check the + // result of Err. Close is idempotent and does not affect the result of Err. + Close() error + + // Next prepares the next result row for reading with the Scan method. It + // returns true on success, or false if there is no next result row or an error + // happened while preparing it. Err should be consulted to distinguish between + // the two cases. + // + // Every call to Scan, even the first one, must be preceded by a call to Next. + Next() bool + + // Err returns the error, if any, that was encountered during iteration. + // Err may be called after an explicit or implicit Close. + Err() error +} + +func RowsWithError(err error) Rows { + return rowsWithError{err} +} + +type rowsWithError struct { + err error +} + +func (r rowsWithError) Columns() ([]string, error) { return nil, r.err } +func (r rowsWithError) Scan(dest ...any) error { return r.err } +func (r rowsWithError) Close() error { return nil } +func (r rowsWithError) Next() bool { return false } +func (r rowsWithError) Err() error { return r.err } diff --git a/rowscanner.go b/rowscanner.go deleted file mode 100644 index 0507364..0000000 --- a/rowscanner.go +++ /dev/null @@ -1,24 +0,0 @@ -package sqldb - -// RowScanner scans the values from a single row. -type RowScanner interface { - // Scan values of a row into dest variables, which must be passed as pointers. - Scan(dest ...any) error - - // ScanStruct scans values of a row into a dest struct which must be passed as pointer. - ScanStruct(dest any) error - - // ScanValues returns the values of a row exactly how they are - // passed from the database driver to an sql.Scanner. - // Byte slices will be copied. - ScanValues() ([]any, error) - - // ScanStrings scans the values of a row as strings. - // Byte slices will be interpreted as strings, - // nil (SQL NULL) will be converted to an empty string, - // all other types are converted with fmt.Sprint(src). - ScanStrings() ([]string, error) - - // Columns returns the column names. - Columns() ([]string, error) -} diff --git a/rowsscanner.go b/rowsscanner.go deleted file mode 100644 index c318deb..0000000 --- a/rowsscanner.go +++ /dev/null @@ -1,45 +0,0 @@ -package sqldb - -// RowsScanner scans the values from multiple rows. -type RowsScanner interface { - // ScanSlice scans one value per row into one slice element of dest. - // dest must be a pointer to a slice with a row value compatible element type. - // In case of zero rows, dest will be set to nil and no error will be returned. - // In case of an error, dest will not be modified. - // It is an error to query more than one column. - ScanSlice(dest any) error - - // ScanStructSlice scans every row into the struct fields of dest slice elements. - // dest must be a pointer to a slice of structs or struct pointers. - // In case of zero rows, dest will be set to nil and no error will be returned. - // In case of an error, dest will not be modified. - // Every mapped struct field must have a corresponding column in the query results. - ScanStructSlice(dest any) error - - // ScanAllRowsAsStrings scans the values of all rows as strings. - // Byte slices will be interpreted as strings, - // nil (SQL NULL) will be converted to an empty string, - // all other types are converted with fmt.Sprint. - // If true is passed for headerRow, then a row - // with the column names will be prepended. - ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) - - // Columns returns the column names. - Columns() ([]string, error) - - // ForEachRow will call the passed callback with a RowScanner for every row. - // In case of zero rows, no error will be returned. - ForEachRow(callback func(RowScanner) error) error - - // ForEachRowCall will call the passed callback with scanned values or a struct for every row. - // If the callback function has a single struct or struct pointer argument, - // then RowScanner.ScanStruct will be used per row, - // else RowScanner.Scan will be used for all arguments of the callback. - // If the function has a context.Context as first argument, - // then the context of the query call will be passed on. - // The callback can have no result or a single error result value. - // If a non nil error is returned from the callback, then this error - // is returned immediately by this function without scanning further rows. - // In case of zero rows, no error will be returned. - ForEachRowCall(callback any) error -} diff --git a/rowvalues.go b/rowvalues.go new file mode 100644 index 0000000..49d6ebf --- /dev/null +++ b/rowvalues.go @@ -0,0 +1,11 @@ +package sqldb + +type RowValues interface { + Columns() []string + RowValues() ([]any, error) +} + +type RowPointers interface { + Columns() []string + RowPointers() ([]any, error) +} diff --git a/scan.go b/scan.go index 3f958b0..12dd5d7 100644 --- a/scan.go +++ b/scan.go @@ -9,6 +9,102 @@ import ( "time" ) +// ScanValues returns the values of a row exactly how they are +// passed from the database driver to an sql.Scanner. +// Byte slices will be copied. +func ScanValues(src Row) ([]any, error) { + cols, err := src.Columns() + if err != nil { + return nil, err + } + var ( + anys = make([]AnyValue, len(cols)) + result = make([]any, len(cols)) + ) + // result elements hold pointer to AnyValue for scanning + for i := range result { + result[i] = &anys[i] + } + err = src.Scan(result...) + if err != nil { + return nil, err + } + // don't return pointers to AnyValue + // but what internal value has been scanned + for i := range result { + result[i] = anys[i].Val + } + return result, nil +} + +// ScanStrings scans the values of a row as strings. +// Byte slices will be interpreted as strings, +// nil (SQL NULL) will be converted to an empty string, +// all other types are converted with fmt.Sprint. +func ScanStrings(src Row) ([]string, error) { + cols, err := src.Columns() + if err != nil { + return nil, err + } + var ( + result = make([]string, len(cols)) + resultPtrs = make([]any, len(cols)) + ) + for i := range resultPtrs { + resultPtrs[i] = (*StringScannable)(&result[i]) + } + err = src.Scan(resultPtrs...) + if err != nil { + return nil, err + } + return result, nil +} + +func ScanStruct(srcRow Row, destStruct any, mapper StructFieldMapper) error { + v := reflect.ValueOf(destStruct) + for v.Kind() == reflect.Pointer && !v.IsNil() { + v = v.Elem() + } + + var ( + setDestStructPtr = false + destStructPtr reflect.Value + newStructPtr reflect.Value + ) + if v.Kind() == reflect.Pointer && v.IsNil() && v.CanSet() { + // Got a nil pointer that we can set with a newly allocated struct + setDestStructPtr = true + destStructPtr = v + newStructPtr = reflect.New(v.Type().Elem()) + // Continue with the newly allocated struct + v = newStructPtr.Elem() + } + if v.Kind() != reflect.Struct { + return fmt.Errorf("ScanStruct: expected struct but got %T", destStruct) + } + + columns, err := srcRow.Columns() + if err != nil { + return err + } + + fieldPointers, err := MapStructFieldPointersForColumns(mapper, v, columns) + if err != nil { + return fmt.Errorf("ScanStruct: %w", err) + } + + err = srcRow.Scan(fieldPointers...) + if err != nil { + return err + } + + if setDestStructPtr { + destStructPtr.Set(newStructPtr) + } + + return nil +} + // ScanDriverValue scans a driver.Value into destPtr. func ScanDriverValue(destPtr any, value driver.Value) error { if destPtr == nil { @@ -20,7 +116,7 @@ func ScanDriverValue(destPtr any, value driver.Value) error { } dest := reflect.ValueOf(destPtr) - if dest.Kind() != reflect.Ptr { + if dest.Kind() != reflect.Pointer { return fmt.Errorf("can't scan non-pointer %s", dest.Type()) } dest = dest.Elem() @@ -100,7 +196,7 @@ func ScanDriverValue(destPtr any, value driver.Value) error { return nil } switch dest.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map: + case reflect.Pointer, reflect.Slice, reflect.Map: dest.SetZero() return nil } diff --git a/serializedtransaction.go b/serializedtransaction.go new file mode 100644 index 0000000..afeadf2 --- /dev/null +++ b/serializedtransaction.go @@ -0,0 +1,117 @@ +package sqldb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "sync/atomic" +) + +var ( + // Number of retries used for a SerializedTransaction + // before it fails + SerializedTransactionRetries = 10 + + serializedTransactionCtxKey int + + savepointCount atomic.Uint64 +) + +// SerializedTransaction executes txFunc "serially" within a database transaction that is passed in to txFunc via the context. +// Use db.ContextConnection(ctx) to get the transaction connection within txFunc. +// Transaction returns all errors from txFunc or transaction commit errors happening after txFunc. +// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx Connection +// and no parentConn.Begin, Commit, or Rollback calls will occour within this Transaction call. +// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. +// Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. +// +// Serialized transactions are typically necessary when an insert depends on a previous select within +// the transaction, but that pre-insert select can't lock the table like it's possible with SELECT FOR UPDATE. +// During transaction execution, the isolation level "Serializable" is set. This does not mean +// that the transaction will be run in series. On the contrary, it actually means that Postgres will +// track read/write dependencies and will report an error in case other concurrent transactions +// have altered the results of the statements within this transaction. If no serialisation is possible, +// raw Postgres error will be: +// ``` +// ERROR: could not serialize access due to read/write dependencies among transactions +// HINT: The transaction might succeed if retried. +// ``` +// or +// ``` +// ERROR: could not serialize access due to concurrent update +// HINT: The transaction might succeed if retried. +// ``` +// In this case, retry the whole transaction (as Postgres hints). This works simply +// because if you run the transaction for the second (or Nth) time, the queries will +// yield different results therefore altering the end result. +// +// SerializedTransaction calls can be nested, in which case nested calls just execute the +// txFunc within the parent's serialized transaction. +// It's not valid to nest a SerializedTransaction within a normal Transaction function +// because in this case serialization retries can't be delegated up to the +// partent transaction that doesn't know anything about serialization. +// +// Because of the retryable nature, please be careful with the size of the transaction and the retry cost. +func SerializedTransaction(ctx context.Context, txFunc func(context.Context) error) error { + // Pass nested serialized transactions through + if ContextConnection(ctx).IsTransaction() { + if ctx.Value(&serializedTransactionCtxKey) == nil { + return errors.New("SerializedTransaction called from within a non-serialized transaction") + } + return txFunc(ctx) + } + + // Add value to context to check for nested serialized transactions + ctx = context.WithValue(ctx, &serializedTransactionCtxKey, struct{}{}) + + opts := sql.TxOptions{Isolation: sql.LevelSerializable} + for i := 0; i < SerializedTransactionRetries; i++ { + err := TransactionOpts(ctx, &opts, txFunc) + if err == nil || !strings.Contains(err.Error(), "could not serialize access") { + return err // nil or err + } + } + + return errors.New("SerializedTransaction retried too many times") +} + +// TransactionSavepoint executes txFunc within a database transaction or uses savepoints for rollback. +// If the passed context already has a database transaction connection, +// then a savepoint with a random name is created before the execution of txFunc. +// If txFunc returns an error, then the transaction is rolled back to the savepoint +// but the transaction from the context is not rolled back. +// If the passed context does not have a database transaction connection, +// then Transaction(ctx, txFunc) is called without savepoints. +// Use db.ContextConnection(ctx) to get the transaction connection within txFunc. +// TransactionSavepoint returns all errors from txFunc, transaction, savepoint, and rollback errors. +// Panics from txFunc are not recovered to rollback to the savepoint, +// they should behandled by the parent Transaction function. +func TransactionSavepoint(ctx context.Context, txFunc func(context.Context) error) error { + conn := ContextConnection(ctx) + if !conn.IsTransaction() { + // If not already in a transaction, then execute txFunc + // within a as transaction instead of using savepoints: + return Transaction(ctx, txFunc) + } + + savepoint := fmt.Sprintf("SP%d", savepointCount.Add(1)) + + err := conn.Exec(ctx, "SAVEPOINT "+savepoint) + if err != nil { + return err + } + + err = txFunc(ctx) + if err != nil { + e := conn.Exec(ctx, "ROLLBACK TO "+savepoint) + if e != nil && !errors.Is(e, sql.ErrTxDone) { + // Double error situation, wrap err with e so it doesn't get lost + err = fmt.Errorf("TransactionSavepoint error (%s) from rollback after error: %w", e, err) + } + return err + } + + return conn.Exec(ctx, "RELEASE SAVEPOINT "+savepoint) +} diff --git a/serializedtransaction_test.go b/serializedtransaction_test.go new file mode 100644 index 0000000..ab60d21 --- /dev/null +++ b/serializedtransaction_test.go @@ -0,0 +1,109 @@ +package sqldb + +// func TestSerializedTransaction(t *testing.T) { +// globalConn = mockconn.New(context.Background(), os.Stdout, nil) + +// expectSerialized := func(ctx context.Context) error { +// if !ContextConnection(ctx).IsTransaction() { +// panic("not in transaction") +// } +// if ctx.Value(&serializedTransactionCtxKey) == nil { +// panic("no SerializedTransaction") +// } +// return nil +// } + +// expectSerializedWithError := func(ctx context.Context) error { +// if !ContextConnection(ctx).IsTransaction() { +// panic("not in transaction") +// } +// if ctx.Value(&serializedTransactionCtxKey) == nil { +// panic("no SerializedTransaction") +// } +// return errors.New("expected error") +// } + +// nestedSerializedTransaction := func(ctx context.Context) error { +// return SerializedTransaction(ctx, expectSerialized) +// } + +// okNestedTransaction := func(ctx context.Context) error { +// return Transaction(ctx, nestedSerializedTransaction) +// } + +// type args struct { +// ctx context.Context +// txFunc func(context.Context) error +// } +// tests := []struct { +// name string +// args args +// wantErr bool +// }{ +// {name: "flat call", args: args{ctx: context.Background(), txFunc: expectSerialized}, wantErr: false}, +// {name: "expect error", args: args{ctx: context.Background(), txFunc: expectSerializedWithError}, wantErr: true}, +// {name: "nested call", args: args{ctx: context.Background(), txFunc: nestedSerializedTransaction}, wantErr: false}, +// {name: "nested tx call", args: args{ctx: context.Background(), txFunc: okNestedTransaction}, wantErr: false}, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if err := SerializedTransaction(tt.args.ctx, tt.args.txFunc); (err != nil) != tt.wantErr { +// t.Errorf("SerializedTransaction() error = %v, wantErr %v", err, tt.wantErr) +// } +// }) +// } +// } + +// func TestTransaction(t *testing.T) { +// globalConn = mockconn.New(context.Background(), os.Stdout, nil) + +// expectNonSerialized := func(ctx context.Context) error { +// if !ContextConnection(ctx).IsTransaction() { +// panic("not in transaction") +// } +// if ctx.Value(&serializedTransactionCtxKey) != nil { +// panic("SerializedTransaction") +// } +// return nil +// } + +// expectNonSerializedWithError := func(ctx context.Context) error { +// if !ContextConnection(ctx).IsTransaction() { +// panic("not in transaction") +// } +// if ctx.Value(&serializedTransactionCtxKey) != nil { +// panic("SerializedTransaction") +// } +// return errors.New("expected error") +// } + +// nestedTransaction := func(ctx context.Context) error { +// return Transaction(ctx, expectNonSerialized) +// } + +// nestedSerializedTransaction := func(ctx context.Context) error { +// return SerializedTransaction(ctx, nestedTransaction) +// } + +// type args struct { +// ctx context.Context +// txFunc func(context.Context) error +// } +// tests := []struct { +// name string +// args args +// wantErr bool +// }{ +// {name: "flat call", args: args{ctx: context.Background(), txFunc: expectNonSerialized}, wantErr: false}, +// {name: "expected error", args: args{ctx: context.Background(), txFunc: expectNonSerializedWithError}, wantErr: true}, +// {name: "nested call", args: args{ctx: context.Background(), txFunc: nestedTransaction}, wantErr: false}, +// {name: "nested serialized", args: args{ctx: context.Background(), txFunc: nestedSerializedTransaction}, wantErr: true}, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if err := Transaction(tt.args.ctx, tt.args.txFunc); (err != nil) != tt.wantErr { +// t.Errorf("Transaction() error = %v, wantErr %v", err, tt.wantErr) +// } +// }) +// } +// } diff --git a/structfieldmapping.go b/structfieldmapping.go index 533c972..e648a55 100644 --- a/structfieldmapping.go +++ b/structfieldmapping.go @@ -1,9 +1,11 @@ package sqldb import ( + "errors" "fmt" "reflect" "strings" + "sync" "github.com/domonda/go-types/strutil" ) @@ -38,8 +40,8 @@ type StructFieldMapper interface { // MapStructField returns the column name for a reflected struct field // and flags for special column properies. // If false is returned for use then the field is not mapped. - // An empty name and true for use indicates an embedded struct - // field whose fields should be recursively mapped. + // An empty string for column and true for use indicates an + // embedded struct field whose fields should be mapped recursively. MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) } @@ -85,14 +87,16 @@ type TaggedStructFieldMapping struct { func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { if field.Anonymous { + if field.Type == typeOfTableName { + table, _ = field.Tag.Lookup(m.NameTag) + return table, "", 0, false + } column, hasTag := field.Tag.Lookup(m.NameTag) if !hasTag { // Embedded struct fields are ok if not tagged with IgnoreName return "", "", 0, true } - if i := strings.IndexByte(column, ','); i != -1 { - column = column[:i] - } + column, _, _ = strings.Cut(column, ",") // Embedded struct fields are ok if not tagged with IgnoreName return "", "", 0, column != m.Ignore } @@ -147,4 +151,276 @@ func IgnoreStructField(string) string { return "" } // before every new upper case character in s. // Whitespace, symbol, and punctuation characters // will be replace by '_'. -var ToSnakeCase = strutil.ToSnakeCase +func ToSnakeCase(s string) string { + return strutil.ToSnakeCase(s) +} + +type MappedStructField struct { + Field reflect.StructField + Table string + Column string + Flags FieldFlag +} + +type MappedStruct struct { + Type reflect.Type + Table string + Fields []MappedStructField + Columns []string + ColumnFields map[string]*MappedStructField +} + +func (m *MappedStruct) StructFieldValues(structVal reflect.Value) (values []any, err error) { + if structVal.Type() != m.Type { + return nil, fmt.Errorf("can't return StructFieldValues of type %s for MappedStruct type %s", structVal.Type(), m.Type) + } + values = make([]any, len(m.Fields)) + for i, m := range m.Fields { + values[i] = structVal.FieldByIndex(m.Field.Index).Interface() + } + return values, nil +} + +func (m *MappedStruct) StructFieldPointers(structVal reflect.Value) (values []any, err error) { + if !structVal.CanAddr() { + return nil, errors.New("struct can't be addressed") + } + if structVal.Type() != m.Type { + return nil, fmt.Errorf("can't return StructFieldPointers of type %s for MappedStruct type %s", structVal.Type(), m.Type) + } + values = make([]any, len(m.Fields)) + for i, m := range m.Fields { + values[i] = structVal.FieldByIndex(m.Field.Index).Addr().Interface() + } + return values, nil +} + +var ( + mappedStructTypeCache = make(map[StructFieldMapper]map[reflect.Type]*MappedStruct) + mappedStructTypeCacheMtx sync.Mutex +) + +func mapStructType(mapper StructFieldMapper, structType reflect.Type, mapped *MappedStruct) error { + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + table, column, flags, use := mapper.MapStructField(field) + + if table != "" { + if mapped.Table != "" && table != mapped.Table { + return fmt.Errorf("conflicting tables %s and %s found in struct %s", mapped.Table, table, structType) + } + mapped.Table = table + } + + if !use { + continue + } + + if column == "" { + // Embedded struct field + t := field.Type + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + err := mapStructType(mapper, t, mapped) + if err != nil { + return err + } + continue + } + + if _, exists := mapped.ColumnFields[column]; exists { + return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, field.Name, structType) + } + + mapped.Fields = append(mapped.Fields, MappedStructField{ + Field: field, + Table: table, + Column: column, + Flags: flags, + }) + mapped.Columns = append(mapped.Columns, column) + mapped.ColumnFields[column] = &mapped.Fields[len(mapped.Fields)-1] + } + return nil +} + +func MapStructType(mapper StructFieldMapper, structType reflect.Type) (*MappedStruct, error) { + if structType.Kind() == reflect.Pointer { + structType = structType.Elem() + } + if structType.Kind() != reflect.Struct { + return nil, fmt.Errorf("MapStructType called with non struct type %s", structType) + } + + mappedStructTypeCacheMtx.Lock() + defer mappedStructTypeCacheMtx.Unlock() + + mapped := mappedStructTypeCache[mapper][structType] + if mapped != nil { + return mapped, nil + } + + mapped = &MappedStruct{Type: structType} + err := mapStructType(mapper, structType, mapped) + if err != nil { + return nil, err + } + if mappedStructTypeCache[mapper] == nil { + mappedStructTypeCache[mapper] = make(map[reflect.Type]*MappedStruct) + } + mappedStructTypeCache[mapper][structType] = mapped + return mapped, nil +} + +func MapStruct(mapper StructFieldMapper, s any) (mapped *MappedStruct, structVal reflect.Value, err error) { + structVal = reflect.ValueOf(s) + for structVal.Kind() == reflect.Pointer && !structVal.IsNil() { + structVal = structVal.Elem() + } + if structVal.Kind() != reflect.Struct { + return nil, reflect.Value{}, fmt.Errorf("expected struct but got %T", s) + } + mapped, err = MapStructType(mapper, structVal.Type()) + return mapped, structVal, err +} + +func MapStructFieldValues(mapper StructFieldMapper, s any) (columns []string, values []any, table string, err error) { + mapped, structVal, err := MapStruct(mapper, s) + if err != nil { + return nil, nil, "", err + } + values, err = mapped.StructFieldValues(structVal) + if err != nil { + return nil, nil, "", err + } + return mapped.Columns, values, mapped.Table, nil +} + +func MapStructFieldPointers(mapper StructFieldMapper, s any) (columns []string, pointers []any, table string, err error) { + mapped, structVal, err := MapStruct(mapper, s) + if err != nil { + return nil, nil, "", err + } + pointers, err = mapped.StructFieldPointers(structVal) + if err != nil { + return nil, nil, "", err + } + return mapped.Columns, pointers, mapped.Table, nil +} + +func MapStructFieldPointersForColumns(mapper StructFieldMapper, s any, columns []string) (pointers []any, err error) { + mapped, structVal, err := MapStruct(mapper, s) + if err != nil { + return nil, err + } + if !structVal.CanAddr() { + return nil, errors.New("struct can't be addressed") + } + // if len(mapped.Fields) > len(columns) { + // // TODO optional error handling + // } + pointers = make([]any, len(columns)) + for i, column := range columns { + m, ok := mapped.ColumnFields[column] + if !ok { + // TODO optional error handling + pointers[i] = new(AnyValue) + continue + } + pointers[i] = structVal.FieldByIndex(m.Field.Index).Addr().Interface() + } + return pointers, nil +} + +func pkColumnsOfStruct(mapper StructFieldMapper, t reflect.Type) (table string, columns []string, err error) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldTable, column, flags, ok := mapper.MapStructField(field) + if !ok { + continue + } + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) + } + table = fieldTable + } + + if column == "" { + fieldTable, columnsEmbed, err := pkColumnsOfStruct(mapper, field.Type) + if err != nil { + return "", nil, err + } + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) + } + table = fieldTable + } + columns = append(columns, columnsEmbed...) + } else if flags.PrimaryKey() { + // if err = conn.ValidateColumnName(column); err != nil { + // return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) + // } + columns = append(columns, column) + } + } + return table, columns, nil +} + +// func MapStructFieldPointers(mapper StructFieldMapper, strct any) (colFieldPtrs map[string]any, table string, err error) { +// v := reflect.ValueOf(strct) +// for v.Kind() == reflect.Pointer && !v.IsNil() { +// v = v.Elem() +// } +// if v.Kind() != reflect.Struct { +// return nil, "", fmt.Errorf("expected struct but got %T", strct) +// } +// if !v.CanAddr() { +// return nil, "", errors.New("struct can't be addressed") +// } + +// mapped, err := getOrCreateMappedStruct(mapper, v.Type()) +// if err != nil { +// return nil, "", err +// } + +// colFieldPtrs = make(map[string]any, len(mapped.Fields)) +// for column, mapped := range mapped.ColumnFields { +// field, err := v.FieldByIndexErr(mapped.Field.Index) +// if err != nil { +// return nil, "", err +// } +// colFieldPtrs[column] = field.Addr().Interface() +// } +// return colFieldPtrs, mapped.Table, nil +// } + +// func MapStructFieldColumnPointers(mapper StructFieldMapper, structVal any, columns []string) (ptrs []any, table string, colsWithoutField, fieldsWithoutCol []string, err error) { +// v := reflect.ValueOf(structVal) +// for v.Kind() == reflect.Pointer && !v.IsNil() { +// v = v.Elem() +// } +// if v.Kind() != reflect.Struct { +// return nil, "", fmt.Errorf("expected struct but got %T", structVal) +// } +// if !v.CanAddr() { +// return nil, "", errors.New("struct can't be addressed") +// } + +// mapped, err := MapStructType(mapper, v.Type()) +// if err != nil { +// return nil, "", err +// } + +// colFieldPtrs = make(map[string]any, len(mapped.Fields)) +// for column, mapped := range mapped.ColumnFields { +// field, err := v.FieldByIndexErr(mapped.Field.Index) +// if err != nil { +// return nil, "", err +// } +// colFieldPtrs[column] = field.Addr().Interface() +// } +// return colFieldPtrs, mapped.Table, nil +// } diff --git a/structfieldmapping_test.go b/structfieldmapping_test.go index 4fcece5..545ed47 100644 --- a/structfieldmapping_test.go +++ b/structfieldmapping_test.go @@ -37,6 +37,8 @@ func TestTaggedStructFieldMapping_StructFieldName(t *testing.T) { } type AnonymousEmbedded struct{} var s struct { + TableName `db:"public.my_table"` // Field(0) + Index int `db:"index,pk=public.my_table"` // Field(0) IndexB int `db:"index_b,pk"` // Field(1) Str string `db:"named_str"` // Field(2) @@ -56,33 +58,34 @@ func TestTaggedStructFieldMapping_StructFieldName(t *testing.T) { wantTable string wantColumn string wantFlags FieldFlag - wantOk bool + wantUse bool }{ - {name: "index", structField: st.Field(0), wantTable: "public.my_table", wantColumn: "index", wantFlags: FieldFlagPrimaryKey, wantOk: true}, - {name: "index_b", structField: st.Field(1), wantTable: "", wantColumn: "index_b", wantFlags: FieldFlagPrimaryKey, wantOk: true}, - {name: "named_str", structField: st.Field(2), wantColumn: "named_str", wantFlags: 0, wantOk: true}, - {name: "read_only", structField: st.Field(3), wantColumn: "read_only", wantFlags: FieldFlagReadOnly, wantOk: true}, - {name: "untagged_field", structField: st.Field(4), wantColumn: "untagged_field", wantFlags: 0, wantOk: true}, - {name: "ignore", structField: st.Field(5), wantColumn: "", wantFlags: 0, wantOk: false}, - {name: "pk_read_only", structField: st.Field(6), wantColumn: "pk_read_only", wantFlags: FieldFlagPrimaryKey | FieldFlagReadOnly, wantOk: true}, - {name: "no_flag", structField: st.Field(7), wantColumn: "no_flag", wantFlags: 0, wantOk: true}, - {name: "malformed_flags", structField: st.Field(8), wantColumn: "malformed_flags", wantFlags: FieldFlagReadOnly, wantOk: true}, - {name: "Embedded", structField: st.Field(9), wantColumn: "", wantFlags: 0, wantOk: true}, + {name: "TableName", structField: st.Field(0), wantTable: "public.my_table", wantColumn: "", wantFlags: 0, wantUse: false}, + {name: "index", structField: st.Field(1), wantTable: "public.my_table", wantColumn: "index", wantFlags: FieldFlagPrimaryKey, wantUse: true}, + {name: "index_b", structField: st.Field(2), wantTable: "", wantColumn: "index_b", wantFlags: FieldFlagPrimaryKey, wantUse: true}, + {name: "named_str", structField: st.Field(3), wantColumn: "named_str", wantFlags: 0, wantUse: true}, + {name: "read_only", structField: st.Field(4), wantColumn: "read_only", wantFlags: FieldFlagReadOnly, wantUse: true}, + {name: "untagged_field", structField: st.Field(5), wantColumn: "untagged_field", wantFlags: 0, wantUse: true}, + {name: "ignore", structField: st.Field(6), wantColumn: "", wantFlags: 0, wantUse: false}, + {name: "pk_read_only", structField: st.Field(7), wantColumn: "pk_read_only", wantFlags: FieldFlagPrimaryKey | FieldFlagReadOnly, wantUse: true}, + {name: "no_flag", structField: st.Field(8), wantColumn: "no_flag", wantFlags: 0, wantUse: true}, + {name: "malformed_flags", structField: st.Field(9), wantColumn: "malformed_flags", wantFlags: FieldFlagReadOnly, wantUse: true}, + {name: "AnonymousEmbedded", structField: st.Field(10), wantColumn: "", wantFlags: 0, wantUse: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotTable, gotColumn, gotFlags, gotOk := naming.MapStructField(tt.structField) + gotTable, gotColumn, gotFlags, gotUse := naming.MapStructField(tt.structField) if gotTable != tt.wantTable { - t.Errorf("TaggedStructFieldMapping.MapStructField(%q) gotTable = %q, want %q", tt.structField.Name, gotTable, tt.wantTable) + t.Errorf("TaggedStructFieldMapping.MapStructField(%#v) gotTable = %#v, want %#v", tt.structField.Name, gotTable, tt.wantTable) } if gotColumn != tt.wantColumn { - t.Errorf("TaggedStructFieldMapping.MapStructField(%q) gotColumn = %q, want %q", tt.structField.Name, gotColumn, tt.wantColumn) + t.Errorf("TaggedStructFieldMapping.MapStructField(%#v) gotColumn = %#v, want %#v", tt.structField.Name, gotColumn, tt.wantColumn) } if gotFlags != tt.wantFlags { - t.Errorf("TaggedStructFieldMapping.MapStructField(%q) gotFlags = %v, want %v", tt.structField.Name, gotFlags, tt.wantFlags) + t.Errorf("TaggedStructFieldMapping.MapStructField(%#v) gotFlags =%#v, want %#v", tt.structField.Name, gotFlags, tt.wantFlags) } - if gotOk != tt.wantOk { - t.Errorf("TaggedStructFieldMapping.MapStructField(%q) gotOk = %v, want %v", tt.structField.Name, gotOk, tt.wantOk) + if gotUse != tt.wantUse { + t.Errorf("TaggedStructFieldMapping.MapStructField(%#v) gotOk = %#v, want %#v", tt.structField.Name, gotUse, tt.wantUse) } }) } diff --git a/tablename.go b/tablename.go new file mode 100644 index 0000000..bdcf9de --- /dev/null +++ b/tablename.go @@ -0,0 +1,9 @@ +package sqldb + +type TableName struct{} + +func (TableName) TableNameMarker() {} + +type RowWithTableName interface { + TableNameMarker() +} diff --git a/transaction.go b/transaction.go index 9cae3e3..80360a1 100644 --- a/transaction.go +++ b/transaction.go @@ -1,52 +1,117 @@ package sqldb import ( + "context" "database/sql" "errors" "fmt" "sync/atomic" ) -var txCounter atomic.Uint64 +var txCount atomic.Uint64 -// NextTransactionNo returns the next globally unique number +// NextTxNumber returns the next globally unique number // for a new transaction in a threadsafe way. // -// Use Connection.TransactionNo() to get the number +// Use TxConnection.TxNumber() to get the number // from a transaction connection. -func NextTransactionNo() uint64 { - return txCounter.Add(1) +func NextTxNumber() uint64 { + return txCount.Add(1) } -// Transaction executes txFunc within a database transaction that is passed in to txFunc as tx Connection. -// Transaction returns all errors from txFunc or transaction commit errors happening after txFunc. +// ToTxConnection returns the passed Connection +// as TxConnection if implemented +// or else an ErrorConnection with an error +// that wraps errors.ErrUnsupported. +func ToTxConnection(conn Connection) TxConnection { + tx, err := AsTxConnection(conn) + if err != nil { + return ErrorConnection(err) + } + return tx +} + +func AsTxConnection(conn Connection) (TxConnection, error) { + if tx, ok := conn.(TxConnection); ok { + return tx, nil + } + return nil, fmt.Errorf("%w: %s does not implement TxConnection", errors.ErrUnsupported, conn) +} + +// TxConnection is a connection that supports transactions. +// +// This does not mean that every TxConnection represents +// a separate connection for an active transaction, +// only if it was returned for a new transaction by +// the Begin method. +type TxConnection interface { + Connection + + DefaultIsolationLevel() sql.IsolationLevel + + // TxNumber returns the globally unique number of the transaction + // or zero if the connection is not a transaction. + // Implementations should use the package function NextTxNumber + // to aquire a new number in a threadsafe way. + TxNumber() uint64 + + // TxOptions returns the sql.TxOptions of the + // current transaction and true as second result value, + // or false if the connection is not a transaction. + TxOptions() (*sql.TxOptions, bool) + + // Begin a new transaction. + // If the connection is already a transaction then a brand + // new transaction will begin on the parent's connection. + // The passed no will be returnd from the transaction's + // Connection.TxNumber method. + // Implementations should use the package function NextTxNumber + // to aquire a new number in a threadsafe way. + Begin(ctx context.Context, opts *sql.TxOptions, no uint64) (TxConnection, error) + + // Commit the current transaction. + // Returns ErrNotWithinTransaction if the connection + // is not within a transaction. + Commit() error + + // Rollback the current transaction. + // Returns ErrNotWithinTransaction if the connection + // is not within a transaction. + Rollback() error +} + +// TransactionOpts executes txFunc within a database transaction with sql.TxOptions that is passed in to txFunc via the context. +// Use db.ContextConnection(ctx) to get the transaction connection within txFunc. +// TransactionOpts returns all errors from txFunc or transaction commit errors happening after txFunc. // If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx Connection -// and no parentConn.Begin, Commit, or Rollback calls will occour within this Transaction call. -// An error is returned, if the requested transaction options passed via opts -// are stricter than the options of the parent transaction. +// and no parentConn.Begin, Commit, or Rollback calls will occour within this TransactionOpts call. // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. -func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { - if parentOpts, parentIsTx := parentConn.TransactionOptions(); parentIsTx { - err = CheckTxOptionsCompatibility(parentOpts, opts, parentConn.Config().DefaultIsolationLevel) +func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(context.Context) error) (err error) { + // Don't shadow err result var! + txConnection, e := AsTxConnection(ContextConnection(ctx)) + if e != nil { + return e + } + + if parentOpts, isTransation := txConnection.TxOptions(); isTransation { + // txConn is already a transaction connection + // so don't begin a new transaction, + // just execute txFunc within the current transaction + // if the TxOptions are compatible + err = CheckTxOptionsCompatibility(parentOpts, opts, txConnection.DefaultIsolationLevel()) if err != nil { return err } - return txFunc(parentConn) + return txFunc(ContextWithConnection(ctx, txConnection)) } - return IsolatedTransaction(parentConn, opts, txFunc) -} -// IsolatedTransaction executes txFunc within a database transaction that is passed in to txFunc as tx Connection. -// IsolatedTransaction returns all errors from txFunc or transaction commit errors happening after txFunc. -// If parentConn is already a transaction, a brand new transaction will begin on the parent's connection. -// Errors and panics from txFunc will rollback the transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. -func IsolatedTransaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { - txNo := NextTransactionNo() - tx, e := parentConn.Begin(opts, txNo) + // Execute txFunc within new transaction + txNumber := NextTxNumber() + // Don't shadow err result var! + tx, e := txConnection.Begin(ctx, opts, txNumber) if e != nil { - return fmt.Errorf("Transaction %d Begin error: %w", txNo, e) + return fmt.Errorf("Transaction %d Begin error: %w", txNumber, e) } defer func() { @@ -55,7 +120,7 @@ func IsolatedTransaction(parentConn Connection, opts *sql.TxOptions, txFunc func e := tx.Rollback() if e != nil && !errors.Is(e, sql.ErrTxDone) { // Double error situation, log e so it doesn't get lost - ErrLogger.Printf("Transaction %d error (%s) from rollback after panic: %+v", txNo, e, r) + ErrLogger.Printf("Transaction %d error (%s) from rollback after panic: %+v", txNumber, e, r) } panic(r) // re-throw panic after Rollback } @@ -65,7 +130,7 @@ func IsolatedTransaction(parentConn Connection, opts *sql.TxOptions, txFunc func e := tx.Rollback() if e != nil && !errors.Is(e, sql.ErrTxDone) { // Double error situation, wrap err with e so it doesn't get lost - err = fmt.Errorf("Transaction %d error (%s) from rollback after error: %w", txNo, e, err) + err = fmt.Errorf("Transaction %d error (%s) from rollback after error: %w", txNumber, e, err) } return } @@ -73,11 +138,71 @@ func IsolatedTransaction(parentConn Connection, opts *sql.TxOptions, txFunc func e := tx.Commit() if e != nil { // Set Commit error as function return value - err = fmt.Errorf("Transaction %d Commit error: %w", txNo, e) + err = fmt.Errorf("Transaction %d Commit error: %w", txNumber, e) } }() - return txFunc(tx) + return txFunc(ContextWithConnection(ctx, tx)) +} + +func Transaction(ctx context.Context, txFunc func(context.Context) error) (err error) { + return TransactionOpts(ctx, nil, txFunc) +} + +// TransactionReadOnly executes txFunc within a read-only database transaction that is passed in to txFunc via the context. +// Use db.ContextConnection(ctx) to get the transaction connection within txFunc. +// TransactionReadOnly returns all errors from txFunc or transaction commit errors happening after txFunc. +// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx Connection +// and no parentConn.Begin, Commit, or Rollback calls will occour within this TransactionReadOnly call. +// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. +// Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. +func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error) error { + return TransactionOpts(ctx, &sql.TxOptions{ReadOnly: true}, txFunc) +} + +// DebugNoTransaction executes nonTxFunc without a database transaction. +// Useful to temporarely replace Transaction to debug the same code without using a transaction. +func DebugNoTransaction(ctx context.Context, nonTxFunc func(context.Context) error) error { + return nonTxFunc(ctx) +} + +// DebugNoTransactionOpts executes nonTxFunc without a database transaction. +// Useful to temporarely replace TransactionOpts to debug the same code without using a transaction. +func DebugNoTransactionOpts(ctx context.Context, opts *sql.TxOptions, nonTxFunc func(context.Context) error) error { + return nonTxFunc(ctx) +} + +// IsTransaction indicates if the connection from the context +// (or the global connection if the context has none) +// is a transaction. +func IsTransaction(ctx context.Context) bool { + return ContextConnection(ctx).IsTransaction() +} + +// ValidateWithinTransaction returns ErrNotWithinTransaction +// if the database connection from the context is not a transaction. +func ValidateWithinTransaction(ctx context.Context) error { + conn := ContextConnection(ctx) + if err := conn.Err(); err != nil { + return err + } + if !conn.IsTransaction() { + return ErrNotWithinTransaction + } + return nil +} + +// ValidateNotWithinTransaction returns ErrWithinTransaction +// if the database connection from the context is a transaction. +func ValidateNotWithinTransaction(ctx context.Context) error { + conn := ContextConnection(ctx) + if err := conn.Err(); err != nil { + return err + } + if conn.IsTransaction() { + return ErrWithinTransaction + } + return nil } // CheckTxOptionsCompatibility returns an error diff --git a/types.go b/types.go new file mode 100644 index 0000000..b0c455e --- /dev/null +++ b/types.go @@ -0,0 +1,19 @@ +package sqldb + +import ( + "context" + "database/sql" + "database/sql/driver" + "reflect" +) + +var ( + typeOfByte = reflect.TypeOf(byte(0)) + typeOfError = reflect.TypeOf((*error)(nil)).Elem() + typeOfByteSlice = reflect.TypeOf((*[]byte)(nil)).Elem() + typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem() + typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + typeOfDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + typeOfTableName = reflect.TypeOf(TableName{}) + // typeOfTime = reflect.TypeOf(time.Time{}) +) diff --git a/values.go b/values.go index f11e8bd..39876cc 100644 --- a/values.go +++ b/values.go @@ -1,23 +1,77 @@ package sqldb -import "sort" +import ( + "database/sql/driver" + "errors" + "reflect" + "sort" -// Values is a map from column names to values + "golang.org/x/exp/maps" +) + +// Values implements RowValues +var _ RowValues = Values(nil) + +// Values is a map from column names to values. +// It implements RowValues. type Values map[string]any -// Sorted returns the names and values from the Values map -// as separated slices sorted by name. -func (v Values) Sorted() (names []string, values []any) { - names = make([]string, 0, len(v)) - for name := range v { - names = append(names, name) +// Sorted returns the columns and values from the Values map +// as separated slices sorted by column. +func (v Values) Sorted() (colums []string, values []any) { + colums = v.Columns() + values = make([]any, len(v)) + for i, col := range colums { + values[i] = v[col] } - sort.Strings(names) + return colums, values +} - values = make([]any, len(v)) - for i, name := range names { - values[i] = v[name] +func (v Values) Columns() []string { + cols := maps.Keys(v) + sort.Strings(cols) + return cols +} + +func (v Values) RowValues() ([]any, error) { + values := make([]any, len(v)) + for i, col := range v.Columns() { + values[i] = v[col] } + return values, nil +} + +func mapKeysAndValues(v reflect.Value) (keys []string, values []any) { + k := v.MapKeys() + sort.Slice(k, func(i, j int) bool { + return k[i].String() < k[j].String() + }) - return names, values + keys = make([]string, len(k)) + for i, key := range k { + keys[i] = key.String() + } + + values = make([]any, len(k)) + for i, key := range k { + values[i] = v.MapIndex(key).Interface() + } + + return keys, values +} + +func convertValuesInPlace(values []any, converter driver.ValueConverter) error { + if converter == nil { + return nil + } + var err error + for i, value := range values { + v, e := converter.ConvertValue(value) + if e != nil { + err = errors.Join(err, e) + continue + } + values[i] = v + } + return err }