diff --git a/columnfilter.go b/columnfilter.go index a70ea89..cc829f3 100644 --- a/columnfilter.go +++ b/columnfilter.go @@ -2,20 +2,22 @@ package sqldb import ( "reflect" + + "github.com/domonda/go-sqldb/reflection" ) type ColumnFilter interface { - IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool + IgnoreColumn(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool } -type ColumnFilterFunc func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool +type ColumnFilterFunc func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool -func (f ColumnFilterFunc) IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +func (f ColumnFilterFunc) IgnoreColumn(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return f(name, flags, fieldType, fieldValue) } func IgnoreColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, ignore := range names { if name == ignore { return true @@ -26,7 +28,7 @@ func IgnoreColumns(names ...string) ColumnFilter { } func OnlyColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, include := range names { if name == include { return false @@ -37,7 +39,7 @@ func OnlyColumns(names ...string) ColumnFilter { } func IgnoreStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, ignore := range names { if fieldType.Name == ignore { return true @@ -48,7 +50,7 @@ func IgnoreStructFields(names ...string) ColumnFilter { } func OnlyStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { for _, include := range names { if fieldType.Name == include { return false @@ -58,32 +60,40 @@ func OnlyStructFields(names ...string) ColumnFilter { }) } -func IgnoreFlags(ignore FieldFlag) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +func IgnoreFlags(ignore reflection.StructFieldFlags) ColumnFilter { + return ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags&ignore != 0 }) } -var IgnoreDefault ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.Default() +var IgnoreHasDefault ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return flags.HasDefault() }) -var IgnorePrimaryKey ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnorePrimaryKey ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.PrimaryKey() }) -var IgnoreReadOnly ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreReadOnly ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return flags.ReadOnly() }) -var IgnoreNull ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreNull ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return IsNull(fieldValue) }) -var IgnoreNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { +var IgnoreNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { return IsNullOrZero(fieldValue) }) -var IgnoreNullOrZeroDefault ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.Default() && IsNullOrZero(fieldValue) +var IgnoreHasDefaultNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return flags.HasDefault() && IsNullOrZero(fieldValue) }) + +type noColumnFilter struct{} + +func (noColumnFilter) IgnoreColumn(name string, flags reflection.StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { + return false +} + +var AllColumns noColumnFilter diff --git a/config.go b/config.go index 978225f..8694dce 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package sqldb import ( "context" "database/sql" + "errors" "fmt" "net/url" "time" @@ -11,20 +12,44 @@ import ( // Config for a connection. // For tips see https://www.alexedwards.net/blog/configuring-sqldb type Config struct { - Driver string `json:"driver"` - Host string `json:"host"` - Port uint16 `json:"port,omitempty"` - User string `json:"user,omitempty"` - Password string `json:"password,omitempty"` - Database string `json:"database"` - Extra map[string]string `json:"misc,omitempty"` - MaxOpenConns int `json:"maxOpenConns,omitempty"` - MaxIdleConns int `json:"maxIdleConns,omitempty"` - ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"` + Driver string `json:"driver"` + Host string `json:"host"` + Port uint16 `json:"port,omitempty"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + Database string `json:"database"` + Extra map[string]string `json:"misc,omitempty"` + MaxOpenConns int `json:"maxOpenConns,omitempty"` + MaxIdleConns int `json:"maxIdleConns,omitempty"` + ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"` + + // ValidateColumnName returns an error + // if the passed name is not valid for a + // column of the connection's database. + ValidateColumnName func(name string) error `json:"-"` + + // ParamPlaceholder returns a parameter value placeholder + // for the parameter with the passed zero based index + // specific to the database type of the connection. + ParamPlaceholderFormatter `json:"-"` + DefaultIsolationLevel sql.IsolationLevel `json:"-"` - Err error `json:"-"` + + // Err will be returned from Connection.Err() + Err error `json:"-"` } +// func (c *DBConnection) ValidateColumnName(name string) error { +// if name == "" { +// return errors.New("empty column name") +// } +// return nil +// } + +// func (c *DBConnection) ParamPlaceholder(index int) string { +// return fmt.Sprintf(":%d", index+1) +// } + // Validate returns Config.Err if it is not nil // or an error if the Config does not have // a Driver, Host, or Database. @@ -32,19 +57,25 @@ func (c *Config) Validate() error { if c.Err != nil { return c.Err } + if c.ValidateColumnName == nil { + return errors.New("missing sqldb.Config.ValidateColumnName") + } + if c.ParamPlaceholderFormatter == nil { + return errors.New("missing sqldb.Config.ParamPlaceholderFormatter") + } if c.Driver == "" { - return fmt.Errorf("missing sqldb.Config.Driver") + return errors.New("missing sqldb.Config.Driver") } if c.Host == "" { - return fmt.Errorf("missing sqldb.Config.Host") + return errors.New("missing sqldb.Config.Host") } if c.Database == "" { - return fmt.Errorf("missing sqldb.Config.Database") + return errors.New("missing sqldb.Config.Database") } return nil } -// ConnectURL for connecting to a database +// ConnectURL returns a connection URL for the Config func (c *Config) ConnectURL() string { extra := make(url.Values) for key, val := range c.Extra { diff --git a/connection.go b/connection.go index eff08b7..e2bfd26 100644 --- a/connection.go +++ b/connection.go @@ -16,119 +16,44 @@ type ( // Connection represents a database connection or transaction type Connection interface { - // Context that all connection operations use. - // See also WithContext. - Context() context.Context - - // WithContext returns a connection that uses the passed - // context for its operations. - WithContext(ctx context.Context) Connection - - // WithStructFieldMapper returns a copy of the connection - // that will use the passed StructFieldMapper. - WithStructFieldMapper(StructFieldMapper) Connection + // Config returns the configuration used + // to create this connection. + Config() *Config - // StructFieldMapper used by methods of this Connection. - StructFieldMapper() StructFieldMapper + // Stats returns the sql.DBStats of this connection. + Stats() sql.DBStats // Ping returns an error if the database // does not answer on this connection // with an optional timeout. // The passed timeout has to be greater zero // to be considered. - Ping(timeout time.Duration) error - - // Stats returns the sql.DBStats of this connection. - Stats() sql.DBStats - - // Config returns the configuration used - // to create this connection. - Config() *Config - - // ValidateColumnName returns an error - // if the passed name is not valid for a - // column of the connection's database. - ValidateColumnName(name string) error + Ping(ctx context.Context, timeout time.Duration) error - // Now returns the result of the SQL now() - // function for the current connection. - // Useful for getting the timestamp of a - // SQL transaction for use in Go code. - Now() (time.Time, error) + // Err returns any current error of the connection + Err() error // Exec executes a query with optional args. - Exec(query string, args ...any) error - - // Insert a new row into table using the values. - Insert(table string, values Values) error - - // InsertUnique inserts a new row into table using the passed values - // or does nothing if the onConflict statement applies. - // Returns if a row was inserted. - InsertUnique(table string, values Values, onConflict string) (inserted bool, err error) - - // InsertReturning inserts a new row into table using values - // and returns values from the inserted row listed in returning. - InsertReturning(table string, values Values, returning string) RowScanner - - // InsertStruct inserts a new row into table using the connection's - // StructFieldMapper to map struct fields to column names. - // Optional ColumnFilter can be passed to ignore mapped columns. - InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // InsertUniqueStruct inserts a new row into table using the connection's - // StructFieldMapper to map struct fields to column names. - // Optional ColumnFilter can be passed to ignore mapped columns. - // Does nothing if the onConflict statement applies - // and returns if a row was inserted. - InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error) - - // Update table rows(s) with values using the where statement with passed in args starting at $1. - Update(table string, values Values, where string, args ...any) error - - // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 - // and returning a single row with the columns specified in returning argument. - UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner - - // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 - // and returning multiple rows with the columns specified in returning argument. - UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner - - // UpdateStruct updates a row in a table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // UpsertStruct upserts a row to table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - // If inserting conflicts on the primary key column(s), then an update is performed. - UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error + Exec(ctx context.Context, query string, args ...any) error // QueryRow queries a single row and returns a RowScanner for the results. - QueryRow(query string, args ...any) RowScanner + QueryRow(ctx context.Context, query string, args ...any) Row // QueryRows queries multiple rows and returns a RowsScanner for the results. - QueryRows(query string, args ...any) RowsScanner + QueryRows(ctx context.Context, query string, args ...any) Rows // IsTransaction returns if the connection is a transaction IsTransaction() bool - // TransactionOptions returns the sql.TxOptions of the - // current transaction and true as second result value, - // or false if the connection is not a transaction. - TransactionOptions() (*sql.TxOptions, bool) + // TxOptions returns the sql.TxOptions of the + // current transaction which can be nil for the default options. + // Use IsTransaction to check if the connection is a transaction. + TxOptions() *sql.TxOptions // Begin a new transaction. // If the connection is already a transaction, a brand // new transaction will begin on the parent's connection. - Begin(opts *sql.TxOptions) (Connection, error) + Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) // Commit the current transaction. // Returns ErrNotWithinTransaction if the connection diff --git a/db/config.go b/db/config.go index 3e63375..793ef87 100644 --- a/db/config.go +++ b/db/config.go @@ -1,23 +1,27 @@ package db import ( - "context" "errors" "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) var ( // Number of retries used for a SerializedTransaction // before it fails SerializedTransactionRetries = 10 + + // DefaultStructFieldMapping provides the default StructFieldTagNaming + // using "db" as NameTag and IgnoreStructField as UntaggedNameFunc. + // Implements StructFieldMapper. + DefaultStructFieldMapping = reflection.NewTaggedStructFieldMapping() ) var ( - conn = sqldb.ConnectionWithError( - context.Background(), + globalConn = sqldb.ConnectionWithError( errors.New("database connection not initialized"), ) - connCtxKey int + globalConnCtxKey int serializedTransactionCtxKey int ) diff --git a/db/conn.go b/db/conn.go index 241cfb3..770db11 100644 --- a/db/conn.go +++ b/db/conn.go @@ -14,7 +14,7 @@ func SetConn(c sqldb.Connection) { if c == nil { panic("must not set nil sqldb.Connection") } - conn = c + globalConn = c } // Conn returns a non nil sqldb.Connection from ctx @@ -22,30 +22,30 @@ func SetConn(c sqldb.Connection) { // The returned connection will use the passed context. // See sqldb.Connection.WithContext func Conn(ctx context.Context) sqldb.Connection { - return ConnDefault(ctx, conn) -} - -// ConnDefault returns a non nil sqldb.Connection from ctx -// or the passed defaultConn. -// The returned connection will use the passed context. -// See sqldb.Connection.WithContext -func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { - c, _ := ctx.Value(&connCtxKey).(sqldb.Connection) - if c == nil { - c = defaultConn - } - if c.Context() == ctx { + if c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection); c != nil { return c } - return c.WithContext(ctx) + return globalConn } +// // ConnDefault returns a non nil sqldb.Connection from ctx +// // or the passed defaultConn. +// // The returned connection will use the passed context. +// // See sqldb.Connection.WithContext +// func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { +// c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection) +// if c == nil { +// return defaultConn +// } +// return c +// } + // ContextWithConn returns a new context with the passed sqldb.Connection // added as value so it can be retrieved again using Conn(ctx). // Passing a nil connection causes Conn(ctx) // to return the global connection set with SetConn. func ContextWithConn(ctx context.Context, conn sqldb.Connection) context.Context { - return context.WithValue(ctx, &connCtxKey, conn) + return context.WithValue(ctx, &globalConnCtxKey, conn) } // ContextWithoutCancel returns a new context that inherits diff --git a/impl/insert.go b/db/insert.go similarity index 50% rename from impl/insert.go rename to db/insert.go index 8254b4f..a097382 100644 --- a/impl/insert.go +++ b/db/insert.go @@ -1,33 +1,41 @@ -package impl +package db import ( + "context" "fmt" "reflect" "strings" - sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) +type Values = sqldb.Values + // Insert a new row into table using the values. -func Insert(conn sqldb.Connection, table, argFmt string, values sqldb.Values) error { +func Insert(ctx context.Context, table string, values Values) error { if len(values) == 0 { return fmt.Errorf("Insert into table %s: no values", table) } + conn := Conn(ctx) + argFmt := conn.Config().ParamPlaceholderFormatter names, vals := values.Sorted() b := strings.Builder{} writeInsertQuery(&b, table, argFmt, names) query := b.String() - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + err := conn.Exec(ctx, query, vals...) + if err != nil { + return sqldb.WrapErrorWithQuery(err, query, vals, argFmt) + } + return nil } // InsertUnique inserts a new row into table using the passed values // or does nothing if the onConflict statement applies. // Returns if a row was inserted. -func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Values, onConflict string) (inserted bool, err error) { +func InsertUnique(ctx context.Context, table string, values Values, onConflict string) (inserted bool, err error) { if len(values) == 0 { return false, fmt.Errorf("InsertUnique into table %s: no values", table) } @@ -36,69 +44,57 @@ func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Valu onConflict = onConflict[1 : len(onConflict)-1] } + conn := Conn(ctx) + argFmt := conn.Config().ParamPlaceholderFormatter names, vals := values.Sorted() var query strings.Builder writeInsertQuery(&query, table, argFmt, names) fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) - - err = sqldb.ReplaceErrNoRows(err, nil) - err = WrapNonNilErrorWithQuery(err, query.String(), argFmt, vals) + err = conn.QueryRow(ctx, query.String(), vals...).Scan(&inserted) + if err != nil { + return false, sqldb.WrapErrorWithQuery(err, query.String(), vals, argFmt) + } return inserted, err } // InsertReturning inserts a new row into table using values // and returns values from the inserted row listed in returning. -func InsertReturning(conn sqldb.Connection, table, argFmt string, values sqldb.Values, returning string) sqldb.RowScanner { +func InsertReturning(ctx context.Context, table string, values Values, returning string) sqldb.Row { if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) + return sqldb.RowWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) } + conn := Conn(ctx) + argFmt := conn.Config().ParamPlaceholderFormatter names, vals := values.Sorted() var query strings.Builder writeInsertQuery(&query, table, argFmt, names) query.WriteString(" RETURNING ") query.WriteString(returning) - return conn.QueryRow(query.String(), vals...) -} - -func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { - fmt.Fprintf(w, `INSERT INTO %s(`, table) - for i, name := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteByte('"') - w.WriteString(name) - w.WriteByte('"') - } - w.WriteString(`) VALUES(`) - for i := range names { - if i > 0 { - w.WriteByte(',') - } - fmt.Fprintf(w, argFmt, i+1) - } - w.WriteByte(')') + return conn.QueryRow(ctx, query.String(), vals...) } // InsertStruct inserts a new row into table using the connection's // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. -func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) +func InsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { + conn := Conn(ctx) + table, columns, vals, err := insertStructValues(rowStruct, DefaultStructFieldMapping, ignoreColumns) if err != nil { return err } + argFmt := conn.Config().ParamPlaceholderFormatter var b strings.Builder writeInsertQuery(&b, table, argFmt, columns) query := b.String() - err = conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + err = conn.Exec(ctx, query, vals...) + if err != nil { + return sqldb.WrapErrorWithQuery(err, query, vals, argFmt) + } + return nil } // InsertUniqueStruct inserts a new row into table using the connection's @@ -106,8 +102,9 @@ func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld // Optional ColumnFilter can be passed to ignore mapped columns. // Does nothing if the onConflict statement applies // and returns if a row was inserted. -func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onConflict string, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) (inserted bool, err error) { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) +func InsertUniqueStruct(ctx context.Context, rowStruct any, onConflict string, ignoreColumns ...reflection.ColumnFilter) (inserted bool, err error) { + conn := Conn(ctx) + table, columns, vals, err := insertStructValues(rowStruct, DefaultStructFieldMapping, ignoreColumns) if err != nil { return false, err } @@ -116,29 +113,51 @@ func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onCo onConflict = onConflict[1 : len(onConflict)-1] } + argFmt := conn.Config().ParamPlaceholderFormatter var b strings.Builder writeInsertQuery(&b, table, argFmt, columns) fmt.Fprintf(&b, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) query := b.String() - err = conn.QueryRow(query, vals...).Scan(&inserted) - err = sqldb.ReplaceErrNoRows(err, nil) + err = conn.QueryRow(ctx, query, vals...).Scan(&inserted) + if err != nil { + return false, sqldb.WrapErrorWithQuery(err, query, vals, argFmt) + } + return inserted, nil +} - return inserted, WrapNonNilErrorWithQuery(err, query, argFmt, vals) +func writeInsertQuery(w *strings.Builder, table string, argFmt sqldb.ParamPlaceholderFormatter, names []string) { + fmt.Fprintf(w, `INSERT INTO %s(`, table) + for i, name := range names { + if i > 0 { + w.WriteByte(',') + } + w.WriteByte('"') + w.WriteString(name) + w.WriteByte('"') + } + w.WriteString(`) VALUES(`) + for i := range names { + if i > 0 { + w.WriteByte(',') + } + w.WriteString(argFmt.ParamPlaceholder(i)) + } + w.WriteByte(')') } -func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { +func insertStructValues(rowStruct any, mapper reflection.StructFieldMapper, ignoreColumns []reflection.ColumnFilter) (table string, columns []string, vals []any, err error) { v := reflect.ValueOf(rowStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() } switch { case v.Kind() == reflect.Ptr && v.IsNil(): - return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) + return "", nil, nil, fmt.Errorf("can't insert nil") case v.Kind() != reflect.Struct: - return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) + return "", nil, nil, fmt.Errorf("expected struct but got %T", rowStruct) } - columns, _, vals = ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - return columns, vals, nil + table, columns, _, vals, err = reflection.ReflectStructValues(v, mapper, append(ignoreColumns, sqldb.IgnoreReadOnly)) + return table, columns, vals, err } diff --git a/db/query.go b/db/query.go new file mode 100644 index 0000000..d7b45fd --- /dev/null +++ b/db/query.go @@ -0,0 +1,159 @@ +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "time" + + "github.com/domonda/go-sqldb" +) + +// Now returns the result of the SQL now() +// function for the current connection. +// Useful for getting the timestamp of a +// SQL transaction for use in Go code. +func Now(ctx context.Context) (time.Time, error) { + var now time.Time + err := Conn(ctx).QueryRow(ctx, `SELECT now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil +} + +// Exec executes a query with optional args. +func Exec(ctx context.Context, query string, args ...any) error { + return Conn(ctx).Exec(ctx, query, args...) +} + +// QueryRow queries a single row and returns a Row for the results. +func QueryRow(ctx context.Context, query string, args ...any) sqldb.Row { + return Conn(ctx).QueryRow(ctx, query, args...) +} + +// QueryRows queries multiple rows and returns a Rows for the results. +func QueryRows(ctx context.Context, query string, args ...any) sqldb.Rows { + return Conn(ctx).QueryRows(ctx, query, args...) +} + +// QueryValue queries a single value of type T. +func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) { + err = Conn(ctx).QueryRow(ctx, query, args...).Scan(&value) + if err != nil { + var zero T + return zero, err + } + return value, nil +} + +// QueryValueOrDefault queries a single value of type T +// or returns the default zero value of T in case of sql.ErrNoRows. +func QueryValueOrDefault[T any](ctx context.Context, query string, args ...any) (value T, err error) { + err = Conn(ctx).QueryRow(ctx, query, args...).Scan(&value) + if err != nil { + var zero T + if errors.Is(err, sql.ErrNoRows) { + return zero, nil + } + return zero, err + } + return value, err +} + +// QueryStruct uses the passed pkValue+pkValues to query a table row +// and scan it into a struct of type S that must have tagged fields +// with primary key flags to identify the primary key column names +// for the passed pkValue+pkValues and a table name. +func QueryStruct[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { + // Using explicit first pkValue value + // to not be able to compile without any value + pkValues = append([]any{pkValue}, pkValues...) + t := reflect.TypeOf(row).Elem() + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct template type instead of %s", t) + } + conn := Conn(ctx) + table, pkColumns, err := pkColumnsOfStruct(conn, t) + if err != nil { + return nil, err + } + if len(pkColumns) != len(pkValues) { + return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns)) + } + query := fmt.Sprintf(`SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) + for i := 1; i < len(pkColumns); i++ { + query += fmt.Sprintf(` AND "%s" = $%d`, pkColumns[i], i+1) + } + err = conn.QueryRow(ctx, query, pkValues...).ScanStruct(&row) + if err != nil { + return nil, err + } + return row, nil +} + +// QueryStructOrNil uses the passed pkValue+pkValues to query a table row +// and scan it into a struct of type S that must have tagged fields +// with primary key flags to identify the primary key column names +// for the passed pkValue+pkValues and a table name. +// Returns nil as row and error if no row could be found with the +// passed pkValue+pkValues. +func QueryStructOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { + row, err = QueryStruct[S](ctx, pkValue, pkValues...) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return row, nil +} + +func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { + mapper := conn.StructFieldMapper() + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldTable, column, flags, ok := mapper.MapStructField(field) + if !ok { + continue + } + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) + } + table = fieldTable + } + + if column == "" { + fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) + if err != nil { + return "", nil, err + } + if fieldTable != "" && fieldTable != table { + if table != "" { + return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) + } + table = fieldTable + } + columns = append(columns, columnsEmbed...) + } else if flags.PrimaryKey() { + if err = conn.ValidateColumnName(column); err != nil { + return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) + } + columns = append(columns, column) + } + } + return table, columns, nil +} + +// QueryStructSlice returns queried rows as slice of the generic type S +// which must be a struct or a pointer to a struct. +func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []S, err error) { + err = Conn(ctx).QueryRows(ctx, query, args...).ScanStructSlice(&rows) + if err != nil { + return nil, err + } + return rows, nil +} diff --git a/db/querystruct.go b/db/querystruct.go deleted file mode 100644 index 1c20342..0000000 --- a/db/querystruct.go +++ /dev/null @@ -1,89 +0,0 @@ -package db - -import ( - "context" - "errors" - "fmt" - "reflect" - - "github.com/domonda/go-sqldb" -) - -// QueryStruct uses the passed pkValues to query a table row -// and scan it into a struct of type S that must have tagged fields -// with primary key flags to identify the primary key column names -// for the passed pkValues and a table name. -func QueryStruct[S any](ctx context.Context, pkValues ...any) (row *S, err error) { - if len(pkValues) == 0 { - return nil, errors.New("missing primary key values") - } - t := reflect.TypeOf(row).Elem() - if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("expected struct template type instead of %s", t) - } - conn := Conn(ctx) - table, pkColumns, err := pkColumnsOfStruct(conn, t) - if err != nil { - return nil, err - } - if len(pkColumns) != len(pkValues) { - return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns)) - } - query := fmt.Sprintf(`SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) - for i := 1; i < len(pkColumns); i++ { - query += fmt.Sprintf(` AND "%s" = $%d`, pkColumns[i], i+1) - } - err = conn.QueryRow(query, pkValues...).ScanStruct(&row) - if err != nil { - return nil, err - } - return row, nil -} - -// QueryStructOrNil uses the passed pkValues to query a table row -// and scan it into a struct of type S that must have tagged fields -// with primary key flags to identify the primary key column names -// for the passed pkValues and a table name. -// Returns nil as row and error if no row could be found with the -// passed pkValues. -func QueryStructOrNil[S any](ctx context.Context, pkValues ...any) (row *S, err error) { - row, err = QueryStruct[S](ctx, pkValues...) - return row, ReplaceErrNoRows(err, nil) -} - -func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { - mapper := conn.StructFieldMapper() - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - fieldTable, column, flags, ok := mapper.MapStructField(field) - if !ok { - continue - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } - - if column == "" { - fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) - if err != nil { - return "", nil, err - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } - columns = append(columns, columnsEmbed...) - } else if flags.PrimaryKey() { - if err = conn.ValidateColumnName(column); err != nil { - return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) - } - columns = append(columns, column) - } - } - return table, columns, nil -} diff --git a/db/transaction.go b/db/transaction.go index 58ea7fd..fb85419 100644 --- a/db/transaction.go +++ b/db/transaction.go @@ -51,7 +51,7 @@ func DebugNoTransaction(ctx context.Context, nonTxFunc func(context.Context) err // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. func IsolatedTransaction(ctx context.Context, txFunc func(context.Context) error) error { return sqldb.IsolatedTransaction(Conn(ctx), nil, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -64,7 +64,7 @@ func IsolatedTransaction(ctx context.Context, txFunc func(context.Context) error // Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. func Transaction(ctx context.Context, txFunc func(context.Context) error) error { return sqldb.Transaction(Conn(ctx), nil, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -135,7 +135,7 @@ func SerializedTransaction(ctx context.Context, txFunc func(context.Context) err // Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(context.Context) error) error { return sqldb.Transaction(Conn(ctx), opts, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -149,7 +149,7 @@ func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(conte func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error) error { opts := sql.TxOptions{ReadOnly: true} return sqldb.Transaction(Conn(ctx), &opts, func(tx sqldb.Connection) error { - return txFunc(context.WithValue(ctx, &connCtxKey, tx)) + return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } diff --git a/db/transaction_test.go b/db/transaction_test.go index 16adfda..a6e16ac 100644 --- a/db/transaction_test.go +++ b/db/transaction_test.go @@ -10,7 +10,7 @@ import ( ) func TestSerializedTransaction(t *testing.T) { - conn = mockconn.New(context.Background(), os.Stdout, nil) + globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectSerialized := func(ctx context.Context) error { if !Conn(ctx).IsTransaction() { @@ -64,7 +64,7 @@ func TestSerializedTransaction(t *testing.T) { } func TestTransaction(t *testing.T) { - conn = mockconn.New(context.Background(), os.Stdout, nil) + globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectNonSerialized := func(ctx context.Context) error { if !Conn(ctx).IsTransaction() { diff --git a/impl/update.go b/db/update.go similarity index 56% rename from impl/update.go rename to db/update.go index 2d5976f..1cdd1f5 100644 --- a/impl/update.go +++ b/db/update.go @@ -1,50 +1,56 @@ -package impl +package db import ( + "context" + "errors" "fmt" "reflect" "strings" sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" "golang.org/x/exp/slices" ) // Update table rows(s) with values using the where statement with passed in args starting at $1. -func Update(conn sqldb.Connection, table string, values sqldb.Values, where, argFmt string, args []any) error { +func Update(ctx context.Context, table string, values sqldb.Values, where string, args ...any) error { if len(values) == 0 { return fmt.Errorf("Update table %s: no values passed", table) } - query, vals := buildUpdateQuery(table, values, where, args) + conn := Conn(ctx) + query, vals := buildUpdateQuery(table, values, where, conn, args) err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return sqldb.WrapErrorWithQuery(err, query, conn, vals) } // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 // and returning a single row with the columns specified in returning argument. -func UpdateReturningRow(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { +func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.Row { if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) + return sqldb.RowWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) } - query, vals := buildUpdateQuery(table, values, where, args) + conn := Conn(ctx) + query, vals := buildUpdateQuery(table, values, where, conn, args) query += " RETURNING " + returning return conn.QueryRow(query, vals...) } // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 // and returning multiple rows with the columns specified in returning argument. -func UpdateReturningRows(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { +func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.Rows { if len(values) == 0 { - return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) + return sqldb.RowsWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) } - query, vals := buildUpdateQuery(table, values, where, args) + conn := Conn(ctx) + query, vals := buildUpdateQuery(table, values, where, conn, args) query += " RETURNING " + returning return conn.QueryRows(query, vals...) } -func buildUpdateQuery(table string, values sqldb.Values, where string, args []any) (string, []any) { +func buildUpdateQuery(table string, values sqldb.Values, where string, argFmt sqldb.ParamPlaceholderFormatter, args []any) (string, []any) { names, vals := values.Sorted() var query strings.Builder @@ -53,7 +59,7 @@ func buildUpdateQuery(table string, values sqldb.Values, where string, args []an if i > 0 { query.WriteByte(',') } - fmt.Fprintf(&query, `"%s"=$%d`, names[i], 1+len(args)+i) + fmt.Fprintf(&query, `"%s"=%s`, names[i], argFmt.ParamPlaceholder(len(args)+i)) } fmt.Fprintf(&query, ` WHERE %s`, where) @@ -65,19 +71,20 @@ func buildUpdateQuery(table string, values sqldb.Values, where string, args []an // Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. // If restrictToColumns are provided, then only struct fields with a `db` tag // matching any of the passed column names will be used. -func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return fmt.Errorf("UpdateStruct of table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return fmt.Errorf("UpdateStruct of table %s: expected struct but got %T", table, rowStruct) +func UpdateStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { + v, err := derefStruct(rowStruct) + if err != nil { + return err } - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + conn := Conn(ctx) + table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, DefaultStructFieldMapping, append(ignoreColumns, sqldb.IgnoreReadOnly)) + if err != nil { + return err + } + if table == "" { + return fmt.Errorf("UpdateStruct: %s has no table name", v.Type()) + } if len(pkCols) == 0 { return fmt.Errorf("UpdateStruct of table %s: %s has no mapped primary key field", table, v.Type()) } @@ -107,7 +114,21 @@ func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld query := b.String() - err := conn.Exec(query, vals...) + err = conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return sqldb.WrapErrorWithQuery(err, query, conn, vals) +} + +func derefStruct(rowStruct any) (reflect.Value, error) { + v := reflect.ValueOf(rowStruct) + for v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + switch { + case v.Kind() == reflect.Ptr && v.IsNil(): + return reflect.Value{}, errors.New("can't use nil pointer") + case v.Kind() != reflect.Struct: + return reflect.Value{}, fmt.Errorf("expected struct but got %T", rowStruct) + } + return v, nil } diff --git a/db/upsert.go b/db/upsert.go new file mode 100644 index 0000000..5cc38ce --- /dev/null +++ b/db/upsert.go @@ -0,0 +1,66 @@ +package db + +import ( + "context" + "fmt" + "strings" + + "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" + "golang.org/x/exp/slices" +) + +// UpsertStruct upserts a row to table using the exported fields +// of rowStruct which have a `db` tag that is not "-". +// If restrictToColumns are provided, then only struct fields with a `db` tag +// matching any of the passed column names will be used. +// The struct must have at least one field with a `db` tag value having a ",pk" suffix +// to mark primary key column(s). +// If inserting conflicts on the primary key column(s), then an update is performed. +func UpsertStruct(ctx context.Context, rowStruct any, ignoreColumns ...reflection.ColumnFilter) error { + v, err := derefStruct(rowStruct) + if err != nil { + return err + } + + conn := Conn(ctx) + table, columns, pkCols, vals, err := reflection.ReflectStructValues(v, DefaultStructFieldMapping, append(ignoreColumns, sqldb.IgnoreReadOnly)) + if err != nil { + return err + } + if table == "" { + return fmt.Errorf("UpsertStruct: %s has no table name", v.Type()) + } + if len(pkCols) == 0 { + return fmt.Errorf("UpsertStruct: %s has no mapped primary key field", v.Type()) + } + + var b strings.Builder + writeInsertQuery(&b, table, conn, columns) + b.WriteString(` ON CONFLICT(`) + for i, pkCol := range pkCols { + if i > 0 { + b.WriteByte(',') + } + fmt.Fprintf(&b, `"%s"`, columns[pkCol]) + } + + b.WriteString(`) DO UPDATE SET `) + first := true + for i := range columns { + if slices.Contains(pkCols, i) { + continue + } + if first { + first = false + } else { + b.WriteByte(',') + } + fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) + } + query := b.String() + + err = conn.Exec(query, vals...) + + return sqldb.WrapErrorWithQuery(err, query, conn, vals) +} diff --git a/dbconnection.go b/dbconnection.go new file mode 100644 index 0000000..b6836fc --- /dev/null +++ b/dbconnection.go @@ -0,0 +1,101 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +type DBConnection struct { + Conf *Config + DB *sql.DB +} + +func (c *DBConnection) Config() *Config { + return c.Conf +} + +func (c *DBConnection) Stats() sql.DBStats { + return c.DB.Stats() +} + +func (c *DBConnection) Ping(ctx context.Context, timeout time.Duration) error { + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return c.DB.PingContext(ctx) +} + +func (c *DBConnection) Err() error { + return c.Conf.Err +} + +func (c *DBConnection) Exec(ctx context.Context, query string, args ...any) error { + _, err := c.DB.ExecContext(ctx, query, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, c.Conf.ParamPlaceholderFormatter) + } + return nil +} + +func (c *DBConnection) QueryRow(ctx context.Context, query string, args ...any) Row { + rows, err := c.DB.QueryContext(ctx, query, args...) + if err != nil { + return RowWithError(WrapErrorWithQuery(err, query, args, c.Conf.ParamPlaceholderFormatter)) + } + return NewRow(ctx, rows, c, query, args) +} + +func (c *DBConnection) QueryRows(ctx context.Context, query string, args ...any) Rows { + rows, err := c.DB.QueryContext(ctx, query, args...) + if err != nil { + return RowsWithError(WrapErrorWithQuery(err, query, args, c.Conf.ParamPlaceholderFormatter)) + } + return NewRows(ctx, rows, c, query, args) +} + +func (c *DBConnection) IsTransaction() bool { + return false +} + +func (c *DBConnection) TxOptions() *sql.TxOptions { + return nil +} + +func (c *DBConnection) Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) { + tx, err := c.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &TxConnection{ + Parent: c, + Tx: tx, + Opts: opts, + }, nil +} + +func (c *DBConnection) Commit() error { + return ErrNotWithinTransaction +} + +func (c *DBConnection) Rollback() error { + return ErrNotWithinTransaction +} + +func (c *DBConnection) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return ErrNotSupported +} + +func (c *DBConnection) UnlistenChannel(channel string) error { + return ErrNotSupported +} + +func (c *DBConnection) IsListeningOnChannel(channel string) bool { + return false +} + +func (c *DBConnection) Close() error { + return c.DB.Close() +} diff --git a/errconnection.go b/errconnection.go new file mode 100644 index 0000000..e94d10c --- /dev/null +++ b/errconnection.go @@ -0,0 +1,88 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +// ConnectionWithError returns a dummy Connection +// where all methods return the passed error. +func ConnectionWithError(err error) Connection { + if err == nil { + panic("ConnectionWithError needs an error") + } + return errConn{err} +} + +type errConn struct { + err error +} + +func (e errConn) Config() *Config { + return &Config{Err: e.err} +} + +func (e errConn) Stats() sql.DBStats { + return sql.DBStats{} +} + +func (e errConn) Ping(context.Context, time.Duration) error { + return e.err +} + +func (e errConn) Err() error { + return e.err +} + +func (e errConn) Exec(ctx context.Context, query string, args ...any) error { + return e.err +} + +func (e errConn) QueryRow(ctx context.Context, query string, args ...any) Row { + return RowWithError(e.err) +} + +func (e errConn) QueryRows(ctx context.Context, query string, args ...any) Rows { + return RowsWithError(e.err) +} + +func (e errConn) IsTransaction() bool { + return false +} + +func (ce errConn) TxOptions() *sql.TxOptions { + return nil +} + +func (e errConn) Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) { + return nil, e.err +} + +func (e errConn) Commit() error { + return e.err +} + +func (e errConn) Rollback() error { + return e.err +} + +func (e errConn) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { + return e.err +} + +func (e errConn) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return e.err +} + +func (e errConn) UnlistenChannel(channel string) error { + return e.err +} + +func (e errConn) IsListeningOnChannel(channel string) bool { + return false +} + +func (e errConn) Close() error { + return e.err +} diff --git a/errors.go b/errors.go index 124d0f3..09bee08 100644 --- a/errors.go +++ b/errors.go @@ -1,17 +1,22 @@ package sqldb import ( - "context" "database/sql" "errors" - "time" + "fmt" ) -var ( - _ Connection = connectionWithError{} - _ RowScanner = rowScannerWithError{} - _ RowsScanner = rowsScannerWithError{} -) +func combineTwoErrors(prim, sec error) error { + switch { + case prim != nil && sec != nil: + return fmt.Errorf("%w\n%s", prim, sec) + case prim != nil: + return prim + case sec != nil: + return sec + } + return nil +} // ReplaceErrNoRows returns the passed replacement error // if errors.Is(err, sql.ErrNoRows), @@ -58,212 +63,29 @@ const ( ErrNotSupported sentinelError = "not supported" ) -// ConnectionWithError - -// ConnectionWithError returns a dummy Connection -// where all methods return the passed error. -func ConnectionWithError(ctx context.Context, err error) Connection { +// WrapErrorWithQuery wraps non nil errors with a formatted query +// if the error was not already wrapped with a query. +// If the passed error is nil, then nil will be returned. +func WrapErrorWithQuery(err error, query string, args []any, paramFmt ParamPlaceholderFormatter) error { if err == nil { - panic("ConnectionWithError needs an error") + return nil } - return connectionWithError{ctx, err} -} - -type connectionWithError struct { - ctx context.Context - err error -} - -func (e connectionWithError) Context() context.Context { return e.ctx } - -func (e connectionWithError) WithContext(ctx context.Context) Connection { - return connectionWithError{ctx: ctx, err: e.err} -} - -func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection { - return e -} - -func (e connectionWithError) StructFieldMapper() StructFieldMapper { - return DefaultStructFieldMapping -} - -func (e connectionWithError) Ping(time.Duration) error { - return e.err -} - -func (e connectionWithError) Stats() sql.DBStats { - return sql.DBStats{} -} - -func (e connectionWithError) Config() *Config { - return &Config{Err: e.err} -} - -func (e connectionWithError) ValidateColumnName(name string) error { - return e.err -} - -func (e connectionWithError) Now() (time.Time, error) { - return time.Time{}, e.err -} - -func (e connectionWithError) Exec(query string, args ...any) error { - return e.err -} - -func (e connectionWithError) Insert(table string, values Values) error { - return e.err -} - -func (e connectionWithError) InsertUnique(table string, values Values, onConflict string) (inserted bool, err error) { - return false, e.err -} - -func (e connectionWithError) InsertReturning(table string, values Values, returning string) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) InsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...ColumnFilter) (inserted bool, err error) { - return false, e.err -} - -func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { - return e.err -} - -func (e connectionWithError) UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) IsTransaction() bool { - return false -} - -func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (e connectionWithError) Begin(opts *sql.TxOptions) (Connection, error) { - return nil, e.err -} - -func (e connectionWithError) Commit() error { - return e.err -} - -func (e connectionWithError) Rollback() error { - return e.err -} - -func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { - return e.err -} - -func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { - return e.err -} - -func (e connectionWithError) UnlistenChannel(channel string) error { - return e.err -} - -func (e connectionWithError) IsListeningOnChannel(channel string) bool { - return false -} - -func (e connectionWithError) Close() error { - return e.err -} - -// RowScannerWithError - -// RowScannerWithError returns a dummy RowScanner -// where all methods return the passed error. -func RowScannerWithError(err error) RowScanner { - return rowScannerWithError{err} -} - -type rowScannerWithError struct { - err error -} - -func (e rowScannerWithError) Scan(dest ...any) error { - return e.err -} - -func (e rowScannerWithError) ScanStruct(dest any) error { - return e.err -} - -func (e rowScannerWithError) ScanValues() ([]any, error) { - return nil, e.err -} - -func (e rowScannerWithError) ScanStrings() ([]string, error) { - return nil, e.err -} - -func (e rowScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -// RowsScannerWithError - -// RowsScannerWithError returns a dummy RowsScanner -// where all methods return the passed error. -func RowsScannerWithError(err error) RowsScanner { - return rowsScannerWithError{err} -} - -type rowsScannerWithError struct { - err error -} - -func (e rowsScannerWithError) ScanSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) ScanStructSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) Columns() ([]string, error) { - return nil, e.err + var wrapped errWithQuery + if errors.As(err, &wrapped) { + return err + } + return errWithQuery{err, query, args, paramFmt} } -func (e rowsScannerWithError) ScanAllRowsAsStrings(headerRow bool) ([][]string, error) { - return nil, e.err +type errWithQuery struct { + err error + query string + args []any + paramFmt ParamPlaceholderFormatter } -func (e rowsScannerWithError) ForEachRow(callback func(RowScanner) error) error { - return e.err -} +func (e errWithQuery) Unwrap() error { return e.err } -func (e rowsScannerWithError) ForEachRowCall(callback any) error { - return e.err +func (e errWithQuery) Error() string { + return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.paramFmt, e.args...)) } diff --git a/impl/errors_test.go b/errors_test.go similarity index 84% rename from impl/errors_test.go rename to errors_test.go index 6e544e5..a32b2c2 100644 --- a/impl/errors_test.go +++ b/errors_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql" @@ -11,7 +11,7 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { type args struct { err error query string - argFmt string + argFmt ParamPlaceholderFormatter args []any } tests := []struct { @@ -25,7 +25,7 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { args: args{ err: sql.ErrNoRows, query: `SELECT * FROM table WHERE b = $2 and a = $1`, - argFmt: "$%d", + argFmt: NewParamPlaceholderFormatter("$%d", 1), args: []any{1, "2"}, }, wantError: fmt.Sprintf("%s from query: %s", sql.ErrNoRows, `SELECT * FROM table WHERE b = '2' and a = 1`), @@ -33,7 +33,7 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := WrapNonNilErrorWithQuery(tt.args.err, tt.args.query, tt.args.argFmt, tt.args.args) + err := WrapErrorWithQuery(tt.args.err, tt.args.query, tt.args.argFmt, tt.args.args) if tt.wantError == "" && err != nil || tt.wantError != "" && (err == nil || err.Error() != tt.wantError) { t.Errorf("WrapNonNilErrorWithQuery() error = %v, wantErr %v", err, tt.wantError) } diff --git a/examples/user_demo/user_demo.go b/examples/user_demo/user_demo.go index dacf487..a329ab1 100644 --- a/examples/user_demo/user_demo.go +++ b/examples/user_demo/user_demo.go @@ -10,13 +10,14 @@ import ( "github.com/domonda/go-sqldb" "github.com/domonda/go-sqldb/db" "github.com/domonda/go-sqldb/pqconn" + "github.com/domonda/go-sqldb/reflection" "github.com/domonda/go-types/email" "github.com/domonda/go-types/nullable" "github.com/domonda/go-types/uu" ) type User struct { - ID uu.ID `db:"id,pk,default"` + ID uu.ID `db:"id,pk=public.user,default"` Email email.NullableAddress `db:"email"` Title nullable.NonEmptyString `db:"title"` @@ -45,10 +46,10 @@ func main() { panic(err) } - conn = conn.WithStructFieldMapper(&sqldb.TaggedStructFieldMapping{ + conn = conn.WithStructFieldMapper(&reflection.TaggedStructFieldMapping{ NameTag: "col", Ignore: "ignore", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, }) var users []User @@ -64,7 +65,7 @@ func main() { } err = conn.QueryRows(`select name, email from public.user`).ForEachRow( - func(row sqldb.RowScanner) error { + func(row sqldb.Row) error { var name, email string err := row.Scan(&name, &email) if err != nil { @@ -87,18 +88,20 @@ func main() { panic(err) } + ctx := context.Background() + newUser := &User{ /* ... */ } - err = conn.InsertStruct("public.user", newUser) + err = db.InsertStruct(ctx, newUser) if err != nil { panic(err) } - err = conn.InsertStruct("public.user", newUser, sqldb.IgnoreNullOrZeroDefault) + err = db.InsertStruct(ctx, newUser, sqldb.IgnoreNullOrZeroDefault) if err != nil { panic(err) } - err = conn.Insert("public.user", sqldb.Values{ + err = db.Insert(ctx, "public.user", sqldb.Values{ "name": "Erik Unger", "email": "erik@domonda.com", }) @@ -106,7 +109,7 @@ func main() { panic(err) } - err = conn.UpsertStruct("public.user", newUser, sqldb.IgnoreColumns("created_at")) + err = db.UpsertStruct(ctx, newUser, sqldb.IgnoreColumns("created_at")) if err != nil { panic(err) } diff --git a/impl/format.go b/format.go similarity index 96% rename from impl/format.go rename to format.go index 5f78ba3..22a080e 100644 --- a/impl/format.go +++ b/format.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql/driver" @@ -77,9 +77,9 @@ func FormatValue(val any) (string, error) { return fmt.Sprint(val), nil } -func FormatQuery(query, argFmt string, args ...any) string { +func FormatQuery(query string, argFmt ParamPlaceholderFormatter, args ...any) string { for i := len(args) - 1; i >= 0; i-- { - placeholder := fmt.Sprintf(argFmt, i+1) + placeholder := argFmt.ParamPlaceholder(i) value, err := FormatValue(args[i]) if err != nil { value = "FORMATERROR:" + err.Error() diff --git a/impl/format_test.go b/format_test.go similarity index 87% rename from impl/format_test.go rename to format_test.go index 8d6ab7c..64b81a2 100644 --- a/impl/format_test.go +++ b/format_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql/driver" @@ -81,12 +81,12 @@ WHERE tests := []struct { name string query string - argFmt string + argFmt ParamPlaceholderFormatter args []any want string }{ - {name: "query1", query: query1, argFmt: "$%d", args: []any{createdAt, true, `Erik's Test`}, want: query1formatted}, - {name: "query2", query: query2, argFmt: "$%d", args: []any{"", 2, "3"}, want: query2formatted}, + {name: "query1", query: query1, argFmt: NewParamPlaceholderFormatter("$%d", 1), args: []any{createdAt, true, `Erik's Test`}, want: query1formatted}, + {name: "query2", query: query2, argFmt: NewParamPlaceholderFormatter("$%d", 1), args: []any{"", 2, "3"}, want: query2formatted}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/go.mod b/go.mod index ef2493e..cd0a778 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,38 @@ module github.com/domonda/go-sqldb go 1.18 require ( - github.com/domonda/go-errs v0.0.0-20220527085304-63cf6ad85d71 - github.com/domonda/go-types v0.0.0-20220603104906-eadd2cf77191 + github.com/domonda/go-errs v0.0.0-20220622113709-bc43209ba645 + github.com/domonda/go-types v0.0.0-20220614092523-688aad7f8c57 github.com/go-sql-driver/mysql v1.6.0 github.com/lib/pq v1.10.6 - github.com/stretchr/testify v1.7.1 - golang.org/x/exp v0.0.0-20220602145555-4a0574d9293f + github.com/stretchr/testify v1.7.5 + golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d + modernc.org/sqlite v1.17.3 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/domonda/go-pretty v0.0.0-20220317123925-dd9e6bef129a // indirect + github.com/google/uuid v1.3.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/pretty v0.1.0 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/ungerik/go-reflection v0.0.0-20220113085621-6c5fc1f2694a // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/sys v0.0.0-20220624220833-87e55d714810 // indirect + golang.org/x/tools v0.1.11 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + lukechampine.com/uint128 v1.2.0 // indirect + modernc.org/cc/v3 v3.36.0 // indirect + modernc.org/ccgo/v3 v3.16.6 // indirect + modernc.org/libc v1.16.11 // indirect + modernc.org/mathutil v1.4.1 // indirect + modernc.org/memory v1.1.1 // indirect + modernc.org/opt v0.1.3 // indirect + modernc.org/strutil v1.1.2 // indirect + modernc.org/token v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index babe5fc..b76e511 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,24 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/domonda/go-errs v0.0.0-20220527085304-63cf6ad85d71 h1:WRag+fUJENLRM8N/wp6gf/0i1aEkLY9prNgoFQsWeso= -github.com/domonda/go-errs v0.0.0-20220527085304-63cf6ad85d71/go.mod h1:suiFfPp8l6I+OOaKgPK/bfX7Ci9ZtFRgPh5VNE0HPao= +github.com/domonda/go-errs v0.0.0-20220622113709-bc43209ba645 h1:hCCfGvOsbejnNPUdqD/wtE/t4pRjEn8/706tRqxUmck= +github.com/domonda/go-errs v0.0.0-20220622113709-bc43209ba645/go.mod h1:WvIoE59Dfs0hhB2GYSlwowlBr2WWGXf/F74bg6HWUpQ= github.com/domonda/go-pretty v0.0.0-20220317123925-dd9e6bef129a h1:6/Is0KGl5Ot3E8ZLAgAFWYiSRdU+3t3jL38+5yIlCV4= github.com/domonda/go-pretty v0.0.0-20220317123925-dd9e6bef129a/go.mod h1:3QkM8UJdyJMeKZiIo7hYzSkQBpRS3k0gOHw4ysyEIB4= -github.com/domonda/go-types v0.0.0-20220603104906-eadd2cf77191 h1:NcOIFS41zSztJog+aPw48HV8oVhRQPV0B6M6CshwFqc= -github.com/domonda/go-types v0.0.0-20220603104906-eadd2cf77191/go.mod h1:qZTRjdjIXo3g+8PUhfpkKbMPGsLVTuF3H7/AX5CzNeQ= +github.com/domonda/go-types v0.0.0-20220614092523-688aad7f8c57 h1:ivIpyltPSRPx1CdqqcXUi+hEp3SyFt6RR6B19pwpYOY= +github.com/domonda/go-types v0.0.0-20220614092523-688aad7f8c57/go.mod h1:jqmELFrQI8hv+uaTNjxht99Wn+14jbUoSmwkbnxaA/g= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -18,18 +26,97 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-sqlite3 v1.14.12 h1:TJ1bhYJPV44phC+IMu1u2K/i5RriLTPe+yc68XDJ1Z0= +github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.5 h1:s5PTfem8p8EbKQOctVV53k6jCJt3UX4IEJzwh+C324Q= +github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/ungerik/go-reflection v0.0.0-20220113085621-6c5fc1f2694a h1:9vfYtqoyrPw08TbSLxkSXEflp6iXa3RL86Qjs+DrVas= github.com/ungerik/go-reflection v0.0.0-20220113085621-6c5fc1f2694a/go.mod h1:6Hnd2/4g3Tpt6TjvxHx8wXOZziwApVxRdIGkr7vNpXs= -golang.org/x/exp v0.0.0-20220602145555-4a0574d9293f h1:KK6mxegmt5hGJRcAnEDjSNLxIRhZxDcgwMbcO/lMCRM= -golang.org/x/exp v0.0.0-20220602145555-4a0574d9293f/go.mod h1:yh0Ynu2b5ZUe3MQfp2nM0ecK7wsgouWTDN0FNeJuIys= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d h1:vtUKgx8dahOomfFzLREU8nSv25YHnTgLBn4rDnWZdU0= +golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220624220833-87e55d714810 h1:rHZQSjJdAI4Xf5Qzeh2bBc5YJIkPFVM6oDtMFYmgws0= +golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= +golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.36.0 h1:0kmRkTmqNidmu3c7BNDSdVHCxXCkWLmWmCIVX4LUboo= +modernc.org/cc/v3 v3.36.0/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI= +modernc.org/ccgo/v3 v3.0.0-20220428102840-41399a37e894/go.mod h1:eI31LL8EwEBKPpNpA4bU1/i+sKOwOrQy8D87zWUcRZc= +modernc.org/ccgo/v3 v3.0.0-20220430103911-bc99d88307be/go.mod h1:bwdAnOoaIt8Ax9YdWGjxWsdkPcZyRPHqrOvJxaKAKGw= +modernc.org/ccgo/v3 v3.16.4/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= +modernc.org/ccgo/v3 v3.16.6 h1:3l18poV+iUemQ98O3X5OMr97LOqlzis+ytivU4NqGhA= +modernc.org/ccgo/v3 v3.16.6/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= +modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= +modernc.org/ccorpus v1.11.6/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ= +modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= +modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= +modernc.org/libc v0.0.0-20220428101251-2d5f3daf273b/go.mod h1:p7Mg4+koNjc8jkqwcoFBJx7tXkpj00G77X7A72jXPXA= +modernc.org/libc v1.16.0/go.mod h1:N4LD6DBE9cf+Dzf9buBlzVJndKr/iJHG97vGLHYnb5A= +modernc.org/libc v1.16.1/go.mod h1:JjJE0eu4yeK7tab2n4S1w8tlWd9MxXLRzheaRnAKymU= +modernc.org/libc v1.16.7/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= +modernc.org/libc v1.16.11 h1:rR2BPB5e9zUm9gYqDgR0hUxcSmjgtmAL79lRObBLfPU= +modernc.org/libc v1.16.11/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= +modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.4.1 h1:ij3fYGe8zBF4Vu+g0oT7mB06r8sqGWKuJu1yXeR4by8= +modernc.org/mathutil v1.4.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.1.1 h1:bDOL0DIDLQv7bWhP3gMvIrnoFw+Eo6F7a2QK9HPDiFU= +modernc.org/memory v1.1.1/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw= +modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.17.3 h1:iE+coC5g17LtByDYDWKpR6m2Z9022YrSh3bumwOnIrI= +modernc.org/sqlite v1.17.3/go.mod h1:10hPVYar9C0kfXuTWGz8s0XtB8uAGymUy51ZzStYe3k= +modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= +modernc.org/strutil v1.1.2 h1:iFBDH6j1Z0bN/Q9udJnnFoFpENA4252qe/7/5woE5MI= +modernc.org/strutil v1.1.2/go.mod h1:OYajnUAcI/MX+XD/Wx7v1bbdvcQSvxgtb0gC+u3d3eg= +modernc.org/tcl v1.13.1 h1:npxzTwFTZYM8ghWicVIX1cRWzj7Nd8i6AqqX2p+IYao= +modernc.org/tcl v1.13.1/go.mod h1:XOLfOwzhkljL4itZkK6T72ckMgvj0BDsnKNdZVUOecw= +modernc.org/token v1.0.0 h1:a0jaWiNMDhDUtqOj09wvjWWAqd3q7WpBulmL9H2egsk= +modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.5.1 h1:RTNHdsrOpeoSeOF4FbzTo8gBYByaJ5xT7NgZ9ZqRiJM= +modernc.org/z v1.5.1/go.mod h1:eWFB510QWW5Th9YGZT81s+LwvaAs3Q2yr4sP0rmLkv8= diff --git a/impl/connection.go b/impl/connection.go deleted file mode 100644 index e50f3b2..0000000 --- a/impl/connection.go +++ /dev/null @@ -1,189 +0,0 @@ -package impl - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/domonda/go-sqldb" -) - -// Connection returns a generic sqldb.Connection implementation -// for an existing sql.DB connection. -// argFmt is the format string for argument placeholders like "?" or "$%d" -// that will be replaced error messages to format a complete query. -func Connection(ctx context.Context, db *sql.DB, config *sqldb.Config, validateColumnName func(string) error, argFmt string) sqldb.Connection { - return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldNamer: sqldb.DefaultStructFieldMapping, - argFmt: argFmt, - validateColumnName: validateColumnName, - } -} - -type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper - argFmt string - validateColumnName func(string) error -} - -func (conn *connection) clone() *connection { - c := *conn - return &c -} - -func (conn *connection) Context() context.Context { return conn.ctx } - -func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.ctx { - return conn - } - c := conn.clone() - c.ctx = ctx - return c -} - -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *connection) Ping(timeout time.Duration) error { - ctx := conn.ctx - if timeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - return conn.db.PingContext(ctx) -} - -func (conn *connection) Stats() sql.DBStats { - return conn.db.Stats() -} - -func (conn *connection) Config() *sqldb.Config { - return conn.config -} - -func (conn *connection) ValidateColumnName(name string) error { - return conn.validateColumnName(name) -} - -func (conn *connection) Now() (time.Time, error) { - return Now(conn) -} - -func (conn *connection) Exec(query string, args ...any) error { - _, err := conn.db.ExecContext(conn.ctx, query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) -} - -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return Insert(conn, table, conn.argFmt, columValues) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return InsertUnique(conn, table, conn.argFmt, values, onConflict) -} - -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return InsertReturning(conn, table, conn.argFmt, values, returning) -} - -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.argFmt, args) -} - -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) -} - -func (conn *connection) IsTransaction() bool { - return false -} - -func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (conn *connection) Begin(opts *sql.TxOptions) (sqldb.Connection, error) { - tx, err := conn.db.BeginTx(conn.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn, tx, opts), nil -} - -func (conn *connection) Commit() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) Rollback() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) -} - -func (conn *connection) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) -} - -func (conn *connection) IsListeningOnChannel(channel string) bool { - return false -} - -func (conn *connection) Close() error { - return conn.db.Close() -} diff --git a/impl/errors.go b/impl/errors.go deleted file mode 100644 index b6028a5..0000000 --- a/impl/errors.go +++ /dev/null @@ -1,42 +0,0 @@ -package impl - -import ( - "errors" - "fmt" -) - -// WrapNonNilErrorWithQuery wraps non nil errors with a formatted query -// if the error was not already wrapped with a query. -// If the passed error is nil, then nil will be returned. -func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error { - var wrapped errWithQuery - if err == nil || errors.As(err, &wrapped) { - return err - } - return errWithQuery{err, query, argFmt, args} -} - -type errWithQuery struct { - err error - query string - argFmt string - args []any -} - -func (e errWithQuery) Unwrap() error { return e.err } - -func (e errWithQuery) Error() string { - return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.argFmt, e.args...)) -} - -func combineErrors(prim, sec error) error { - switch { - case prim != nil && sec != nil: - return fmt.Errorf("%w\n%s", prim, sec) - case prim != nil: - return prim - case sec != nil: - return sec - } - return nil -} diff --git a/impl/now.go b/impl/now.go deleted file mode 100644 index 4a1bd2f..0000000 --- a/impl/now.go +++ /dev/null @@ -1,15 +0,0 @@ -package impl - -import ( - "time" - - "github.com/domonda/go-sqldb" -) - -func Now(conn sqldb.Connection) (now time.Time, err error) { - err = conn.QueryRow(`select now()`).Scan(&now) - if err != nil { - return time.Time{}, err - } - return now, nil -} diff --git a/impl/reflectstruct.go b/impl/reflectstruct.go deleted file mode 100644 index c92c230..0000000 --- a/impl/reflectstruct.go +++ /dev/null @@ -1,117 +0,0 @@ -package impl - -import ( - "errors" - "fmt" - "reflect" - "strings" - - "golang.org/x/exp/slices" - - "github.com/domonda/go-sqldb" -) - -func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { - for i := 0; i < structVal.NumField(); i++ { - fieldType := structVal.Type().Field(i) - _, column, flags, use := namer.MapStructField(fieldType) - if !use { - continue - } - fieldValue := structVal.Field(i) - - if column == "" { - // Embedded struct field - columnsEmbed, pkColsEmbed, valuesEmbed := ReflectStructValues(fieldValue, namer, ignoreColumns) - for _, pkCol := range pkColsEmbed { - pkCols = append(pkCols, pkCol+len(columns)) - } - columns = append(columns, columnsEmbed...) - values = append(values, valuesEmbed...) - continue - } - - if ignoreColumn(ignoreColumns, column, flags, fieldType, fieldValue) { - continue - } - - if flags.PrimaryKey() { - pkCols = append(pkCols, len(columns)) - } - columns = append(columns, column) - values = append(values, fieldValue.Interface()) - } - return columns, pkCols, values -} - -func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { - if len(columns) == 0 { - return nil, errors.New("no columns") - } - pointers = make([]any, len(columns)) - err = reflectStructColumnPointers(structVal, namer, columns, pointers) - if err != nil { - return nil, err - } - for _, ptr := range pointers { - if ptr != nil { - continue - } - nilCols := new(strings.Builder) - for i, ptr := range pointers { - if ptr != nil { - continue - } - if nilCols.Len() > 0 { - nilCols.WriteString(", ") - } - fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) - } - return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) - } - return pointers, nil -} - -func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error { - var ( - structType = structVal.Type() - ) - for i := 0; i < structType.NumField(); i++ { - fieldType := structType.Field(i) - _, column, _, use := namer.MapStructField(fieldType) - if !use { - continue - } - fieldValue := structVal.Field(i) - - if column == "" { - // Embedded struct field - err := reflectStructColumnPointers(fieldValue, namer, columns, pointers) - if err != nil { - return err - } - continue - } - - colIndex := slices.Index(columns, column) - if colIndex == -1 { - continue - } - - if pointers[colIndex] != nil { - return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, fieldType.Name, structType) - } - - pointers[colIndex] = fieldValue.Addr().Interface() - } - return nil -} - -func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, filter := range filters { - if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { - return true - } - } - return false -} diff --git a/impl/rowscanner.go b/impl/rowscanner.go deleted file mode 100644 index 53956b7..0000000 --- a/impl/rowscanner.go +++ /dev/null @@ -1,128 +0,0 @@ -package impl - -import ( - "database/sql" - - sqldb "github.com/domonda/go-sqldb" -) - -var ( - _ sqldb.RowScanner = &RowScanner{} - _ sqldb.RowScanner = CurrentRowScanner{} - _ sqldb.RowScanner = SingleRowScanner{} -) - -// RowScanner implements sqldb.RowScanner for a sql.Row -type RowScanner struct { - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowScanner(rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowScanner { - return &RowScanner{rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowScanner) Scan(dest ...any) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return s.rows.Scan(dest...) -} - -func (s *RowScanner) ScanStruct(dest any) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return ScanStruct(s.rows, dest, s.structFieldNamer) -} - -func (s *RowScanner) ScanValues() ([]any, error) { - return ScanValues(s.rows) -} - -func (s *RowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.rows) -} - -func (s *RowScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -// CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close -type CurrentRowScanner struct { - Rows Rows - StructFieldMapper sqldb.StructFieldMapper -} - -func (s CurrentRowScanner) Scan(dest ...any) error { - return s.Rows.Scan(dest...) -} - -func (s CurrentRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Rows, dest, s.StructFieldMapper) -} - -func (s CurrentRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Rows) -} - -func (s CurrentRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Rows) -} - -func (s CurrentRowScanner) Columns() ([]string, error) { - return s.Rows.Columns() -} - -// SingleRowScanner always uses the same Row -type SingleRowScanner struct { - Row Row - StructFieldMapper sqldb.StructFieldMapper -} - -func (s SingleRowScanner) Scan(dest ...any) error { - return s.Row.Scan(dest...) -} - -func (s SingleRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Row, dest, s.StructFieldMapper) -} - -func (s SingleRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Row) -} - -func (s SingleRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Row) -} - -func (s SingleRowScanner) Columns() ([]string, error) { - return s.Row.Columns() -} diff --git a/impl/rowsscanner.go b/impl/rowsscanner.go deleted file mode 100644 index f833e59..0000000 --- a/impl/rowsscanner.go +++ /dev/null @@ -1,95 +0,0 @@ -package impl - -import ( - "context" - "fmt" - - sqldb "github.com/domonda/go-sqldb" -) - -var _ sqldb.RowsScanner = &RowsScanner{} - -// RowsScanner implements sqldb.RowsScanner with Rows -type RowsScanner struct { - ctx context.Context // ctx is checked for every row and passed through to callbacks - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowsScanner(ctx context.Context, rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowsScanner { - return &RowsScanner{ctx, rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowsScanner) ScanSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, nil) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) ScanStructSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldNamer) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -func (s *RowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { - cols, err := s.rows.Columns() - if err != nil { - return nil, err - } - if headerRow { - rows = [][]string{cols} - } - stringScannablePtrs := make([]any, len(cols)) - err = s.ForEachRow(func(rowScanner sqldb.RowScanner) error { - row := make([]string, len(cols)) - for i := range stringScannablePtrs { - stringScannablePtrs[i] = (*sqldb.StringScannable)(&row[i]) - } - err := rowScanner.Scan(stringScannablePtrs...) - if err != nil { - return err - } - rows = append(rows, row) - return nil - }) - return rows, err -} - -func (s *RowsScanner) ForEachRow(callback func(sqldb.RowScanner) error) (err error) { - defer func() { - err = combineErrors(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - for s.rows.Next() { - if s.ctx.Err() != nil { - return s.ctx.Err() - } - - err := callback(CurrentRowScanner{s.rows, s.structFieldNamer}) - if err != nil { - return err - } - } - return s.rows.Err() -} - -func (s *RowsScanner) ForEachRowCall(callback any) error { - forEachRowFunc, err := ForEachRowCallFunc(s.ctx, callback) - if err != nil { - return err - } - return s.ForEachRow(forEachRowFunc) -} diff --git a/impl/scanstruct.go b/impl/scanstruct.go deleted file mode 100644 index ce2dd56..0000000 --- a/impl/scanstruct.go +++ /dev/null @@ -1,53 +0,0 @@ -package impl - -import ( - "fmt" - "reflect" - - sqldb "github.com/domonda/go-sqldb" -) - -func ScanStruct(srcRow Row, destStruct any, namer sqldb.StructFieldMapper) error { - v := reflect.ValueOf(destStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - - var ( - setDestStructPtr = false - destStructPtr reflect.Value - newStructPtr reflect.Value - ) - if v.Kind() == reflect.Ptr && v.IsNil() && v.CanSet() { - // Got a nil pointer that we can set with a newly allocated struct - setDestStructPtr = true - destStructPtr = v - newStructPtr = reflect.New(v.Type().Elem()) - // Continue with the newly allocated struct - v = newStructPtr.Elem() - } - if v.Kind() != reflect.Struct { - return fmt.Errorf("ScanStruct: expected struct but got %T", destStruct) - } - - columns, err := srcRow.Columns() - if err != nil { - return err - } - - fieldPointers, err := ReflectStructColumnPointers(v, namer, columns) - if err != nil { - return fmt.Errorf("ScanStruct: %w", err) - } - - err = srcRow.Scan(fieldPointers...) - if err != nil { - return err - } - - if setDestStructPtr { - destStructPtr.Set(newStructPtr) - } - - return nil -} diff --git a/impl/transaction.go b/impl/transaction.go deleted file mode 100644 index 1363fe9..0000000 --- a/impl/transaction.go +++ /dev/null @@ -1,169 +0,0 @@ -package impl - -import ( - "context" - "database/sql" - "fmt" - "time" - - "github.com/domonda/go-sqldb" -) - -type transaction struct { - // The parent non-transaction connection is needed - // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - structFieldNamer sqldb.StructFieldMapper -} - -func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { - return &transaction{ - parent: parent, - tx: tx, - opts: opts, - structFieldNamer: parent.structFieldNamer, - } -} - -func (conn *transaction) clone() *transaction { - c := *conn - return &c -} - -func (conn *transaction) Context() context.Context { return conn.parent.ctx } - -func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.parent.ctx { - return conn - } - parent := conn.parent.clone() - parent.ctx = ctx - return newTransaction(parent, conn.tx, conn.opts) -} - -func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } -func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() } -func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } - -func (conn *transaction) ValidateColumnName(name string) error { - return conn.parent.validateColumnName(name) -} - -func (conn *transaction) Now() (time.Time, error) { - return Now(conn) -} - -func (conn *transaction) Exec(query string, args ...any) error { - _, err := conn.tx.Exec(query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) -} - -func (conn *transaction) Insert(table string, columValues sqldb.Values) error { - return Insert(conn, table, conn.parent.argFmt, columValues) -} - -func (conn *transaction) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return InsertUnique(conn, table, conn.parent.argFmt, values, onConflict) -} - -func (conn *transaction) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return InsertReturning(conn, table, conn.parent.argFmt, values, returning) -} - -func (conn *transaction) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.parent.argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.parent.argFmt, args) -} - -func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) -} - -func (conn *transaction) IsTransaction() bool { - return true -} - -func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn *transaction) Begin(opts *sql.TxOptions) (sqldb.Connection, error) { - tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn.parent, tx, opts), nil -} - -func (conn *transaction) Commit() error { - return conn.tx.Commit() -} - -func (conn *transaction) Rollback() error { - return conn.tx.Rollback() -} - -func (conn *transaction) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) -} - -func (conn *transaction) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) -} - -func (conn *transaction) IsListeningOnChannel(channel string) bool { - return false -} - -func (conn *transaction) Close() error { - return conn.Rollback() -} diff --git a/impl/upsert.go b/impl/upsert.go deleted file mode 100644 index 7aa92f9..0000000 --- a/impl/upsert.go +++ /dev/null @@ -1,63 +0,0 @@ -package impl - -import ( - "fmt" - "reflect" - "strings" - - sqldb "github.com/domonda/go-sqldb" - "golang.org/x/exp/slices" -) - -// UpsertStruct upserts a row to table using the exported fields -// of rowStruct which have a `db` tag that is not "-". -// Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. -// If restrictToColumns are provided, then only struct fields with a `db` tag -// matching any of the passed column names will be used. -// If inserting conflicts on pkColumn, then an update of the existing row is performed. -func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return fmt.Errorf("UpsertStruct to table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return fmt.Errorf("UpsertStruct to table %s: expected struct but got %T", table, rowStruct) - } - - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - if len(pkCols) == 0 { - return fmt.Errorf("UpsertStruct of table %s: %s has no mapped primary key field", table, v.Type()) - } - - var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) - b.WriteString(` ON CONFLICT(`) - for i, pkCol := range pkCols { - if i > 0 { - b.WriteByte(',') - } - fmt.Fprintf(&b, `"%s"`, columns[pkCol]) - } - - b.WriteString(`) DO UPDATE SET `) - first := true - for i := range columns { - if slices.Contains(pkCols, i) { - continue - } - if first { - first = false - } else { - b.WriteByte(',') - } - fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) - } - query := b.String() - - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} diff --git a/mockconn/connection.go b/mockconn/connection.go index de4457e..f42b4cd 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -8,29 +8,24 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) -var DefaultArgFmt = "$%d" +var DefaultParamPlaceholderFormatter = sqldb.NewParamPlaceholderFormatter("$%d", 1) func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) sqldb.Connection { return &connection{ - ctx: ctx, - queryWriter: queryWriter, - listening: newBoolMap(), - rowsProvider: rowsProvider, - structFieldNamer: sqldb.DefaultStructFieldMapping, - argFmt: DefaultArgFmt, + ctx: ctx, + queryWriter: queryWriter, + listening: newBoolMap(), + rowsProvider: rowsProvider, } } type connection struct { - ctx context.Context - queryWriter io.Writer - listening *boolMap - rowsProvider RowsProvider - structFieldNamer sqldb.StructFieldMapper - argFmt string + ctx context.Context + queryWriter io.Writer + listening *boolMap + rowsProvider RowsProvider } func (conn *connection) Context() context.Context { return conn.ctx } @@ -40,34 +35,21 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return conn } return &connection{ - ctx: ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldNamer: conn.structFieldNamer, - argFmt: conn.argFmt, + ctx: ctx, + queryWriter: conn.queryWriter, + listening: conn.listening, + rowsProvider: conn.rowsProvider, } } -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - return &connection{ - ctx: conn.ctx, - queryWriter: conn.queryWriter, - listening: conn.listening, - rowsProvider: conn.rowsProvider, - structFieldNamer: namer, - argFmt: conn.argFmt, - } -} - -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - func (conn *connection) Stats() sql.DBStats { return sql.DBStats{} } +func (conn *connection) Ping(time.Duration) error { + return nil +} + func (conn *connection) Config() *sqldb.Config { return &sqldb.Config{Driver: "mockconn", Host: "localhost", Database: "mock"} } @@ -76,7 +58,11 @@ func (conn *connection) ValidateColumnName(name string) error { return validateColumnName(name) } -func (conn *connection) Ping(time.Duration) error { +func (*connection) ParamPlaceholder(index int) string { + return fmt.Sprintf("$%d", index+1) +} + +func (conn *connection) Err() error { return nil } @@ -91,70 +77,30 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, conn.argFmt, columValues) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, conn.argFmt, values, onConflict) -} - -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, conn.argFmt, values, returning) -} - -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *connection) QueryRow(query string, args ...any) sqldb.Row { if conn.ctx.Err() != nil { - return sqldb.RowScannerWithError(conn.ctx.Err()) + return sqldb.RowWithError(conn.ctx.Err()) } if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, query) } if conn.rowsProvider == nil { - return sqldb.RowScannerWithError(nil) + return sqldb.RowWithError(nil) } - return conn.rowsProvider.QueryRow(conn.structFieldNamer, query, args...) + return conn.rowsProvider.QueryRow(query, args...) } -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *connection) QueryRows(query string, args ...any) sqldb.Rows { if conn.ctx.Err() != nil { - return sqldb.RowsScannerWithError(conn.ctx.Err()) + return sqldb.RowsWithError(conn.ctx.Err()) } if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, query) } if conn.rowsProvider == nil { - return sqldb.RowsScannerWithError(nil) + return sqldb.RowsWithError(nil) } - return conn.rowsProvider.QueryRows(conn.structFieldNamer, query, args...) + return conn.rowsProvider.QueryRows(query, args...) } func (conn *connection) IsTransaction() bool { diff --git a/mockconn/connection_test.go b/mockconn/connection_test.go index e99e034..0eb2366 100644 --- a/mockconn/connection_test.go +++ b/mockconn/connection_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/db" + "github.com/domonda/go-sqldb/reflection" "github.com/domonda/go-types/uu" ) @@ -18,7 +20,7 @@ type embed struct { } type testRow struct { - ID uu.ID `db:"id,pk"` + ID uu.ID `db:"id,pk=public.table"` Int int `db:"int"` embed Str string `db:"str"` @@ -32,10 +34,11 @@ type testRow struct { } func TestInsertQuery(t *testing.T) { - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + context.Background() + naming := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} queryOutput := bytes.NewBuffer(nil) rowProvider := NewSingleRowProvider(NewRow(struct{ True bool }{true}, naming)) - conn := New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming)) str := "Hello World!" values := sqldb.Values{ @@ -51,13 +54,13 @@ func TestInsertQuery(t *testing.T) { } expected := `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` - err := conn.Insert("public.table", values) + err := db.Insert(ctx, "public.table", values) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` - inserted, err := conn.InsertUnique("public.table", values, "id") + inserted, err := db.InsertUnique(ctx, "public.table", values, "id") assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) @@ -65,67 +68,67 @@ func TestInsertQuery(t *testing.T) { func TestInsertStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(testRow) expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` - err := conn.InsertStruct("public.table", row) + err := db.InsertStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3)` - err = conn.InsertStruct("public.table", row, sqldb.OnlyColumns("id", "untagged_field", "bools")) + err = db.InsertStruct(ctx, row, sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6)` - err = conn.InsertStruct("public.table", row, sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) + err = db.InsertStruct(ctx, row, sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestInsertUniqueStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } rowProvider := NewSingleRowProvider(NewRow(struct{ True bool }{true}, naming)) - conn := New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming)) row := new(testRow) expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` - inserted, err := conn.InsertUniqueStruct("public.table", row, "(id)") + inserted, err := db.InsertUniqueStruct(ctx, row, "(id)") assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3) ON CONFLICT (id, untagged_field) DO NOTHING RETURNING TRUE` - inserted, err = conn.InsertUniqueStruct("public.table", row, "(id, untagged_field)", sqldb.OnlyColumns("id", "untagged_field", "bools")) + inserted, err = db.InsertUniqueStruct(ctx, row, "(id, untagged_field)", sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6) ON CONFLICT (id) DO NOTHING RETURNING TRUE` - inserted, err = conn.InsertUniqueStruct("public.table", row, "(id)", sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) + inserted, err = db.InsertUniqueStruct(ctx, row, "(id)", sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) @@ -133,8 +136,8 @@ func TestInsertUniqueStructQuery(t *testing.T) { func TestUpdateQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + naming := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) str := "Hello World!" values := sqldb.Values{ @@ -149,21 +152,21 @@ func TestUpdateQuery(t *testing.T) { } expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1` - err := conn.Update("public.table", values, "id = $1", 1) + err := db.Update(ctx, "public.table", values, "id = $1", 1) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$3,"bools"=$4,"created_at"=$5,"int"=$6,"nil_ptr"=$7,"str"=$8,"str_ptr"=$9,"untagged_field"=$10 WHERE a = $1 AND b = $2` - err = conn.Update("public.table", values, "a = $1 AND b = $2", 1, 2) + err = db.Update(ctx, "public.table", values, "a = $1 AND b = $2", 1, 2) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpdateReturningQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + naming := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) str := "Hello World!" values := sqldb.Values{ @@ -178,72 +181,72 @@ func TestUpdateReturningQuery(t *testing.T) { } expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING *` - err := conn.UpdateReturningRow("public.table", values, "*", "id = $1", 1).Scan() + err := db.UpdateReturningRow(ctx, "public.table", values, "*", "id = $1", 1).Scan() assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING created_at,untagged_field` - err = conn.UpdateReturningRows("public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) + err = db.UpdateReturningRows(ctx, "public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpdateStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(testRow) expected := `UPDATE public.table SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9 WHERE "id"=$1` - err := conn.UpdateStruct("public.table", row) + err := db.UpdateStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$2,"str"=$3,"created_at"=$4 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) + err = db.UpdateStruct(ctx, row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "int"=$2,"bool"=$3,"str_ptr"=$4,"nil_ptr"=$5,"created_at"=$6 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) + err = db.UpdateStruct(ctx, row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpsertStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(testRow) expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + ` ON CONFLICT("id") DO UPDATE SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9` - err := conn.UpsertStruct("public.table", row) + err := db.UpsertStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } type multiPrimaryKeyRow struct { - FirstID string `db:"first_id,pk"` + FirstID string `db:"first_id,pk=public.multi_pk"` SecondID string `db:"second_id,pk"` ThirdID string `db:"third_id,pk"` @@ -252,40 +255,40 @@ type multiPrimaryKeyRow struct { func TestUpsertStructMultiPKQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(multiPrimaryKeyRow) expected := `INSERT INTO public.multi_pk("first_id","second_id","third_id","created_at") VALUES($1,$2,$3,$4) ON CONFLICT("first_id","second_id","third_id") DO UPDATE SET "created_at"=$4` - err := conn.UpsertStruct("public.multi_pk", row) + err := db.UpsertStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpdateStructMultiPKQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &reflection.TaggedStructFieldMapping{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", ReadOnly: "readonly", Default: "default", - UntaggedNameFunc: sqldb.ToSnakeCase, + UntaggedNameFunc: reflection.ToSnakeCase, } - conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming)) row := new(multiPrimaryKeyRow) expected := `UPDATE public.multi_pk SET "created_at"=$4 WHERE "first_id"=$1 AND "second_id"=$2 AND "third_id"=$3` - err := conn.UpdateStruct("public.multi_pk", row) + err := db.UpdateStruct(ctx, row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } diff --git a/mockconn/onetimerowsprovider.go b/mockconn/onetimerowsprovider.go index 51ca5ce..059f9ee 100644 --- a/mockconn/onetimerowsprovider.go +++ b/mockconn/onetimerowsprovider.go @@ -7,22 +7,23 @@ import ( "sync" sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) type OneTimeRowsProvider struct { - rowScanners map[string]sqldb.RowScanner - rowsScanners map[string]sqldb.RowsScanner + rowScanners map[string]sqldb.Row + rowsScanners map[string]sqldb.Rows mtx sync.Mutex } func NewOneTimeRowsProvider() *OneTimeRowsProvider { return &OneTimeRowsProvider{ - rowScanners: make(map[string]sqldb.RowScanner), - rowsScanners: make(map[string]sqldb.RowsScanner), + rowScanners: make(map[string]sqldb.Row), + rowsScanners: make(map[string]sqldb.Rows), } } -func (p *OneTimeRowsProvider) AddRowScannerQuery(scanner sqldb.RowScanner, query string, args ...any) { +func (p *OneTimeRowsProvider) AddRowQuery(scanner sqldb.Row, query string, args ...any) { p.mtx.Lock() defer p.mtx.Unlock() @@ -33,7 +34,7 @@ func (p *OneTimeRowsProvider) AddRowScannerQuery(scanner sqldb.RowScanner, query p.rowScanners[key] = scanner } -func (p *OneTimeRowsProvider) AddRowsScannerQuery(scanner sqldb.RowsScanner, query string, args ...any) { +func (p *OneTimeRowsProvider) AddRowsQuery(scanner sqldb.Rows, query string, args ...any) { p.mtx.Lock() defer p.mtx.Unlock() @@ -44,7 +45,7 @@ func (p *OneTimeRowsProvider) AddRowsScannerQuery(scanner sqldb.RowsScanner, que p.rowsScanners[key] = scanner } -func (p *OneTimeRowsProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { +func (p *OneTimeRowsProvider) QueryRow(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.Row { p.mtx.Lock() defer p.mtx.Unlock() @@ -54,7 +55,7 @@ func (p *OneTimeRowsProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, return scanner } -func (p *OneTimeRowsProvider) QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { +func (p *OneTimeRowsProvider) QueryRows(structFieldMapper reflection.StructFieldMapper, query string, args ...any) sqldb.Rows { p.mtx.Lock() defer p.mtx.Unlock() diff --git a/mockconn/row.go b/mockconn/row.go index 928b47e..054d0fc 100644 --- a/mockconn/row.go +++ b/mockconn/row.go @@ -9,35 +9,35 @@ import ( "strconv" "time" - sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) // Row implements impl.Row with the fields of a struct as column values. type Row struct { - rowStructVal reflect.Value - columnNamer sqldb.StructFieldMapper + rowStructVal reflect.Value + structFieldMapper reflection.StructFieldMapper } -func NewRow(rowStruct any, columnNamer sqldb.StructFieldMapper) *Row { +func NewRow(rowStruct any, mapper reflection.StructFieldMapper) *Row { val := reflect.ValueOf(rowStruct) for val.Kind() == reflect.Ptr { val = val.Elem() } return &Row{ - rowStructVal: val, - columnNamer: columnNamer, + rowStructVal: val, + structFieldMapper: mapper, } } -func (r *Row) StructFieldMapper() sqldb.StructFieldMapper { - return r.columnNamer +func (r *Row) StructFieldMapper() reflection.StructFieldMapper { + return r.structFieldMapper } func (r *Row) Columns() ([]string, error) { columns := make([]string, r.rowStructVal.NumField()) for i := range columns { field := r.rowStructVal.Type().Field(i) - _, columns[i], _, _ = r.columnNamer.MapStructField(field) + _, columns[i], _, _ = r.structFieldMapper.MapStructField(field) } return columns, nil } diff --git a/mockconn/row_test.go b/mockconn/row_test.go index f63ec70..ee0c8bc 100644 --- a/mockconn/row_test.go +++ b/mockconn/row_test.go @@ -3,10 +3,9 @@ package mockconn import ( "testing" + "github.com/domonda/go-sqldb/reflection" "github.com/lib/pq" "github.com/stretchr/testify/assert" - - sqldb "github.com/domonda/go-sqldb" ) func TestRow(t *testing.T) { @@ -21,8 +20,8 @@ func TestRow(t *testing.T) { str := "Hello World!" input := Struct{"myID", 66, -1, &str, nil, pq.BoolArray{true, false}} - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - row := NewRow(input, naming) + mapping := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} + row := NewRow(input, mapping) cols, err := row.Columns() assert.NoError(t, err) diff --git a/mockconn/rows.go b/mockconn/rows.go index 8698fe5..14da2e0 100644 --- a/mockconn/rows.go +++ b/mockconn/rows.go @@ -4,7 +4,7 @@ import ( "errors" "reflect" - sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/reflection" ) type Rows struct { @@ -14,7 +14,7 @@ type Rows struct { err error } -func NewRowsFromStructs(rowStructs any, columnNamer sqldb.StructFieldMapper) *Rows { +func NewRowsFromStructs(rowStructs any, columnNamer reflection.StructFieldMapper) *Rows { v := reflect.ValueOf(rowStructs) t := v.Type() if t.Kind() != reflect.Array && t.Kind() != reflect.Slice { diff --git a/mockconn/rows_test.go b/mockconn/rows_test.go index 608cbb1..f64ec9f 100644 --- a/mockconn/rows_test.go +++ b/mockconn/rows_test.go @@ -4,10 +4,9 @@ import ( "fmt" "testing" + "github.com/domonda/go-sqldb/reflection" "github.com/lib/pq" "github.com/stretchr/testify/assert" - - sqldb "github.com/domonda/go-sqldb" ) func TestRows(t *testing.T) { @@ -26,8 +25,8 @@ func TestRows(t *testing.T) { input = append(input, &Struct{"myID", i, -1, &str, nil, pq.BoolArray{true, false, i%2 == 0}}) } - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} - rows := NewRowsFromStructs(input, naming) + mapping := &reflection.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: reflection.ToSnakeCase} + rows := NewRowsFromStructs(input, mapping) cols, err := rows.Columns() assert.NoError(t, err) diff --git a/mockconn/rowsprovider.go b/mockconn/rowsprovider.go index 9cf8394..74526dc 100644 --- a/mockconn/rowsprovider.go +++ b/mockconn/rowsprovider.go @@ -5,6 +5,6 @@ import ( ) type RowsProvider interface { - QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner - QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner + QueryRow(query string, args ...any) sqldb.Row + QueryRows(query string, args ...any) sqldb.Rows } diff --git a/mockconn/singlerowprovider.go b/mockconn/singlerowprovider.go index 8c8e39e..f1c9ff1 100644 --- a/mockconn/singlerowprovider.go +++ b/mockconn/singlerowprovider.go @@ -4,26 +4,25 @@ import ( "context" sqldb "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) // NewSingleRowProvider a RowsProvider implementation // with a single row that will be re-used for every query. func NewSingleRowProvider(row *Row) RowsProvider { - return &singleRowProvider{row: row, argFmt: DefaultArgFmt} + return &singleRowProvider{row: row, argFmt: DefaultParamPlaceholderFormatter} } // SingleRowProvider implements RowsProvider with a single Row // that will be re-used for every query. type singleRowProvider struct { row *Row - argFmt string + argFmt sqldb.ParamPlaceholderFormatter } -func (p *singleRowProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { - return impl.NewRowScanner(impl.RowAsRows(p.row), structFieldNamer, query, p.argFmt, args) +func (p *singleRowProvider) QueryRow(query string, args ...any) sqldb.Row { + return sqldb.NewRow(context.Background(), sqldb.RowAsRows(p.row), query, p.argFmt, args) } -func (p *singleRowProvider) QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { - return impl.NewRowsScanner(context.Background(), NewRows(p.row), structFieldNamer, query, p.argFmt, args) +func (p *singleRowProvider) QueryRows(query string, args ...any) sqldb.Rows { + return sqldb.NewRows(context.Background(), sqldb.NewRows(p.row), query, p.argFmt, args) } diff --git a/mysqlconn/config.go b/mysqlconn/config.go index 3622647..18b48f5 100644 --- a/mysqlconn/config.go +++ b/mysqlconn/config.go @@ -2,8 +2,6 @@ package mysqlconn import "github.com/go-sql-driver/mysql" -const argFmt = "?" - type Config = mysql.Config // NewConfig creates a new Config and sets default values. diff --git a/mysqlconn/connection.go b/mysqlconn/connection.go index 46be636..e78ebc9 100644 --- a/mysqlconn/connection.go +++ b/mysqlconn/connection.go @@ -4,9 +4,9 @@ import ( "context" "database/sql" "fmt" + "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) // New creates a new sqldb.Connection using the passed sqldb.Config @@ -23,7 +23,12 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { if err != nil { return nil, err } - return impl.Connection(ctx, db, config, validateColumnName, argFmt), nil + conn := &connection{ + ctx: ctx, + db: db, + config: config, + } + return conn, nil } // MustNew creates a new sqldb.Connection using the passed sqldb.Config @@ -38,3 +43,126 @@ func MustNew(ctx context.Context, config *sqldb.Config) sqldb.Connection { } return conn } + +type connection struct { + ctx context.Context + db *sql.DB + config *sqldb.Config +} + +func (conn *connection) clone() *connection { + c := *conn + return &c +} + +func (conn *connection) Context() context.Context { return conn.ctx } + +func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { + if ctx == conn.ctx { + return conn + } + c := conn.clone() + c.ctx = ctx + return c +} + +func (conn *connection) Ping(timeout time.Duration) error { + ctx := conn.ctx + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return conn.db.PingContext(ctx) +} + +func (conn *connection) Stats() sql.DBStats { + return conn.db.Stats() +} + +func (conn *connection) Config() *sqldb.Config { + return conn.config +} + +func (conn *connection) ValidateColumnName(name string) error { + return validateColumnName(name) +} + +func (conn *connection) ParamPlaceholder(index int) string { + return "?" +} + +func (conn *connection) Err() error { + return nil +} + +func (conn *connection) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil +} + +func (conn *connection) Exec(query string, args ...any) error { + _, err := conn.db.ExecContext(conn.ctx, query, args...) + return sqldb.WrapErrorWithQuery(err, query, conn, args) +} + +func (conn *connection) QueryRow(query string, args ...any) sqldb.Row { + rows, err := conn.db.QueryContext(conn.ctx, query, args...) + if err != nil { + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) + } + return sqldb.NewRow(conn.ctx, rows, query, conn, args) +} + +func (conn *connection) QueryRows(query string, args ...any) sqldb.Rows { + rows, err := conn.db.QueryContext(conn.ctx, query, args...) + if err != nil { + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) + } + return sqldb.NewRows(conn.ctx, rows, query, conn, args) +} + +func (conn *connection) IsTransaction() bool { + return false +} + +func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { + return nil, false +} + +func (conn *connection) Begin(opts *sql.TxOptions) (sqldb.Connection, error) { + tx, err := conn.db.BeginTx(conn.ctx, opts) + if err != nil { + return nil, err + } + return newTransaction(conn, tx, opts), nil +} + +func (conn *connection) Commit() error { + return sqldb.ErrNotWithinTransaction +} + +func (conn *connection) Rollback() error { + return sqldb.ErrNotWithinTransaction +} + +func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { + return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) +} + +func (conn *connection) UnlistenChannel(channel string) (err error) { + return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) +} + +func (conn *connection) IsListeningOnChannel(channel string) bool { + return false +} + +func (conn *connection) Close() error { + return conn.db.Close() +} diff --git a/mysqlconn/transaction.go b/mysqlconn/transaction.go new file mode 100644 index 0000000..281a4d9 --- /dev/null +++ b/mysqlconn/transaction.go @@ -0,0 +1,129 @@ +package mysqlconn + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/domonda/go-sqldb" +) + +type transaction struct { + // The parent non-transaction connection is needed + // for its ctx, Ping(), Stats(), and Config() + parent *connection + tx *sql.Tx + opts *sql.TxOptions +} + +func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { + return &transaction{ + parent: parent, + tx: tx, + opts: opts, + } +} + +func (conn *transaction) clone() *transaction { + c := *conn + return &c +} + +func (conn *transaction) Context() context.Context { return conn.parent.ctx } + +func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { + if ctx == conn.parent.ctx { + return conn + } + parent := conn.parent.clone() + parent.ctx = ctx + return newTransaction(parent, conn.tx, conn.opts) +} + +func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } +func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() } +func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } + +func (conn *transaction) ValidateColumnName(name string) error { + return validateColumnName(name) +} + +func (conn *transaction) ParamPlaceholder(index int) string { + return conn.parent.ParamPlaceholder(index) +} + +func (conn *transaction) Err() error { + return conn.parent.Err() +} + +func (conn *transaction) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil +} + +func (conn *transaction) Exec(query string, args ...any) error { + _, err := conn.tx.Exec(query, args...) + return sqldb.WrapErrorWithQuery(err, query, conn, args) +} + +func (conn *transaction) QueryRow(query string, args ...any) sqldb.Row { + rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) + if err != nil { + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) + } + return sqldb.NewRow(conn.parent.ctx, rows, query, conn, args) +} + +func (conn *transaction) QueryRows(query string, args ...any) sqldb.Rows { + rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) + if err != nil { + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) + } + return sqldb.NewRows(conn.parent.ctx, rows, query, conn, args) +} + +func (conn *transaction) IsTransaction() bool { + return true +} + +func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) { + return conn.opts, true +} + +func (conn *transaction) Begin(opts *sql.TxOptions) (sqldb.Connection, error) { + tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) + if err != nil { + return nil, err + } + return newTransaction(conn.parent, tx, opts), nil +} + +func (conn *transaction) Commit() error { + return conn.tx.Commit() +} + +func (conn *transaction) Rollback() error { + return conn.tx.Rollback() +} + +func (conn *transaction) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { + return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) +} + +func (conn *transaction) UnlistenChannel(channel string) (err error) { + return fmt.Errorf("notifications %w", sqldb.ErrNotSupported) +} + +func (conn *transaction) IsListeningOnChannel(channel string) bool { + return false +} + +func (conn *transaction) Close() error { + return conn.Rollback() +} diff --git a/paramplaceholder.go b/paramplaceholder.go new file mode 100644 index 0000000..feecdb6 --- /dev/null +++ b/paramplaceholder.go @@ -0,0 +1,23 @@ +package sqldb + +import "fmt" + +type ParamPlaceholderFormatter interface { + // ParamPlaceholder returns a parameter value placeholder + // for the parameter with the passed zero based index + // specific to the database type of the connection. + ParamPlaceholder(index int) string +} + +func NewParamPlaceholderFormatter(format string, indexOffset int) ParamPlaceholderFormatter { + return ¶mPlaceholderFormatter{format, indexOffset} +} + +type paramPlaceholderFormatter struct { + format string + indexOffset int +} + +func (f *paramPlaceholderFormatter) ParamPlaceholder(index int) string { + return fmt.Sprintf(f.format, index+f.indexOffset) +} diff --git a/pqconn/connection.go b/pqconn/connection.go index 65c01a8..2f610f8 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -7,11 +7,8 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) -const argFmt = "$%d" - // New creates a new sqldb.Connection using the passed sqldb.Config // and github.com/lib/pq as driver implementation. // The connection is pinged with the passed context @@ -27,10 +24,9 @@ func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { return nil, err } return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldNamer: sqldb.DefaultStructFieldMapping, + ctx: ctx, + db: db, + config: config, }, nil } @@ -48,10 +44,9 @@ func MustNew(ctx context.Context, config *sqldb.Config) sqldb.Connection { } type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper + ctx context.Context + db *sql.DB + config *sqldb.Config } func (conn *connection) clone() *connection { @@ -70,16 +65,6 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return c } -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - func (conn *connection) Ping(timeout time.Duration) error { ctx := conn.ctx if timeout > 0 { @@ -102,71 +87,43 @@ func (conn *connection) ValidateColumnName(name string) error { return validateColumnName(name) } -func (conn *connection) Now() (time.Time, error) { - return impl.Now(conn) -} - -func (conn *connection) Exec(query string, args ...any) error { - _, err := conn.db.ExecContext(conn.ctx, query, args...) - return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) -} - -func (conn *connection) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, argFmt, columValues) -} - -func (conn *connection) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, argFmt, values, onConflict) +func (*connection) ParamPlaceholder(index int) string { + return fmt.Sprintf("$%d", index+1) } -func (conn *connection) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, argFmt, values, returning) +func (conn *connection) Err() error { + return conn.config.Err } -func (conn *connection) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) +func (conn *connection) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil } -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) +func (conn *connection) Exec(query string, args ...any) error { + _, err := conn.db.ExecContext(conn.ctx, query, args...) + return sqldb.WrapErrorWithQuery(err, query, conn, args) } -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *connection) QueryRow(query string, args ...any) sqldb.Row { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) - return sqldb.RowScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) } - return impl.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRow(conn.ctx, rows, query, conn, args) } -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *connection) QueryRows(query string, args ...any) sqldb.Rows { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) - return sqldb.RowsScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) } - return impl.NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRows(conn.ctx, rows, query, conn, args) } func (conn *connection) IsTransaction() bool { diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 4019f5e..dc01192 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -6,24 +6,21 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) type transaction struct { // The parent non-transaction connection is needed // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - structFieldNamer sqldb.StructFieldMapper + parent *connection + tx *sql.Tx + opts *sql.TxOptions } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions) *transaction { return &transaction{ - parent: parent, - tx: tx, - opts: opts, - structFieldNamer: parent.structFieldNamer, + parent: parent, + tx: tx, + opts: opts, } } @@ -43,16 +40,6 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { return newTransaction(parent, conn.tx, conn.opts) } -func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() } func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } @@ -61,71 +48,43 @@ func (conn *transaction) ValidateColumnName(name string) error { return validateColumnName(name) } -func (conn *transaction) Now() (time.Time, error) { - return impl.Now(conn) -} - -func (conn *transaction) Exec(query string, args ...any) error { - _, err := conn.tx.Exec(query, args...) - return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) -} - -func (conn *transaction) Insert(table string, columValues sqldb.Values) error { - return impl.Insert(conn, table, argFmt, columValues) -} - -func (conn *transaction) InsertUnique(table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - return impl.InsertUnique(conn, table, argFmt, values, onConflict) -} - -func (conn *transaction) InsertReturning(table string, values sqldb.Values, returning string) sqldb.RowScanner { - return impl.InsertReturning(conn, table, argFmt, values, returning) -} - -func (conn *transaction) InsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.InsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) InsertUniqueStruct(table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - return impl.InsertUniqueStruct(conn, table, rowStruct, onConflict, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) +func (conn *transaction) ParamPlaceholder(index int) string { + return conn.parent.ParamPlaceholder(index) } -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) +func (conn *transaction) Err() error { + return conn.parent.config.Err } -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) +func (conn *transaction) Now() (now time.Time, err error) { + err = conn.QueryRow(`select now()`).Scan(&now) + if err != nil { + return time.Time{}, err + } + return now, nil } -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) +func (conn *transaction) Exec(query string, args ...any) error { + _, err := conn.tx.ExecContext(conn.parent.ctx, query, args...) + return sqldb.WrapErrorWithQuery(err, query, conn, args) } -func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { +func (conn *transaction) QueryRow(query string, args ...any) sqldb.Row { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) - return sqldb.RowScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowWithError(err) } - return impl.NewRowScanner(rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRow(conn.parent.ctx, rows, query, conn, args) } -func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *transaction) QueryRows(query string, args ...any) sqldb.Rows { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) - return sqldb.RowsScannerWithError(err) + err = sqldb.WrapErrorWithQuery(err, query, conn, args) + return sqldb.RowsWithError(err) } - return impl.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, argFmt, args) + return sqldb.NewRows(conn.parent.ctx, rows, query, conn, args) } func (conn *transaction) IsTransaction() bool { diff --git a/reflection/columnfilter.go b/reflection/columnfilter.go new file mode 100644 index 0000000..112b271 --- /dev/null +++ b/reflection/columnfilter.go @@ -0,0 +1,9 @@ +package reflection + +import ( + "reflect" +) + +type ColumnFilter interface { + IgnoreColumn(*StructColumn, reflect.Value) bool +} diff --git a/impl/foreachrow.go b/reflection/foreachrow.go similarity index 72% rename from impl/foreachrow.go rename to reflection/foreachrow.go index 2be2b6b..b9448f5 100644 --- a/impl/foreachrow.go +++ b/reflection/foreachrow.go @@ -1,4 +1,4 @@ -package impl +package reflection import ( "context" @@ -6,8 +6,6 @@ import ( "fmt" "reflect" "time" - - sqldb "github.com/domonda/go-sqldb" ) var ( @@ -27,17 +25,17 @@ var ( // If a non nil error is returned from the callback, then this error // is returned immediately by this function without scanning further rows. // In case of zero rows, no error will be returned. -func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScanner) error, err error) { +func ForEachRowCallFunc(ctx context.Context, mapper StructFieldMapper, callback any) (f func(Row) error, err error) { val := reflect.ValueOf(callback) typ := val.Type() if typ.Kind() != reflect.Func { - return nil, fmt.Errorf("ForEachRowCall expected callback function, got %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc expected callback function, got %s", typ) } if typ.IsVariadic() { - return nil, fmt.Errorf("ForEachRowCall callback function must not be varidic: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function must not be varidic: %s", typ) } if typ.NumIn() == 0 || (typ.NumIn() == 1 && typ.In(0) == typeOfContext) { - return nil, fmt.Errorf("ForEachRowCall callback function has no arguments: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function has no arguments: %s", typ) } firstArg := 0 if typ.In(0) == typeOfContext { @@ -58,28 +56,28 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan continue } if structArg { - return nil, fmt.Errorf("ForEachRowCall callback function must not have further argument after struct: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function must not have further argument after struct: %s", typ) } structArg = true case reflect.Chan, reflect.Func: - return nil, fmt.Errorf("ForEachRowCall callback function has invalid argument type: %s", typ.In(i)) + return nil, fmt.Errorf("ForEachRowCallFunc callback function has invalid argument type: %s", typ.In(i)) } } if typ.NumOut() > 1 { - return nil, fmt.Errorf("ForEachRowCall callback function can only have one result value: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function can only have one result value: %s", typ) } if typ.NumOut() == 1 && typ.Out(0) != typeOfError { - return nil, fmt.Errorf("ForEachRowCall callback function result must be of type error: %s", typ) + return nil, fmt.Errorf("ForEachRowCallFunc callback function result must be of type error: %s", typ) } - f = func(row sqldb.RowScanner) (err error) { + f = func(row Row) (err error) { // First scan row scannedValPtrs := make([]any, typ.NumIn()-firstArg) for i := range scannedValPtrs { scannedValPtrs[i] = reflect.New(typ.In(firstArg + i)).Interface() } if structArg { - err = row.ScanStruct(scannedValPtrs[0]) + err = ScanStruct(row, scannedValPtrs[0], mapper) } else { err = row.Scan(scannedValPtrs...) } diff --git a/impl/foreachrow_test.go b/reflection/foreachrow_test.go similarity index 97% rename from impl/foreachrow_test.go rename to reflection/foreachrow_test.go index 7509553..395b7ab 100644 --- a/impl/foreachrow_test.go +++ b/reflection/foreachrow_test.go @@ -1,4 +1,4 @@ -package impl +package reflection import ( "testing" diff --git a/reflection/reflectstruct.go b/reflection/reflectstruct.go new file mode 100644 index 0000000..5a2868d --- /dev/null +++ b/reflection/reflectstruct.go @@ -0,0 +1,131 @@ +package reflection + +// import ( +// "errors" +// "fmt" +// "reflect" +// "strings" + +// "golang.org/x/exp/slices" +// ) + +// func ReflectStructValues(structVal reflect.Value, mapper StructFieldMapper, ignoreColumns []ColumnFilter) (table string, columns []string, pkCols []int, values []any, err error) { +// structType := structVal.Type() +// for i := 0; i < structType.NumField(); i++ { +// fieldType := structType.Field(i) +// fieldTable, column, flags, use := mapper.MapStructField(fieldType) +// if !use { +// continue +// } +// fieldValue := structVal.Field(i) + +// if column == "" { +// // Embedded struct field +// fieldTable, columnsEmbed, pkColsEmbed, valuesEmbed, err := ReflectStructValues(fieldValue, mapper, ignoreColumns) +// if err != nil { +// return "", nil, nil, nil, err +// } +// if fieldTable != "" && fieldTable != table { +// if table != "" { +// return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) +// } +// table = fieldTable +// } +// for _, pkCol := range pkColsEmbed { +// pkCols = append(pkCols, pkCol+len(columns)) +// } +// columns = append(columns, columnsEmbed...) +// values = append(values, valuesEmbed...) +// continue +// } + +// if ignoreColumn(ignoreColumns, column, flags, fieldType, fieldValue) { +// continue +// } + +// if fieldTable != "" && fieldTable != table { +// if table != "" { +// return "", nil, nil, nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, structType) +// } +// table = fieldTable +// } +// if flags.PrimaryKey() { +// pkCols = append(pkCols, len(columns)) +// } +// columns = append(columns, column) +// values = append(values, fieldValue.Interface()) +// } +// return table, columns, pkCols, values, nil +// } + +// func ReflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string) (pointers []any, err error) { +// if len(columns) == 0 { +// return nil, errors.New("no columns") +// } +// pointers = make([]any, len(columns)) +// err = reflectStructColumnPointers(structVal, mapper, columns, pointers) +// if err != nil { +// return nil, err +// } +// for _, ptr := range pointers { +// if ptr != nil { +// continue +// } +// nilCols := new(strings.Builder) +// for i, ptr := range pointers { +// if ptr != nil { +// continue +// } +// if nilCols.Len() > 0 { +// nilCols.WriteString(", ") +// } +// fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) +// } +// return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) +// } +// return pointers, nil +// } + +// func reflectStructColumnPointers(structVal reflect.Value, mapper StructFieldMapper, columns []string, pointers []any) error { +// var ( +// structType = structVal.Type() +// ) +// for i := 0; i < structType.NumField(); i++ { +// fieldType := structType.Field(i) +// _, column, _, use := mapper.MapStructField(fieldType) +// if !use { +// continue +// } +// fieldValue := structVal.Field(i) + +// if column == "" { +// // Embedded struct field +// err := reflectStructColumnPointers(fieldValue, mapper, columns, pointers) +// if err != nil { +// return err +// } +// continue +// } + +// colIndex := slices.Index(columns, column) +// if colIndex == -1 { +// continue +// } + +// if pointers[colIndex] != nil { +// return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, fieldType.Name, structType) +// } + +// pointers[colIndex] = fieldValue.Addr().Interface() +// } +// return nil +// } + +// func ignoreColumn(filters []ColumnFilter, name string, flags StructFieldFlags, fieldType reflect.StructField, fieldValue reflect.Value) bool { +// for _, filter := range filters { +// if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { +// return true +// } +// } +// return false +// } diff --git a/impl/row.go b/reflection/row.go similarity index 95% rename from impl/row.go rename to reflection/row.go index a3b34c7..ba394e0 100644 --- a/impl/row.go +++ b/reflection/row.go @@ -1,4 +1,4 @@ -package impl +package reflection // Row is an interface with the methods of sql.Rows // that are needed for ScanStruct. diff --git a/impl/rows.go b/reflection/rows.go similarity index 70% rename from impl/rows.go rename to reflection/rows.go index 84ab136..2133ee3 100644 --- a/impl/rows.go +++ b/reflection/rows.go @@ -1,4 +1,4 @@ -package impl +package reflection // Rows is an interface with the methods of sql.Rows // that are needed for ScanSlice. @@ -24,18 +24,3 @@ type Rows interface { // Err may be called after an explicit or implicit Close. Err() error } - -// RowAsRows implements the methods of Rows for a Row as no-ops. -// Note that Next() always returns true leading to an endless loop -// if used to scan multiple rows. -func RowAsRows(row Row) Rows { - return rowAsRows{Row: row} -} - -type rowAsRows struct { - Row -} - -func (rowAsRows) Close() error { return nil } -func (rowAsRows) Next() bool { return true } -func (rowAsRows) Err() error { return nil } diff --git a/reflection/scan.go b/reflection/scan.go new file mode 100644 index 0000000..d8d5f89 --- /dev/null +++ b/reflection/scan.go @@ -0,0 +1,91 @@ +package reflection + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "time" +) + +func ScanValue(src driver.Value, dest reflect.Value) error { + if dest.Kind() == reflect.Interface { + if src != nil { + dest.Set(reflect.ValueOf(src)) + } else { + dest.Set(reflect.Zero(dest.Type())) + } + return nil + } + + if dest.Addr().Type().Implements(typeOfSQLScanner) { + return dest.Addr().Interface().(sql.Scanner).Scan(src) + } + + switch x := src.(type) { + case int64: + switch dest.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dest.SetInt(x) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dest.SetUint(uint64(x)) + return nil + case reflect.Float32, reflect.Float64: + dest.SetFloat(float64(x)) + return nil + } + + case float64: + switch dest.Kind() { + case reflect.Float32, reflect.Float64: + dest.SetFloat(x) + return nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dest.SetInt(int64(x)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dest.SetUint(uint64(x)) + return nil + } + + case bool: + dest.SetBool(x) + return nil + + case []byte: + switch { + case dest.Kind() == reflect.String: + dest.SetString(string(x)) + return nil + case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: + dest.Set(reflect.ValueOf(x)) + return nil + } + + case string: + switch { + case dest.Kind() == reflect.String: + dest.SetString(x) + return nil + case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: + dest.Set(reflect.ValueOf([]byte(x))) + return nil + } + + case time.Time: + if srcVal := reflect.ValueOf(src); srcVal.Type().AssignableTo(dest.Type()) { + dest.Set(srcVal) + return nil + } + + case nil: + switch dest.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dest.Set(reflect.Zero(dest.Type())) + return nil + } + } + + return fmt.Errorf("can't scan %#v as %s", src, dest.Type()) +} diff --git a/impl/scanstruct_test.go b/reflection/scan_test.go similarity index 98% rename from impl/scanstruct_test.go rename to reflection/scan_test.go index 787e76e..d81e484 100644 --- a/impl/scanstruct_test.go +++ b/reflection/scan_test.go @@ -1,4 +1,4 @@ -package impl +package reflection // func TestGetStructFieldIndices(t *testing.T) { // type DeepEmbeddedStruct struct { diff --git a/impl/scanslice.go b/reflection/scanslice.go similarity index 55% rename from impl/scanslice.go rename to reflection/scanslice.go index 1aead27..f15cd41 100644 --- a/impl/scanslice.go +++ b/reflection/scanslice.go @@ -1,4 +1,4 @@ -package impl +package reflection import ( "context" @@ -6,19 +6,72 @@ import ( "errors" "fmt" "reflect" - "time" - sqldb "github.com/domonda/go-sqldb" "github.com/domonda/go-types/nullable" ) +// // TODO doc +// // ScanSlice scans one value per row into one slice element of dest. +// // dest must be a pointer to a slice with a row value compatible element type. +// // In case of zero rows, dest will be set to nil and no error will be returned. +// // In case of an error, dest will not be modified. +// // It is an error to query more than one column.func (s *rowsScanner) ScanSlice(dest any) error { +// err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, nil) +// if err != nil { +// return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) +// } +// return nil +// } +// // ScanStructSlice scans every row into the struct fields of dest slice elements. +// // dest must be a pointer to a slice of structs or struct pointers. +// // In case of zero rows, dest will be set to nil and no error will be returned. +// // In case of an error, dest will not be modified. +// // Every mapped struct field must have a corresponding column in the query results. +// func (s *rowsScanner) ScanStructSlice(dest any) error { +// err := reflection.ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldMapper) +// if err != nil { +// return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) +// } +// return nil +// } + +// // ScanAllRowsAsStrings scans the values of all rows as strings. +// // Byte slices will be interpreted as strings, +// // nil (SQL NULL) will be converted to an empty string, +// // all other types are converted with fmt.Sprint. +// // If true is passed for headerRow, then a row +// // with the column names will be prepended. +// func (s *rowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { +// cols, err := s.rows.Columns() +// if err != nil { +// return nil, err +// } +// if headerRow { +// rows = [][]string{cols} +// } +// stringScannablePtrs := make([]any, len(cols)) +// err = s.ForEachRow(func(rowScanner RowScanner) error { +// row := make([]string, len(cols)) +// for i := range stringScannablePtrs { +// stringScannablePtrs[i] = (*StringScannable)(&row[i]) +// } +// err := rowScanner.Scan(stringScannablePtrs...) +// if err != nil { +// return err +// } +// rows = append(rows, row) +// return nil +// }) +// return rows, err +// } + // ScanRowsAsSlice scans all srcRows as slice into dest. // The rows must either have only one column compatible with the element type of the slice, // or if multiple columns are returned then the slice element type must me a struct or struction pointer -// so that every column maps on exactly one struct field using structFieldNamer. -// In case of single column rows, nil must be passed for structFieldNamer. +// so that every column maps on exactly one struct field using structFieldMapper. +// In case of single column rows, nil must be passed for structFieldMapper. // ScanRowsAsSlice calls srcRows.Close(). -func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer sqldb.StructFieldMapper) error { +func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldMapper StructFieldMapper) error { defer srcRows.Close() destVal := reflect.ValueOf(dest) @@ -43,8 +96,8 @@ func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNam newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType)) target := newSlice.Index(newSlice.Len() - 1).Addr() - if structFieldNamer != nil { - err := ScanStruct(srcRows, target.Interface(), structFieldNamer) + if structFieldMapper != nil { + err := ScanStruct(srcRows, target.Interface(), structFieldMapper) if err != nil { return err } @@ -122,85 +175,3 @@ func (a *SliceScanner) scanString(src string) error { a.destSlice.Set(newSlice) return nil } - -func ScanValue(src any, dest reflect.Value) error { - if dest.Kind() == reflect.Interface { - if src != nil { - dest.Set(reflect.ValueOf(src)) - } else { - dest.Set(reflect.Zero(dest.Type())) - } - return nil - } - - if dest.Addr().Type().Implements(typeOfSQLScanner) { - return dest.Addr().Interface().(sql.Scanner).Scan(src) - } - - switch x := src.(type) { - case int64: - switch dest.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dest.SetInt(x) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - dest.SetUint(uint64(x)) - return nil - case reflect.Float32, reflect.Float64: - dest.SetFloat(float64(x)) - return nil - } - - case float64: - switch dest.Kind() { - case reflect.Float32, reflect.Float64: - dest.SetFloat(x) - return nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dest.SetInt(int64(x)) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - dest.SetUint(uint64(x)) - return nil - } - - case bool: - dest.SetBool(x) - return nil - - case []byte: - switch { - case dest.Kind() == reflect.String: - dest.SetString(string(x)) - return nil - case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: - dest.Set(reflect.ValueOf(x)) - return nil - } - - case string: - switch { - case dest.Kind() == reflect.String: - dest.SetString(x) - return nil - case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8: - dest.Set(reflect.ValueOf([]byte(x))) - return nil - } - - case time.Time: - if srcVal := reflect.ValueOf(src); srcVal.Type().AssignableTo(dest.Type()) { - dest.Set(srcVal) - return nil - } - - case nil: - switch dest.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map: - dest.Set(reflect.Zero(dest.Type())) - return nil - } - } - - return fmt.Errorf("can't scan %#v as %s", src, dest.Type()) -} diff --git a/reflection/structfieldflags.go b/reflection/structfieldflags.go new file mode 100644 index 0000000..71878a1 --- /dev/null +++ b/reflection/structfieldflags.go @@ -0,0 +1,25 @@ +package reflection + +// StructFieldFlags is a bitmask for special properties +// of how struct fields relate to database columns. +type StructFieldFlags uint + +// PrimaryKey indicates if FlagPrimaryKey is set +func (f StructFieldFlags) PrimaryKey() bool { return f&FlagPrimaryKey != 0 } + +// ReadOnly indicates if FlagReadOnly is set +func (f StructFieldFlags) ReadOnly() bool { return f&FlagReadOnly != 0 } + +// HasDefault indicates if FlagHasDefault is set +func (f StructFieldFlags) HasDefault() bool { return f&FlagHasDefault != 0 } + +const ( + // FlagPrimaryKey marks a field as primary key + FlagPrimaryKey StructFieldFlags = 1 << iota + + // FlagReadOnly marks a field as read-only + FlagReadOnly + + // FlagHasDefault marks a field as having a column default value + FlagHasDefault +) diff --git a/structfieldmapping.go b/reflection/structfieldmapping.go similarity index 74% rename from structfieldmapping.go rename to reflection/structfieldmapping.go index fde2ca9..c3a9b0f 100644 --- a/structfieldmapping.go +++ b/reflection/structfieldmapping.go @@ -1,4 +1,4 @@ -package sqldb +package reflection import ( "fmt" @@ -7,30 +7,6 @@ import ( "unicode" ) -// FieldFlag is a bitmask for special properties -// of how struct fields relate to database columns. -type FieldFlag uint - -// PrimaryKey indicates if FieldFlagPrimaryKey is set -func (f FieldFlag) PrimaryKey() bool { return f&FieldFlagPrimaryKey != 0 } - -// ReadOnly indicates if FieldFlagReadOnly is set -func (f FieldFlag) ReadOnly() bool { return f&FieldFlagReadOnly != 0 } - -// Default indicates if FieldFlagDefault is set -func (f FieldFlag) Default() bool { return f&FieldFlagDefault != 0 } - -const ( - // FieldFlagPrimaryKey marks a field as primary key - FieldFlagPrimaryKey FieldFlag = 1 << iota - - // FieldFlagReadOnly marks a field as read-only - FieldFlagReadOnly - - // FieldFlagDefault marks a field as having a column default value - FieldFlagDefault -) - // StructFieldMapper is used to map struct type fields to column names // and indicate special column properies via flags. type StructFieldMapper interface { @@ -39,7 +15,7 @@ type StructFieldMapper interface { // If false is returned for use then the field is not mapped. // An empty name and true for use indicates an embedded struct // field whose fields should be recursively mapped. - MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) + MapStructField(field reflect.StructField) (table, column string, flags StructFieldFlags, use bool) } // NewTaggedStructFieldMapping returns a default mapping. @@ -54,11 +30,6 @@ func NewTaggedStructFieldMapping() *TaggedStructFieldMapping { } } -// DefaultStructFieldMapping provides the default StructFieldTagNaming -// using "db" as NameTag and IgnoreStructField as UntaggedNameFunc. -// Implements StructFieldMapper. -var DefaultStructFieldMapping = NewTaggedStructFieldMapping() - // TaggedStructFieldMapping implements StructFieldMapper with a struct field NameTag // to be used for naming and a UntaggedNameFunc in case the NameTag is not set. type TaggedStructFieldMapping struct { @@ -79,7 +50,7 @@ type TaggedStructFieldMapping struct { UntaggedNameFunc func(fieldName string) string } -func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { +func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (table, column string, flags StructFieldFlags, use bool) { if field.Anonymous { column, hasTag := field.Tag.Lookup(m.NameTag) if !hasTag { @@ -112,12 +83,12 @@ func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (ta case "": // Ignore empty flags case m.PrimaryKey: - flags |= FieldFlagPrimaryKey + flags |= FlagPrimaryKey table = value case m.ReadOnly: - flags |= FieldFlagReadOnly + flags |= FlagReadOnly case m.Default: - flags |= FieldFlagDefault + flags |= FlagHasDefault } } } else { diff --git a/structfieldmapping_test.go b/reflection/structfieldmapping_test.go similarity index 90% rename from structfieldmapping_test.go rename to reflection/structfieldmapping_test.go index 4fcece5..e3fe80c 100644 --- a/structfieldmapping_test.go +++ b/reflection/structfieldmapping_test.go @@ -1,4 +1,4 @@ -package sqldb +package reflection import ( "reflect" @@ -55,18 +55,18 @@ func TestTaggedStructFieldMapping_StructFieldName(t *testing.T) { structField reflect.StructField wantTable string wantColumn string - wantFlags FieldFlag + wantFlags StructFieldFlags wantOk bool }{ - {name: "index", structField: st.Field(0), wantTable: "public.my_table", wantColumn: "index", wantFlags: FieldFlagPrimaryKey, wantOk: true}, - {name: "index_b", structField: st.Field(1), wantTable: "", wantColumn: "index_b", wantFlags: FieldFlagPrimaryKey, wantOk: true}, + {name: "index", structField: st.Field(0), wantTable: "public.my_table", wantColumn: "index", wantFlags: FlagPrimaryKey, wantOk: true}, + {name: "index_b", structField: st.Field(1), wantTable: "", wantColumn: "index_b", wantFlags: FlagPrimaryKey, wantOk: true}, {name: "named_str", structField: st.Field(2), wantColumn: "named_str", wantFlags: 0, wantOk: true}, - {name: "read_only", structField: st.Field(3), wantColumn: "read_only", wantFlags: FieldFlagReadOnly, wantOk: true}, + {name: "read_only", structField: st.Field(3), wantColumn: "read_only", wantFlags: FlagReadOnly, wantOk: true}, {name: "untagged_field", structField: st.Field(4), wantColumn: "untagged_field", wantFlags: 0, wantOk: true}, {name: "ignore", structField: st.Field(5), wantColumn: "", wantFlags: 0, wantOk: false}, - {name: "pk_read_only", structField: st.Field(6), wantColumn: "pk_read_only", wantFlags: FieldFlagPrimaryKey | FieldFlagReadOnly, wantOk: true}, + {name: "pk_read_only", structField: st.Field(6), wantColumn: "pk_read_only", wantFlags: FlagPrimaryKey | FlagReadOnly, wantOk: true}, {name: "no_flag", structField: st.Field(7), wantColumn: "no_flag", wantFlags: 0, wantOk: true}, - {name: "malformed_flags", structField: st.Field(8), wantColumn: "malformed_flags", wantFlags: FieldFlagReadOnly, wantOk: true}, + {name: "malformed_flags", structField: st.Field(8), wantColumn: "malformed_flags", wantFlags: FlagReadOnly, wantOk: true}, {name: "Embedded", structField: st.Field(9), wantColumn: "", wantFlags: 0, wantOk: true}, } for _, tt := range tests { diff --git a/reflection/structmapper.go b/reflection/structmapper.go new file mode 100644 index 0000000..4210ab4 --- /dev/null +++ b/reflection/structmapper.go @@ -0,0 +1,9 @@ +package reflection + +import ( + "reflect" +) + +type StructMapper interface { + ReflectStructMapping(t reflect.Type) (*StructMapping, error) +} diff --git a/reflection/structmapping.go b/reflection/structmapping.go new file mode 100644 index 0000000..e13be7a --- /dev/null +++ b/reflection/structmapping.go @@ -0,0 +1,188 @@ +package reflection + +import ( + "errors" + "fmt" + "reflect" + "sync" +) + +type StructMapping struct { + StructType reflect.Type + Table string + Columns []*StructColumn + ColumnMap map[string]*StructColumn +} + +type StructColumn struct { + Name string + Flags StructFieldFlags + FieldIndex []int + FieldType reflect.StructField +} + +type mappingKey struct { + reflect.Type + StructMapper +} + +var ( + cachedMappings = make(map[mappingKey]*StructMapping) + cachedMappingsMtx sync.Mutex +) + +func CachedStructMapping(t reflect.Type, m StructMapper) (*StructMapping, error) { + cachedMappingsMtx.Lock() + defer cachedMappingsMtx.Unlock() + + key := mappingKey{t, m} + + if mapping, ok := cachedMappings[key]; ok { + return mapping, nil + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("passed type %s is not a struct", t) + } + if m == nil { + return nil, errors.New("passed nil StructMapper") + } + mapping, err := m.ReflectStructMapping(t) + if err != nil { + return nil, err + } + cachedMappings[key] = mapping + return mapping, nil +} + +func (m *StructMapping) StructColumnValues(strct any, filter ColumnFilter) ([]any, error) { + v := reflect.ValueOf(strct) + switch v.Kind() { + case reflect.Struct: + // ok + case reflect.Pointer: + if v.IsNil() { + return nil, fmt.Errorf("passed nil %T", strct) + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return nil, fmt.Errorf("passed type %T is not a struct pointer", strct) + } + default: + return nil, fmt.Errorf("passed type %T is not a struct or struct pointer", strct) + } + if v.Type() != m.StructType { + return nil, fmt.Errorf("passed struct of type %s to %s StructMapping", v.Type(), m.StructType) + } + + if filter == nil { + vals := make([]any, len(m.Columns)) + for i, col := range m.Columns { + vals[i] = v.FieldByIndex(col.FieldIndex).Interface() + } + return vals, nil + } + + vals := make([]any, 0, len(m.Columns)) + for _, col := range m.Columns { + val := v.FieldByIndex(col.FieldIndex) + if !filter.IgnoreColumn(col, val) { + vals = append(vals, val.Interface()) + } + } + return vals, nil +} + +// func (m *StructMapping) StructColumnPointers(structPtr any, filter ColumnFilter) ([]any, error) { +// v := reflect.ValueOf(structPtr) +// if v.Kind() != reflect.Pointer { +// return nil, fmt.Errorf("passed type %T is not a struct pointer", structPtr) +// } +// if v.IsNil() { +// return nil, fmt.Errorf("passed nil %T", structPtr) +// } +// v = v.Elem() +// if v.Kind() != reflect.Struct { +// return nil, fmt.Errorf("passed type %T is not a struct pointer", structPtr) +// } +// if v.Type() != m.StructType { +// return nil, fmt.Errorf("passed struct of type %s to %s StructMapping", v.Type(), m.StructType) +// } + +// if filter == nil { +// vals := make([]any, len(m.Columns)) +// for i, col := range m.Columns { +// vals[i] = v.FieldByIndex(col.FieldIndex).Addr().Interface() +// } +// return vals, nil +// } + +// vals := make([]any, 0, len(m.Columns)) +// for _, col := range m.Columns { +// val := v.FieldByIndex(col.FieldIndex) +// if !filter.IgnoreColumn(col, val) { +// vals = append(vals, val.Addr().Interface()) +// } +// } +// return vals, nil +// } + +// ScanStruct scans values of a srcRow into a destStruct which must be passed as pointer. +func (m *StructMapping) ScanStruct(srcRow Row, structPtr any, filter ColumnFilter) error { + v := reflect.ValueOf(structPtr) + // if v.Kind() != reflect.Pointer { + // return fmt.Errorf("passed type %T is not a struct pointer", structPtr) + // } + // if v.IsNil() { + // return fmt.Errorf("passed nil %T", structPtr) + // } + // v = v.Elem() + // if v.Kind() != reflect.Struct { + // return fmt.Errorf("passed type %T is not a struct pointer", structPtr) + // } + // if v.Type() != m.StructType { + // return fmt.Errorf("passed struct of type %s to %s StructMapping", v.Type(), m.StructType) + // } + + var ( + setDestStructPtr = false + destStructPtr reflect.Value + newStructPtr reflect.Value + ) + if v.Kind() == reflect.Ptr && v.IsNil() && v.CanSet() { + // Got a nil pointer that we can set with a newly allocated struct + setDestStructPtr = true + destStructPtr = v + newStructPtr = reflect.New(v.Type().Elem()) + // Continue with the newly allocated struct + v = newStructPtr.Elem() + } + if v.Kind() != reflect.Struct { + return fmt.Errorf("passed type %T is not a struct pointer", structPtr) + } + + columns, err := srcRow.Columns() + if err != nil { + return err + } + + fieldPointers := make([]any, len(columns)) + for i, name := range columns { + col, ok := m.ColumnMap[name] + if !ok { + return fmt.Errorf("no mapping for column %s to struct %s", name, m.StructType) + } + fieldPointers[i] = v.FieldByIndex(col.FieldIndex).Addr().Interface() + } + + err = srcRow.Scan(fieldPointers...) + if err != nil { + return err + } + + if setDestStructPtr { + destStructPtr.Set(newStructPtr) + } + + return nil +} diff --git a/reflection/taggedstructmapping.go b/reflection/taggedstructmapping.go new file mode 100644 index 0000000..f2a666d --- /dev/null +++ b/reflection/taggedstructmapping.go @@ -0,0 +1,142 @@ +package reflection + +import ( + "fmt" + "reflect" + "strings" +) + +// TaggedStructMapper implements StructFieldMapper with a struct field NameTag +// to be used for naming and a UntaggedNameFunc in case the NameTag is not set. +type TaggedStructMapper struct { + _Named_Fields_Required struct{} + + // NameTag is the struct field tag to be used as column name + NameTag string + + // Ignore will cause a struct field to be ignored if it has that name + Ignore string + + PrimaryKey string + ReadOnly string + Default string + + // UntaggedNameFunc will be called with the struct field name to + // return a column name in case the struct field has no tag named NameTag. + UntaggedNameFunc func(fieldName string) string +} + +// NewTaggedStructMapper returns a default mapping. +func NewTaggedStructMapper() *TaggedStructMapper { + return &TaggedStructMapper{ + NameTag: "db", + Ignore: "-", + PrimaryKey: "pk", + ReadOnly: "readonly", + Default: "default", + UntaggedNameFunc: IgnoreStructField, + } +} + +func (m *TaggedStructMapper) ReflectStructMapping(structType reflect.Type) (*StructMapping, error) { + if structType.Kind() != reflect.Struct { + return nil, fmt.Errorf("passed type %s is not a struct", structType) + } + mapping := &StructMapping{ + StructType: structType, + ColumnMap: make(map[string]*StructColumn), + } + err := m.reflectStructMapping(structType, mapping) + if err != nil { + return nil, err + } + return mapping, nil +} + +func (m *TaggedStructMapper) reflectStructMapping(structType reflect.Type, mapping *StructMapping) error { + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + fieldTable, name, flags, use := m.mapStructField(field) + if !use { + continue + } + + if name == "" { + // Embedded struct field + err := m.reflectStructMapping(field.Type, mapping) + if err != nil { + return err + } + continue + } + + if fieldTable != "" && fieldTable != mapping.Table { + if mapping.Table != "" { + return fmt.Errorf("table name not unique (%s vs %s) in struct %s", mapping.Table, fieldTable, mapping.StructType) + } + mapping.Table = fieldTable + } + + column := &StructColumn{ + Name: name, + Flags: flags, + FieldIndex: field.Index, + FieldType: field, + } + mapping.Columns = append(mapping.Columns, column) + mapping.ColumnMap[name] = column + } + return nil +} + +func (m *TaggedStructMapper) mapStructField(field reflect.StructField) (table, column string, flags StructFieldFlags, use bool) { + if field.Anonymous { + column, hasTag := field.Tag.Lookup(m.NameTag) + if !hasTag { + // Embedded struct fields are ok if not tagged with IgnoreName + return "", "", 0, true + } + if i := strings.IndexByte(column, ','); i != -1 { + column = column[:i] + } + // Embedded struct fields are ok if not tagged with IgnoreName + return "", "", 0, column != m.Ignore + } + + if !field.IsExported() { + // Not exported struct fields that are not + // anonymously embedded structs are not ok + return "", "", 0, false + } + + tag, hasTag := field.Tag.Lookup(m.NameTag) + if hasTag { + for i, part := range strings.Split(tag, ",") { + // First part is the name + if i == 0 { + column = part + continue + } + // Follow on parts are flags + flag, value, _ := strings.Cut(part, "=") + switch flag { + case "": + // Ignore empty flags + case m.PrimaryKey: + flags |= FlagPrimaryKey + table = value + case m.ReadOnly: + flags |= FlagReadOnly + case m.Default: + flags |= FlagHasDefault + } + } + } else if m.UntaggedNameFunc != nil { + column = m.UntaggedNameFunc(field.Name) + } + + if column == m.Ignore || column == "" { + return "", "", 0, false + } + return table, column, flags, true +} diff --git a/row.go b/row.go new file mode 100644 index 0000000..fd70693 --- /dev/null +++ b/row.go @@ -0,0 +1,97 @@ +package sqldb + +import ( + "context" + "database/sql" + "errors" +) + +// Row is an interface with the methods of sql.Rows +// that are needed for ScanStruct. +// Allows mocking for tests without an SQL driver. +type Row interface { + // Columns returns the column names. + Columns() ([]string, error) + // Scan copies the columns in the current row into the values pointed + // at by dest. The number of values in dest must be the same as the + // number of columns in Rows. + Scan(dest ...any) error +} + +/////////////////////////////////////////////////////////////////////////////// + +// RowWithError returns a dummy Row +// where all methods return the passed error. +func RowWithError(err error) Row { + return errRow{err} +} + +type errRow struct{ err error } + +func (e errRow) Columns() ([]string, error) { return nil, e.err } +func (e errRow) Scan(dest ...any) error { return e.err } + +/////////////////////////////////////////////////////////////////////////////// + +type sqlRow struct { + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows *sql.Rows + conn Connection // for error wrapping + query string // for error wrapping + args []any // for error wrapping +} + +func NewRow(ctx context.Context, rows *sql.Rows, conn Connection, query string, args []any) Row { + return &sqlRow{ctx, rows, conn, query, args} +} + +func (r *sqlRow) Columns() ([]string, error) { + columns, err := r.rows.Columns() + if err != nil { + return nil, WrapErrorWithQuery(err, r.query, r.args, r.conn.Config().ParamPlaceholderFormatter) + } + return columns, nil +} + +func (r *sqlRow) Scan(dest ...any) (err error) { + defer func() { + err = combineTwoErrors(err, r.rows.Close()) + if err != nil { + err = WrapErrorWithQuery(err, r.query, r.args, r.conn.Config().ParamPlaceholderFormatter) + } + }() + + if r.ctx.Err() != nil { + return r.ctx.Err() + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + if !r.rows.Next() { + if err := r.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err = r.rows.Scan(dest...) + if err != nil { + return err + } + return r.rows.Close() +} diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..350320e --- /dev/null +++ b/rows.go @@ -0,0 +1,96 @@ +package sqldb + +import ( + "context" + "database/sql" +) + +type Rows interface { + // ForEachRow will call the passed callback with a RowScanner for every row. + // In case of zero rows, no error will be returned. + ForEachRow(callback func(Row) error) error + + // Close closes the Rows, preventing further enumeration. If Next is called + // and returns false and there are no further result sets, + // the Rows are closed automatically and it will suffice to check the + // result of Err. Close is idempotent and does not affect the result of Err. + Close() error +} + +/////////////////////////////////////////////////////////////////////////////// + +// RowsWithError returns dummy Rows +// where all methods return the passed error. +func RowsWithError(err error) Rows { + return errRows{err} +} + +type errRows struct{ err error } + +func (e errRows) ForEachRow(func(Row) error) error { return e.err } +func (e errRows) Close() error { return e.err } + +/////////////////////////////////////////////////////////////////////////////// + +type sqlRows struct { + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows *sql.Rows + conn Connection // for error wrapping + query string // for error wrapping + args []any // for error wrapping +} + +func NewRows(ctx context.Context, rows *sql.Rows, conn Connection, query string, args []any) Rows { + return &sqlRows{ctx, rows, conn, query, args} +} + +func (r *sqlRows) ForEachRow(callback func(Row) error) (err error) { + defer func() { + err = combineTwoErrors(err, r.rows.Close()) + if err != nil { + err = WrapErrorWithQuery(err, r.query, r.args, r.conn.Config().ParamPlaceholderFormatter) + } + }() + + for r.rows.Next() { + if r.ctx.Err() != nil { + return r.ctx.Err() + } + + err := callback(r.rows) + if err != nil { + return err + } + } + return r.rows.Err() +} + +func (r *sqlRows) Close() error { + return r.rows.Close() +} + +/////////////////////////////////////////////////////////////////////////////// + +// RowAsRows returns a single Rows wrapped as a Rows implementation. +// func RowAsRows(row Row) Rows { +// return &rowAsRows{row: row, closed: false} +// } + +// type rowAsRows struct { +// row Row +// closed bool +// } + +// func (r *rowAsRows) ForEachRow(callback func(Row) error) error { +// if r.closed { +// return errors.New("Rows are closed") +// } +// err := callback(r.row) +// r.closed = true +// return err +// } + +// func (r *rowAsRows) Close() error { +// r.closed = true +// return nil +// } diff --git a/rowscanner.go b/rowscanner.go deleted file mode 100644 index 0507364..0000000 --- a/rowscanner.go +++ /dev/null @@ -1,24 +0,0 @@ -package sqldb - -// RowScanner scans the values from a single row. -type RowScanner interface { - // Scan values of a row into dest variables, which must be passed as pointers. - Scan(dest ...any) error - - // ScanStruct scans values of a row into a dest struct which must be passed as pointer. - ScanStruct(dest any) error - - // ScanValues returns the values of a row exactly how they are - // passed from the database driver to an sql.Scanner. - // Byte slices will be copied. - ScanValues() ([]any, error) - - // ScanStrings scans the values of a row as strings. - // Byte slices will be interpreted as strings, - // nil (SQL NULL) will be converted to an empty string, - // all other types are converted with fmt.Sprint(src). - ScanStrings() ([]string, error) - - // Columns returns the column names. - Columns() ([]string, error) -} diff --git a/rowsscanner.go b/rowsscanner.go deleted file mode 100644 index c318deb..0000000 --- a/rowsscanner.go +++ /dev/null @@ -1,45 +0,0 @@ -package sqldb - -// RowsScanner scans the values from multiple rows. -type RowsScanner interface { - // ScanSlice scans one value per row into one slice element of dest. - // dest must be a pointer to a slice with a row value compatible element type. - // In case of zero rows, dest will be set to nil and no error will be returned. - // In case of an error, dest will not be modified. - // It is an error to query more than one column. - ScanSlice(dest any) error - - // ScanStructSlice scans every row into the struct fields of dest slice elements. - // dest must be a pointer to a slice of structs or struct pointers. - // In case of zero rows, dest will be set to nil and no error will be returned. - // In case of an error, dest will not be modified. - // Every mapped struct field must have a corresponding column in the query results. - ScanStructSlice(dest any) error - - // ScanAllRowsAsStrings scans the values of all rows as strings. - // Byte slices will be interpreted as strings, - // nil (SQL NULL) will be converted to an empty string, - // all other types are converted with fmt.Sprint. - // If true is passed for headerRow, then a row - // with the column names will be prepended. - ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) - - // Columns returns the column names. - Columns() ([]string, error) - - // ForEachRow will call the passed callback with a RowScanner for every row. - // In case of zero rows, no error will be returned. - ForEachRow(callback func(RowScanner) error) error - - // ForEachRowCall will call the passed callback with scanned values or a struct for every row. - // If the callback function has a single struct or struct pointer argument, - // then RowScanner.ScanStruct will be used per row, - // else RowScanner.Scan will be used for all arguments of the callback. - // If the function has a context.Context as first argument, - // then the context of the query call will be passed on. - // The callback can have no result or a single error result value. - // If a non nil error is returned from the callback, then this error - // is returned immediately by this function without scanning further rows. - // In case of zero rows, no error will be returned. - ForEachRowCall(callback any) error -} diff --git a/impl/scanresult.go b/scanresult.go similarity index 87% rename from impl/scanresult.go rename to scanresult.go index c51c336..266297d 100644 --- a/impl/scanresult.go +++ b/scanresult.go @@ -1,6 +1,4 @@ -package impl - -import "github.com/domonda/go-sqldb" +package sqldb // ScanValues returns the values of a row exactly how they are // passed from the database driver to an sql.Scanner. @@ -11,7 +9,7 @@ func ScanValues(src Row) ([]any, error) { return nil, err } var ( - anys = make([]sqldb.AnyValue, len(cols)) + anys = make([]AnyValue, len(cols)) vals = make([]any, len(cols)) ) for i := range vals { @@ -41,7 +39,7 @@ func ScanStrings(src Row) ([]string, error) { args = make([]any, len(cols)) ) for i := range args { - args[i] = (*sqldb.StringScannable)(&strs[i]) + args[i] = (*StringScannable)(&strs[i]) } err = src.Scan(args...) if err != nil { diff --git a/sqliteconn/connection.go b/sqliteconn/connection.go new file mode 100644 index 0000000..a74d3a5 --- /dev/null +++ b/sqliteconn/connection.go @@ -0,0 +1,27 @@ +package sqliteconn + +import ( + "context" + "fmt" + + _ "modernc.org/sqlite" + + "github.com/domonda/go-sqldb" +) + +// New creates a new sqldb.Connection using the passed sqldb.Config +// and modernc.org/sqlite as driver implementation. +// The connection is pinged with the passed context +// and only returned when there was no error from the ping. +func New(ctx context.Context, config *sqldb.Config) (sqldb.Connection, error) { + if config.Driver != "sqlite" { + return nil, fmt.Errorf(`invalid driver %q, pqconn expects "sqlite"`, config.Driver) + } + + db, err := config.Connect(ctx) + if err != nil { + return nil, err + } + _ = db + panic("TODO") +} diff --git a/transaction.go b/transaction.go index 036568b..755505f 100644 --- a/transaction.go +++ b/transaction.go @@ -1,6 +1,7 @@ package sqldb import ( + "context" "database/sql" "errors" "fmt" @@ -14,15 +15,15 @@ import ( // are stricter than the options of the parent transaction. // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. -func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { - if parentOpts, parentIsTx := parentConn.TransactionOptions(); parentIsTx { - err = CheckTxOptionsCompatibility(parentOpts, opts, parentConn.Config().DefaultIsolationLevel) +func Transaction(ctx context.Context, parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { + if parentConn.IsTransaction() { + err = CheckConnectionTxOptionsCompatibility(parentConn, opts) if err != nil { return err } return txFunc(parentConn) } - return IsolatedTransaction(parentConn, opts, txFunc) + return IsolatedTransaction(ctx, parentConn, opts, txFunc) } // IsolatedTransaction executes txFunc within a database transaction that is passed in to txFunc as tx Connection. @@ -30,8 +31,8 @@ func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Conn // If parentConn is already a transaction, a brand new transaction will begin on the parent's connection. // Errors and panics from txFunc will rollback the transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. -func IsolatedTransaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { - tx, e := parentConn.Begin(opts) +func IsolatedTransaction(ctx context.Context, parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { + tx, e := parentConn.Begin(ctx, opts) if e != nil { return fmt.Errorf("Transaction Begin error: %w", e) } @@ -97,3 +98,7 @@ func CheckTxOptionsCompatibility(parent, child *sql.TxOptions, defaultIsolation } return nil } + +func CheckConnectionTxOptionsCompatibility(parentTx Connection, childTxOpts *sql.TxOptions) error { + return CheckTxOptionsCompatibility(parentTx.TxOptions(), childTxOpts, parentTx.Config().DefaultIsolationLevel) +} diff --git a/txconnection.go b/txconnection.go new file mode 100644 index 0000000..d0d58cc --- /dev/null +++ b/txconnection.go @@ -0,0 +1,89 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +type TxConnection struct { + Parent Connection + Tx *sql.Tx + Opts *sql.TxOptions +} + +func (c *TxConnection) Config() *Config { + return c.Parent.Config() +} + +func (c *TxConnection) Stats() sql.DBStats { + return c.Parent.Stats() +} + +func (c *TxConnection) Ping(ctx context.Context, timeout time.Duration) error { + return c.Parent.Ping(ctx, timeout) +} + +func (c *TxConnection) Err() error { + return c.Parent.Err() +} + +func (c *TxConnection) Exec(ctx context.Context, query string, args ...any) error { + _, err := c.Tx.ExecContext(ctx, query, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, c.Config().ParamPlaceholderFormatter) + } + return nil +} + +func (c *TxConnection) QueryRow(ctx context.Context, query string, args ...any) Row { + rows, err := c.Tx.QueryContext(ctx, query, args...) + if err != nil { + return RowWithError(WrapErrorWithQuery(err, query, args, c.Config().ParamPlaceholderFormatter)) + } + return NewRow(ctx, rows, c, query, args) +} + +func (c *TxConnection) QueryRows(ctx context.Context, query string, args ...any) Rows { + rows, err := c.Tx.QueryContext(ctx, query, args...) + if err != nil { + return RowsWithError(WrapErrorWithQuery(err, query, args, c.Config().ParamPlaceholderFormatter)) + } + return NewRows(ctx, rows, c, query, args) +} + +func (c *TxConnection) IsTransaction() bool { + return true +} + +func (c *TxConnection) TxOptions() *sql.TxOptions { + return c.Opts +} + +func (c *TxConnection) Begin(ctx context.Context, opts *sql.TxOptions) (Connection, error) { + return nil, ErrWithinTransaction +} + +func (c *TxConnection) Commit() error { + return c.Tx.Commit() +} + +func (c *TxConnection) Rollback() error { + return c.Tx.Rollback() +} + +func (c *TxConnection) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return ErrWithinTransaction +} + +func (c *TxConnection) UnlistenChannel(channel string) error { + return ErrWithinTransaction +} + +func (c *TxConnection) IsListeningOnChannel(channel string) bool { + return false +} + +func (c *TxConnection) Close() error { + return c.Tx.Rollback() +}