diff --git a/.gitignore b/.gitignore index f1c181e..3b281f5 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out + +pqconn/test/postgres-data \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..f34a2e0 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,104 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Commands + +- **Run tests for all modules**: `./test-workspace.sh` - This runs tests across all Go workspace modules (main, mssqlconn, mysqlconn, pqconn) +- **Run tests for specific module**: `go test ./...` in the module directory +- **Build all modules**: `go build ./...` +- **Get dependencies**: `go mod tidy` (run in each module directory as needed) + +## Go Workspace Structure + +This repository uses Go workspaces with multiple modules: +- Main module: `github.com/domonda/go-sqldb` (root) +- Database drivers: `./mssqlconn`, `./mysqlconn`, `./pqconn` +- Command tools: `./cmd/sqldb-dump` +- Examples: `./examples/user_demo` + +## Architecture Overview + +### Core Components + +- **Connection Interface**: Central abstraction for database connections and transactions (`connection.go`) +- **DB Package**: Context-based connection management pattern (`db/` directory) +- **Database Drivers**: Separate modules for PostgreSQL (`pqconn/`), MySQL (`mysqlconn/`), and SQL Server (`mssqlconn/`) +- **Query Building**: Flexible query construction with struct mapping +- **Transaction Management**: Nested transactions with savepoint support + +### Key Design Patterns + +1. **Context-Based Connection Storage**: Store connections/transactions in context for seamless function composition +2. **Struct-to-SQL Mapping**: Automatic mapping between Go structs and database rows using reflection +3. **Transaction Callbacks**: Execute transactions in callback functions that can be nested +4. **Flexible Query Interface**: Support for both raw SQL and struct-based operations + +### Important Packages + +- `sqldb` (root): Core interfaces and types +- `db/`: Context-based connection management and transaction utilities +- `information/`: Database schema introspection +- `_mockconn/`: Mock implementations for testing + +## Code Conventions + +### SQL Queries +- Write SQL string literals with backticks and prefix with `/*sql*/` comment +- Use numbered parameters (`$1`, `$2`) for PostgreSQL driver + +### Error Handling +- Use `github.com/domonda/go-errs` instead of standard `errors` package +- Use `errs.New()` instead of `errors.New()` +- Use `errs.Errorf()` instead of `fmt.Errorf()` + +### UUID Types +- Use `github.com/domonda/go-types/uu` package for UUIDs +- Single UUID: `uu.ID` +- UUID slice: `uu.IDSlice` (not `[]uu.ID`) +- Zero values: `uu.IDNil` for `uu.ID`, `uu.IDNull` for `uu.NullableID`, `nil` for `uu.IDSlice` + +### General Go Rules +- Use `any` instead of `interface{}` +- In HTTP handlers, use `http.Request.Context()` for context +- Never return actual error strings as HTTP 500 responses - use abstract descriptions + +### Struct Field Mapping +- Default tag: `db:"column_name"` +- Primary key: `db:"id,pk"` +- Ignore field: `db:"-"` + +## Testing +- Mock connections available in `_mockconn/` package +- PostgreSQL integration tests use Docker (see `pqconn/test/`) +- Use `db.ContextWithNonConnectionForTest()` for testing without real database +- Helper functions in `db/testhelper.go` + +## Common Usage Patterns + +### Transaction Management +```go +// Simple transaction +err := db.Transaction(ctx, func(ctx context.Context) error { + // All db.Conn(ctx) calls use the transaction + return db.Conn(ctx).Exec(/*sql*/ `INSERT ...`) +}) + +// Serialized transaction (for high concurrency scenarios) +err := db.SerializedTransaction(ctx, func(ctx context.Context) error { ... }) + +// Transaction with savepoints (nested transactions) +err := db.TransactionSavepoint(ctx, func(ctx context.Context) error { ... }) +``` + +### Struct Operations +```go +// Insert with struct +err := db.InsertStruct(ctx, "table_name", &structInstance) + +// Upsert (uses primary key fields) +err := db.UpsertStruct(ctx, "table_name", &structInstance) + +// Query into struct +user, err := db.QueryRowValue[User](ctx, /*sql*/ `SELECT * FROM users WHERE id = $1`, id) +``` \ No newline at end of file diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..f76c0be --- /dev/null +++ b/TODO.md @@ -0,0 +1,78 @@ +# TODO - Unfinished Work in go-sqldb + +## Major TODOs from README.md + +- [ ] Test all pkg db functions +- [ ] pkg information completion +- [ ] Test pqconn with dockerized Postgres +- [ ] Cache struct types (see commit 090e73d1d9db8534d2950dd7236d7ebe192cd512) +- [ ] Std SQL driver for mocks +- [ ] Smooth out listener for Postgres +- [ ] SQLite integration https://github.com/zombiezen/go-sqlite +- [ ] Bartch insert +```go + func BatchInsert[T any](ctx context.Context, table string, items []T, + batchSize int) error +``` + +## Code-Level TODOs and Missing Implementations + +### Performance Optimizations + +- [ ] **db/insert.go:203** - `InsertRowStructs` missing optimized batch insert (currently processes one-by-one in transaction) +- [ ] **db/insert.go:76** - Commented code for RETURNING clause needs error wrapping +- [ ] **pqconn/arrays.go:128** - Array element scanning needs type conversion improvement for different element types + +### Function Implementations + +- [ ] **db/insert.go:152** - Complete commented out `InsertStructStmt` function with TODO placeholder +- [ ] **mssqlconn/queryformatter.go:11** - Allow spaces and other characters with backtick escaping + +### API Design Questions + +- [ ] **db/scanresult.go:3** - Consider moving ScanResult to RowScanner interface +- [ ] **db/multirowscanner.go:15,97** - Resolve API design questions about single vs multi-column scanning +- [ ] **db/reflectstruct.go:168** - Clean up Connection implementation detail + +## Missing Patterns + +### 1. Batch Operations +- Current `InsertRowStructs` processes items individually in a transaction +- Need optimized batch INSERT statements that combine multiple structs +- Consider maxArgs parameter limitations + +### 2. RETURNING Clause Support +- Commented implementation exists in insert.go:76 +- Need proper error wrapping for query execution +- Should integrate with existing query building patterns + +### 3. Error Handling Standardization +- Some query error wrapping is incomplete +- Need consistent pattern across all database operations + +### 4. Type Conversion Enhancement +- Array scanning needs improvement for different element types +- String-to-type conversion challenges in pqconn/arrays.go:128 + +## Key Areas for Completion + +### High Priority +1. **Performance Optimization**: Implement batch insert operations +2. **Testing**: Comprehensive test coverage for db package functions +3. **Configuration**: Rethink and improve Config structure + +### Medium Priority +4. **Database Support**: Complete SQLite integration +5. **Error Handling**: Standardize query error wrapping patterns +6. **API Consistency**: Resolve design questions in multirowscanner + +### Low Priority +7. **Code Organization**: Move ScanResult and clean up implementation details +8. **Features**: Enhanced array type support and RETURNING clause functionality + +## Implementation Notes + +- The UpsertStruct function (db/upsert.go:14) is marked "TODO" but appears fully implemented +- Mock connection patterns are well established in _mockconn/ package +- Go workspace structure supports multiple database drivers effectively +- Context-based connection management pattern is consistently implemented \ No newline at end of file diff --git a/mockconn/boolmap.go b/_mockconn/boolmap.go similarity index 100% rename from mockconn/boolmap.go rename to _mockconn/boolmap.go diff --git a/_mockconn/connection.go b/_mockconn/connection.go new file mode 100644 index 0000000..d1eab9d --- /dev/null +++ b/_mockconn/connection.go @@ -0,0 +1,136 @@ +package mockconn + +/* +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "time" + + "github.com/domonda/go-sqldb" +) + +var DefaultArgFmt = "?" + +func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) sqldb.Connection { + return &connection{ + ctx: ctx, + queryWriter: queryWriter, + listening: newBoolMap(), + rowsProvider: rowsProvider, + structFieldNamer: sqldb.DefaultStructFieldMapping, + argFmt: DefaultArgFmt, + } +} + +type connection struct { + ctx context.Context + queryWriter io.Writer + listening *boolMap + rowsProvider RowsProvider + structFieldNamer sqldb.StructFieldMapper + argFmt string +} + + + +func (conn *connection) Stats() sql.DBStats { + return sql.DBStats{} +} + +func (conn *connection) Config() *sqldb.Config { + return &sqldb.Config{Driver: "mockconn", Host: "localhost", Database: "mock"} +} + +func (conn *connection) Ping(time.Duration) error { + return nil +} + +func (conn *connection) Exec(query string, args ...any) error { + if conn.queryWriter != nil { + fmt.Fprint(conn.queryWriter, query) + } + return nil +} + +func (conn *connection) Query(query string, args ...any) sqldb.Rows { + if err := conn.ctx.Err(); err != nil { + return sqldb.RowsErr(err) + } + return conn.rowsProvider.Query(conn.structFieldNamer, query, args...) +} + +// func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { +// if conn.ctx.Err() != nil { +// return sqldb.RowScannerWithError(conn.ctx.Err()) +// } +// if conn.queryWriter != nil { +// fmt.Fprint(conn.queryWriter, query) +// } +// if conn.rowsProvider == nil { +// return sqldb.RowScannerWithError(nil) +// } +// return conn.rowsProvider.QueryRow(conn.structFieldNamer, query, args...) +// } + +// func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +// if conn.ctx.Err() != nil { +// return sqldb.RowsScannerWithError(conn.ctx.Err()) +// } +// if conn.queryWriter != nil { +// fmt.Fprint(conn.queryWriter, query) +// } +// if conn.rowsProvider == nil { +// return sqldb.RowsScannerWithError(nil) +// } +// return conn.rowsProvider.QueryRows(conn.structFieldNamer, query, args...) +// } + +func (conn *connection) Transaction() (no uint64, opts *sql.TxOptions) { + return 0, nil +} + +func (conn *connection) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } + if conn.queryWriter != nil { + fmt.Fprint(conn.queryWriter, "BEGIN") + } + return transaction{conn, opts, no}, nil +} + +func (conn *connection) Commit() error { + return sqldb.ErrNotWithinTransaction +} + +func (conn *connection) Rollback() error { + return sqldb.ErrNotWithinTransaction +} + +func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { + conn.listening.Set(channel, true) + if conn.queryWriter != nil { + fmt.Fprint(conn.queryWriter, "LISTEN "+channel) + } + return nil +} + +func (conn *connection) UnlistenChannel(channel string) (err error) { + conn.listening.Set(channel, false) + if conn.queryWriter != nil { + fmt.Fprint(conn.queryWriter, "UNLISTEN "+channel) + } + return nil +} + +func (conn *connection) IsListeningOnChannel(channel string) bool { + return conn.listening.Get(channel) +} + +func (conn *connection) Close() error { + return nil +} +*/ diff --git a/mockconn/connection_test.go b/_mockconn/connection_test.go similarity index 68% rename from mockconn/connection_test.go rename to _mockconn/connection_test.go index 9842bf2..b7f52bd 100644 --- a/mockconn/connection_test.go +++ b/_mockconn/connection_test.go @@ -1,5 +1,6 @@ package mockconn +/* import ( "bytes" "context" @@ -33,7 +34,7 @@ type testRow struct { } func TestInsertQuery(t *testing.T) { - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &sqldb.TaggedStructReflector{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} queryOutput := bytes.NewBuffer(nil) rowProvider := NewSingleRowProvider(NewRow(struct{ True bool }{true}, naming)) conn := New(context.Background(), queryOutput, rowProvider).WithStructFieldMapper(naming) @@ -52,13 +53,13 @@ func TestInsertQuery(t *testing.T) { "bools": pq.BoolArray{true, false}, } - expected := `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + expected := `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9)` err := db.Insert(ctx, "public.table", values) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` + expected = `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` inserted, err := db.InsertUnique(ctx, "public.table", values, "id") assert.NoError(t, err) assert.True(t, inserted) @@ -67,7 +68,7 @@ func TestInsertQuery(t *testing.T) { func TestInsertStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &sqldb.TaggedStructReflector{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", @@ -80,19 +81,19 @@ func TestInsertStructQuery(t *testing.T) { row := new(testRow) - expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9)` err := db.InsertStruct(ctx, "public.table", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3)` + expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES(?1,?2,?3)` err = db.InsertStruct(ctx, "public.table", row, sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6)` + expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES(?1,?2,?3,?4,?5,?6)` err = db.InsertStruct(ctx, "public.table", row, sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) @@ -100,7 +101,7 @@ func TestInsertStructQuery(t *testing.T) { func TestInsertUniqueStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &sqldb.TaggedStructReflector{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", @@ -114,21 +115,21 @@ func TestInsertUniqueStructQuery(t *testing.T) { row := new(testRow) - expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` + expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` inserted, err := db.InsertUniqueStruct(ctx, "public.table", row, "(id)") assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3) ON CONFLICT (id, untagged_field) DO NOTHING RETURNING TRUE` + expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES(?1,?2,?3) ON CONFLICT (id, untagged_field) DO NOTHING RETURNING TRUE` inserted, err = db.InsertUniqueStruct(ctx, "public.table", row, "(id, untagged_field)", sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6) ON CONFLICT (id) DO NOTHING RETURNING TRUE` + expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES(?1,?2,?3,?4,?5,?6) ON CONFLICT (id) DO NOTHING RETURNING TRUE` inserted, err = db.InsertUniqueStruct(ctx, "public.table", row, "(id)", sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.True(t, inserted) @@ -137,8 +138,9 @@ func TestInsertUniqueStructQuery(t *testing.T) { func TestUpdateQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &sqldb.TaggedStructReflector{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) str := "Hello World!" values := sqldb.Values{ @@ -152,22 +154,25 @@ func TestUpdateQuery(t *testing.T) { "bools": pq.BoolArray{true, false}, } - expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1` - err := conn.Update("public.table", values, "id = $1", 1) + // Passing one varidic arg as ?1, moves the index of the rest of the args by 1 + expected := `UPDATE public.table SET "bool"=?2, "bools"=?3, "created_at"=?4, "int"=?5, "nil_ptr"=?6, "str"=?7, "str_ptr"=?8, "untagged_field"=?9 WHERE id = ?1` + err := db.Update(ctx, "public.table", values, "id = ?1", 1) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "bool"=$3,"bools"=$4,"created_at"=$5,"int"=$6,"nil_ptr"=$7,"str"=$8,"str_ptr"=$9,"untagged_field"=$10 WHERE a = $1 AND b = $2` - err = conn.Update("public.table", values, "a = $1 AND b = $2", 1, 2) + // Passing two varidic args as ?1 and ?2, moves the index of the rest of the args by 2 + expected = `UPDATE public.table SET "bool"=?3, "bools"=?4, "created_at"=?5, "int"=?6, "nil_ptr"=?7, "str"=?8, "str_ptr"=?9, "untagged_field"=?10 WHERE a = ?1 AND b = ?2` + err = db.Update(ctx, "public.table", values, "a = ?1 AND b = ?2", 1, 2) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpdateReturningQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &sqldb.TaggedStructReflector{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) str := "Hello World!" values := sqldb.Values{ @@ -181,21 +186,23 @@ func TestUpdateReturningQuery(t *testing.T) { "bools": pq.BoolArray{true, false}, } - expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING *` - err := conn.UpdateReturningRow("public.table", values, "*", "id = $1", 1).Scan() + // Passing one varidic arg as ?1, moves the index of the rest of the args by 1 + expected := `UPDATE public.table SET "bool"=?2, "bools"=?3, "created_at"=?4, "int"=?5, "nil_ptr"=?6, "str"=?7, "str_ptr"=?8, "untagged_field"=?9 WHERE id = ?1 RETURNING *` + err := db.UpdateReturningRow(ctx, "public.table", values, "*", "id = ?1", 1).Scan() assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING created_at,untagged_field` - err = conn.UpdateReturningRows("public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) + // Passing two varidic args as ?1 and ?2, moves the index of the rest of the args by 2 + expected = `UPDATE public.table SET "bool"=?3, "bools"=?4, "created_at"=?5, "int"=?6, "nil_ptr"=?7, "str"=?8, "str_ptr"=?9, "untagged_field"=?10 WHERE id = ?1 RETURNING created_at,untagged_field` + err = db.UpdateReturningRows(ctx, "public.table", values, "created_at,untagged_field", "id = ?1", 1, 2).ScanSlice(nil) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpdateStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &sqldb.TaggedStructReflector{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", @@ -204,30 +211,31 @@ func TestUpdateStructQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(testRow) - expected := `UPDATE public.table SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9 WHERE "id"=$1` - err := conn.UpdateStruct("public.table", row) + expected := `UPDATE public.table SET "int"=?2, "bool"=?3, "str"=?4, "str_ptr"=?5, "nil_ptr"=?6, "untagged_field"=?7, "created_at"=?8, "bools"=?9 WHERE "id"=?1` + err := db.UpdateStruct(ctx, "public.table", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "bool"=$2,"str"=$3,"created_at"=$4 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) + expected = `UPDATE public.table SET "bool"=?2, "str"=?3, "created_at"=?4 WHERE "id"=?1` + err = db.UpdateStruct(ctx, "public.table", row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "int"=$2,"bool"=$3,"str_ptr"=$4,"nil_ptr"=$5,"created_at"=$6 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) + expected = `UPDATE public.table SET "int"=?2, "bool"=?3, "str_ptr"=?4, "nil_ptr"=?5, "created_at"=?6 WHERE "id"=?1` + err = db.UpdateStruct(ctx, "public.table", row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpsertStructQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &sqldb.TaggedStructReflector{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", @@ -236,12 +244,13 @@ func TestUpsertStructQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(testRow) - expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + - ` ON CONFLICT("id") DO UPDATE SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9` + expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9)` + + ` ON CONFLICT("id") DO UPDATE SET "int"=?2, "bool"=?3, "str"=?4, "str_ptr"=?5, "nil_ptr"=?6, "untagged_field"=?7, "created_at"=?8, "bools"=?9` - err := conn.UpsertStruct("public.table", row) + err := db.UpsertStruct(ctx, "public.table", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -256,7 +265,7 @@ type multiPrimaryKeyRow struct { func TestUpsertStructMultiPKQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &sqldb.TaggedStructReflector{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", @@ -265,18 +274,19 @@ func TestUpsertStructMultiPKQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(multiPrimaryKeyRow) - expected := `INSERT INTO public.multi_pk("first_id","second_id","third_id","created_at") VALUES($1,$2,$3,$4) ON CONFLICT("first_id","second_id","third_id") DO UPDATE SET "created_at"=$4` + expected := `INSERT INTO public.multi_pk("first_id","second_id","third_id","created_at") VALUES(?1,?2,?3,?4) ON CONFLICT("first_id","second_id","third_id") DO UPDATE SET "created_at"=?4` - err := conn.UpsertStruct("public.multi_pk", row) + err := db.UpsertStruct(ctx, "public.multi_pk", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } func TestUpdateStructMultiPKQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) - naming := &sqldb.TaggedStructFieldMapping{ + naming := &sqldb.TaggedStructReflector{ NameTag: "db", Ignore: "-", PrimaryKey: "pk", @@ -285,11 +295,13 @@ func TestUpdateStructMultiPKQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(multiPrimaryKeyRow) - expected := `UPDATE public.multi_pk SET "created_at"=$4 WHERE "first_id"=$1 AND "second_id"=$2 AND "third_id"=$3` + expected := `UPDATE public.multi_pk SET "created_at"=?4 WHERE "first_id"=?1 AND "second_id"=?2 AND "third_id"=?3` - err := conn.UpdateStruct("public.multi_pk", row) + err := db.UpdateStruct(ctx, "public.multi_pk", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } +*/ diff --git a/mockconn/errors.go b/_mockconn/errors.go similarity index 100% rename from mockconn/errors.go rename to _mockconn/errors.go diff --git a/mockconn/onetimerowsprovider.go b/_mockconn/onetimerowsprovider.go similarity index 69% rename from mockconn/onetimerowsprovider.go rename to _mockconn/onetimerowsprovider.go index 51ca5ce..5d384a0 100644 --- a/mockconn/onetimerowsprovider.go +++ b/_mockconn/onetimerowsprovider.go @@ -44,25 +44,25 @@ func (p *OneTimeRowsProvider) AddRowsScannerQuery(scanner sqldb.RowsScanner, que p.rowsScanners[key] = scanner } -func (p *OneTimeRowsProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { - p.mtx.Lock() - defer p.mtx.Unlock() +// func (p *OneTimeRowsProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { +// p.mtx.Lock() +// defer p.mtx.Unlock() - key := uniqueQueryString(query, args) - scanner := p.rowScanners[key] - delete(p.rowScanners, key) - return scanner -} +// key := uniqueQueryString(query, args) +// scanner := p.rowScanners[key] +// delete(p.rowScanners, key) +// return scanner +// } -func (p *OneTimeRowsProvider) QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { - p.mtx.Lock() - defer p.mtx.Unlock() +// func (p *OneTimeRowsProvider) QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { +// p.mtx.Lock() +// defer p.mtx.Unlock() - key := uniqueQueryString(query, args) - scanner := p.rowsScanners[key] - delete(p.rowScanners, key) - return scanner -} +// key := uniqueQueryString(query, args) +// scanner := p.rowsScanners[key] +// delete(p.rowScanners, key) +// return scanner +// } func uniqueQueryString(query string, args []any) string { var b strings.Builder diff --git a/mockconn/row.go b/_mockconn/row.go similarity index 97% rename from mockconn/row.go rename to _mockconn/row.go index 579f754..f240958 100644 --- a/mockconn/row.go +++ b/_mockconn/row.go @@ -9,17 +9,15 @@ import ( "slices" "strconv" "time" - - sqldb "github.com/domonda/go-sqldb" ) // Row implements impl.Row with the fields of a struct as column values. type Row struct { rowStructVal reflect.Value - columnNamer sqldb.StructFieldMapper + columnNamer StructReflector } -func NewRow(rowStruct any, columnNamer sqldb.StructFieldMapper) *Row { +func NewRow(rowStruct any, columnNamer StructReflector) *Row { val := reflect.ValueOf(rowStruct) for val.Kind() == reflect.Ptr { val = val.Elem() @@ -30,7 +28,7 @@ func NewRow(rowStruct any, columnNamer sqldb.StructFieldMapper) *Row { } } -func (r *Row) StructFieldMapper() sqldb.StructFieldMapper { +func (r *Row) StructReflector() StructReflector { return r.columnNamer } @@ -38,7 +36,8 @@ func (r *Row) Columns() ([]string, error) { columns := make([]string, r.rowStructVal.NumField()) for i := range columns { field := r.rowStructVal.Type().Field(i) - _, columns[i], _, _ = r.columnNamer.MapStructField(field) + col, _ = r.columnNamer.MapStructField(field) + columns[i] = col.Name } return columns, nil } diff --git a/mockconn/row_test.go b/_mockconn/row_test.go similarity index 89% rename from mockconn/row_test.go rename to _mockconn/row_test.go index f63ec70..8ff83b4 100644 --- a/mockconn/row_test.go +++ b/_mockconn/row_test.go @@ -21,7 +21,7 @@ func TestRow(t *testing.T) { str := "Hello World!" input := Struct{"myID", 66, -1, &str, nil, pq.BoolArray{true, false}} - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &sqldb.TaggedStructReflector{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} row := NewRow(input, naming) cols, err := row.Columns() diff --git a/mockconn/rows.go b/_mockconn/rows.go similarity index 95% rename from mockconn/rows.go rename to _mockconn/rows.go index 8698fe5..e4a5ebc 100644 --- a/mockconn/rows.go +++ b/_mockconn/rows.go @@ -3,8 +3,6 @@ package mockconn import ( "errors" "reflect" - - sqldb "github.com/domonda/go-sqldb" ) type Rows struct { @@ -14,7 +12,7 @@ type Rows struct { err error } -func NewRowsFromStructs(rowStructs any, columnNamer sqldb.StructFieldMapper) *Rows { +func NewRowsFromStructs(rowStructs any, columnNamer StructReflector) *Rows { v := reflect.ValueOf(rowStructs) t := v.Type() if t.Kind() != reflect.Array && t.Kind() != reflect.Slice { diff --git a/mockconn/rows_test.go b/_mockconn/rows_test.go similarity index 93% rename from mockconn/rows_test.go rename to _mockconn/rows_test.go index 608cbb1..c73bb61 100644 --- a/mockconn/rows_test.go +++ b/_mockconn/rows_test.go @@ -26,7 +26,7 @@ func TestRows(t *testing.T) { input = append(input, &Struct{"myID", i, -1, &str, nil, pq.BoolArray{true, false, i%2 == 0}}) } - naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} + naming := &sqldb.TaggedStructReflector{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} rows := NewRowsFromStructs(input, naming) cols, err := rows.Columns() diff --git a/_mockconn/rowsprovider.go b/_mockconn/rowsprovider.go new file mode 100644 index 0000000..e056d1b --- /dev/null +++ b/_mockconn/rowsprovider.go @@ -0,0 +1,11 @@ +package mockconn + +import ( + sqldb "github.com/domonda/go-sqldb" +) + +type RowsProvider interface { + Query(structFieldNamer StructReflector, query string, args ...any) (sqldb.Rows, error) + QueryRow(structFieldNamer StructReflector, query string, args ...any) sqldb.RowScanner + QueryRows(structFieldNamer StructReflector, query string, args ...any) sqldb.RowsScanner +} diff --git a/_mockconn/singlerowprovider.go b/_mockconn/singlerowprovider.go new file mode 100644 index 0000000..873793b --- /dev/null +++ b/_mockconn/singlerowprovider.go @@ -0,0 +1,33 @@ +package mockconn + +// import ( +// "context" + +// sqldb "github.com/domonda/go-sqldb" +// "github.com/domonda/go-sqldb/impl" +// ) + +// // NewSingleRowProvider a RowsProvider implementation +// // with a single row that will be re-used for every query. +// func NewSingleRowProvider(row *Row) RowsProvider { +// return &singleRowProvider{row: row, argFmt: DefaultArgFmt} +// } + +// // SingleRowProvider implements RowsProvider with a single Row +// // that will be re-used for every query. +// type singleRowProvider struct { +// row *Row +// argFmt string +// } + +// func (p *singleRowProvider) Query(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) (sqldb.Rows, error) { +// panic("TODO") +// } + +// func (p *singleRowProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { +// return impl.NewRowScanner(impl.RowAsRows(p.row), structFieldNamer, query, p.argFmt, args) +// } + +// func (p *singleRowProvider) QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner { +// return impl.NewRowsScanner(context.Background(), NewRows(p.row), structFieldNamer, query, p.argFmt, args) +// } diff --git a/mockconn/transaction.go b/_mockconn/transaction.go similarity index 58% rename from mockconn/transaction.go rename to _mockconn/transaction.go index 5611e23..563fcc9 100644 --- a/mockconn/transaction.go +++ b/_mockconn/transaction.go @@ -1,8 +1,10 @@ package mockconn +/* import ( "context" "database/sql" + "errors" "fmt" "github.com/domonda/go-sqldb" @@ -14,32 +16,14 @@ type transaction struct { no uint64 } -func (conn transaction) Context() context.Context { return conn.connection.ctx } - -func (conn transaction) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.connection.ctx { - return conn - } - return transaction{ - connection: conn.connection.WithContext(ctx).(*connection), - opts: conn.opts, - no: conn.no, - } -} - -func (conn transaction) IsTransaction() bool { - return true +func (conn transaction) Transaction() (no uint64, opts *sql.TxOptions) { + return conn.no, conn.opts } -func (conn transaction) TransactionNo() uint64 { - return conn.no -} - -func (conn transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn transaction) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { +func (conn transaction) Begin(ctx context.Context, no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, "BEGIN") } @@ -71,3 +55,4 @@ func (conn transaction) UnlistenChannel(channel string) (err error) { func (conn transaction) Close() error { return conn.Rollback() } +*/ diff --git a/columinfo.go b/columinfo.go new file mode 100644 index 0000000..cc119e0 --- /dev/null +++ b/columinfo.go @@ -0,0 +1,12 @@ +package sqldb + +type ColumnInfo struct { + Name string + PrimaryKey bool + HasDefault bool + ReadOnly bool +} + +func (c *ColumnInfo) IsEmbeddedField() bool { + return c.Name == "" +} diff --git a/columnfilter.go b/columnfilter.go deleted file mode 100644 index a70ea89..0000000 --- a/columnfilter.go +++ /dev/null @@ -1,89 +0,0 @@ -package sqldb - -import ( - "reflect" -) - -type ColumnFilter interface { - IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool -} - -type ColumnFilterFunc func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool - -func (f ColumnFilterFunc) IgnoreColumn(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return f(name, flags, fieldType, fieldValue) -} - -func IgnoreColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, ignore := range names { - if name == ignore { - return true - } - } - return false - }) -} - -func OnlyColumns(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, include := range names { - if name == include { - return false - } - } - return true - }) -} - -func IgnoreStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, ignore := range names { - if fieldType.Name == ignore { - return true - } - } - return false - }) -} - -func OnlyStructFields(names ...string) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, include := range names { - if fieldType.Name == include { - return false - } - } - return true - }) -} - -func IgnoreFlags(ignore FieldFlag) ColumnFilter { - return ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags&ignore != 0 - }) -} - -var IgnoreDefault ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.Default() -}) - -var IgnorePrimaryKey ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.PrimaryKey() -}) - -var IgnoreReadOnly ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.ReadOnly() -}) - -var IgnoreNull ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return IsNull(fieldValue) -}) - -var IgnoreNullOrZero ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return IsNullOrZero(fieldValue) -}) - -var IgnoreNullOrZeroDefault ColumnFilter = ColumnFilterFunc(func(name string, flags FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - return flags.Default() && IsNullOrZero(fieldValue) -}) diff --git a/config.go b/config.go index 071fa7e..d657163 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,8 @@ import ( "database/sql" "fmt" "net/url" + "strconv" + "strings" "time" ) @@ -19,6 +21,12 @@ type Config struct { Database string `json:"database"` Extra map[string]string `json:"misc,omitempty"` + // ReadOnly sets the database connection to read-only mode + // if supported by the database connection. + // + // The default is false. + ReadOnly bool `json:"readOnly,omitempty"` + // MaxOpenConns sets the maximum number of open connections to the database. // // If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than @@ -47,18 +55,49 @@ type Config struct { // // If ConnMaxLifetime <= 0, connections are not closed due to a connection's age. ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"` +} - DefaultIsolationLevel sql.IsolationLevel `json:"-"` - Err error `json:"-"` +// ParseConfig parses a connection URI string and returns a Config. +// The URI must be in the format: +// +// driver://user:password@host:port/database?key=value&key2=value2 +// +// For example: +// +// postgres://user:password@localhost:5432/database?sslmode=disable +// +// See also [Config.String] +func ParseConfig(uri string) (*Config, error) { + parsed, err := url.Parse(uri) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(parsed.Port(), 10, 16) + if err != nil { + return nil, err + } + password, _ := parsed.User.Password() + config := &Config{ + Driver: parsed.Scheme, + Host: parsed.Hostname(), + Port: uint16(port), + User: parsed.User.Username(), + Password: password, + Database: strings.TrimPrefix(parsed.Path, "/"), + } + if vals := parsed.Query(); len(vals) > 0 { + config.Extra = make(map[string]string) + for key, val := range vals { + config.Extra[key] = val[0] + } + } + return config, nil } // Validate returns Config.Err if it is not nil // or an error if the Config does not have // a Driver, Host, or Database. func (c *Config) Validate() error { - if c.Err != nil { - return c.Err - } if c.Driver == "" { return fmt.Errorf("missing sqldb.Config.Driver") } @@ -71,47 +110,67 @@ func (c *Config) Validate() error { return nil } -// ConnectURL for connecting to a database -func (c *Config) ConnectURL() string { +// URL returns a [*url.URL] with the connection parameters +// for connecting to a database based on the Config. +func (c *Config) URL() *url.URL { extra := make(url.Values) for key, val := range c.Extra { extra.Add(key, val) } - u := url.URL{ + u := &url.URL{ Scheme: c.Driver, Host: c.Host, - Path: c.Database, + Path: "/" + c.Database, RawQuery: extra.Encode(), } if c.Port != 0 { - u.Host = fmt.Sprintf("%s:%d", c.Host, c.Port) + u.Host += fmt.Sprintf(":%d", c.Port) } if c.User != "" { u.User = url.UserPassword(c.User, c.Password) } - return u.String() + return u +} + +// String returns the connection URI string for the Config +// without the password and implements the [fmt.Stringer] interface. +// +// To get the full connection URI including the password use [Config.URL]. +// +// The returned string will not include the following fields: +// - Password +// - MaxOpenConns +// - MaxIdleConns +// - ConnMaxLifetime +// - DefaultIsolationLevel +// - Err +// +// See also [ParseConfig] +func (c *Config) String() string { + uri := c.URL() + uri.User = url.User(c.User) + return uri.String() } -// Connect opens a new sql.DB connection, +// Connect opens a new [sql.DB] connection, // sets all Config values and performs a ping with ctx. -// The sql.DB will be returned if the ping was successful. +// The [sql.DB] will be returned if the ping was successful. func (c *Config) Connect(ctx context.Context) (*sql.DB, error) { err := c.Validate() if err != nil { return nil, err } - db, err := sql.Open(c.Driver, c.ConnectURL()) + db, err := sql.Open(c.Driver, c.URL().String()) if err != nil { - return nil, err + return nil, fmt.Errorf("error opening database connection: %w", err) } db.SetMaxOpenConns(c.MaxOpenConns) db.SetMaxIdleConns(c.MaxIdleConns) db.SetConnMaxLifetime(c.ConnMaxLifetime) err = db.PingContext(ctx) if err != nil { - e := db.Close() - if e != nil { - err = fmt.Errorf("%w, then %s", err, e) + if e := db.Close(); e != nil { + err = fmt.Errorf("%w, then %w", err, e) } return nil, err } diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..8f4a813 --- /dev/null +++ b/config_test.go @@ -0,0 +1,44 @@ +package sqldb + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfigURL(t *testing.T) { + tests := []struct { + uri string + wantURIWithoutPassword string + want *Config + wantErr bool + }{ + { + uri: "postgres://user:password@localhost:5432/database?sslmode=disable", + wantURIWithoutPassword: "postgres://user@localhost:5432/database?sslmode=disable", + want: &Config{ + Driver: "postgres", + Host: "localhost", + Port: 5432, + User: "user", + Password: "password", + Database: "database", + Extra: map[string]string{"sslmode": "disable"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.uri, func(t *testing.T) { + got, err := ParseConfig(tt.uri) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + assert.Equal(t, tt.uri, got.URL().String(), "convertig back to URI should match original") + assert.Equal(t, tt.wantURIWithoutPassword, got.String(), "convertig back to URI without password") + }) + } +} diff --git a/connection.go b/connection.go index 054f43e..993eef5 100644 --- a/connection.go +++ b/connection.go @@ -14,111 +14,72 @@ type ( OnUnlistenFunc func(channel string) ) -// PlaceholderFormatter is an interface for formatting query parameter placeholders -// implemented by database connections. -type PlaceholderFormatter interface { - // Placeholder formats a query parameter placeholder - // for the paramIndex starting at zero. - Placeholder(paramIndex int) string +type TransactionState struct { + ID uint64 + Opts *sql.TxOptions } -// Connection represents a database connection or transaction -type Connection interface { - PlaceholderFormatter +func (ts TransactionState) Active() bool { + return ts.ID != 0 +} + +type Preparer interface { + // Prepare a statement for execution. + Prepare(ctx context.Context, query string) (Stmt, error) +} - // Context that all connection operations use. - // See also WithContext. - Context() context.Context +type Executor interface { + // Exec executes a query with optional args. + Exec(ctx context.Context, query string, args ...any) error +} - // WithContext returns a connection that uses the passed - // context for its operations. - WithContext(ctx context.Context) Connection +type Querier interface { + // Query queries rows with optional args. + // Any error will be returned by the Rows.Err method. + Query(ctx context.Context, query string, args ...any) Rows +} - // WithStructFieldMapper returns a copy of the connection - // that will use the passed StructFieldMapper. - WithStructFieldMapper(StructFieldMapper) Connection +// Connection represents a database connection or transaction +type Connection interface { + // Config returns the configuration used + // to create this connection. + Config() *Config - // StructFieldMapper used by methods of this Connection. - StructFieldMapper() StructFieldMapper + // Stats returns the sql.DBStats of this connection. + Stats() sql.DBStats // Ping returns an error if the database // does not answer on this connection // with an optional timeout. // The passed timeout has to be greater zero // to be considered. - Ping(timeout time.Duration) error + Ping(ctx context.Context, timeout time.Duration) error - // Stats returns the sql.DBStats of this connection. - Stats() sql.DBStats + // QueryFormatter has methods for formatting parts + // of a query dependent on the database driver. + QueryFormatter - // Config returns the configuration used - // to create this connection. - Config() *Config + Preparer + Executor + Querier - // ValidateColumnName returns an error - // if the passed name is not valid for a - // column of the connection's database. - ValidateColumnName(name string) error + // DefaultIsolationLevel returns the isolation level of the database + // driver that is used when no isolation level + // is specified when beginning a new transaction. + DefaultIsolationLevel() sql.IsolationLevel - // Exec executes a query with optional args. - Exec(query string, args ...any) error - - // Update table rows(s) with values using the where statement with passed in args starting at $1. - Update(table string, values Values, where string, args ...any) error - - // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 - // and returning a single row with the columns specified in returning argument. - UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner - - // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 - // and returning multiple rows with the columns specified in returning argument. - UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner - - // UpdateStruct updates a row in a table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // UpsertStruct upserts a row to table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - // If inserting conflicts on the primary key column(s), then an update is performed. - UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // QueryRow queries a single row and returns a RowScanner for the results. - QueryRow(query string, args ...any) RowScanner - - // QueryRows queries multiple rows and returns a RowsScanner for the results. - QueryRows(query string, args ...any) RowsScanner - - // IsTransaction returns if the connection is a transaction - IsTransaction() bool - - // TransactionNo returns the globally unique number of the transaction - // or zero if the connection is not a transaction. - // Implementations should use the package function NextTransactionNo - // to aquire a new number in a threadsafe way. - TransactionNo() uint64 - - // TransactionOptions returns the sql.TxOptions of the - // current transaction and true as second result value, - // or false if the connection is not a transaction. - TransactionOptions() (*sql.TxOptions, bool) + // Transaction returns the transaction state of the connection + Transaction() TransactionState // Begin a new transaction. // If the connection is already a transaction then a brand - // new transaction will begin on the parent's connection. - // The passed no will be returnd from the transaction's - // Connection.TransactionNo method. - // Implementations should use the package function NextTransactionNo - // to aquire a new number in a threadsafe way. - Begin(opts *sql.TxOptions, no uint64) (Connection, error) + // new transaction will begin based on the connection + // that started this transaction. + // The passed id and opts will be returned from the transaction's + // Connection.Transaction method as TransactionState. + // Implementations should use the function NextTransactionID + // to aquire a new ID in a threadsafe way. + Begin(ctx context.Context, id uint64, opts *sql.TxOptions) (Connection, error) // Commit the current transaction. // Returns ErrNotWithinTransaction if the connection @@ -130,6 +91,14 @@ type Connection interface { // is not within a transaction. Rollback() error + // Close the connection. + // Transactions will be rolled back. + Close() error +} + +type ListenerConnection interface { + Connection + // ListenOnChannel will call onNotify for every channel notification // and onUnlisten if the channel gets unlistened // or the listener connection gets closed for some reason. @@ -146,8 +115,4 @@ type Connection interface { // IsListeningOnChannel returns if a channel is listened to. IsListeningOnChannel(channel string) bool - - // Close the connection. - // Transactions will be rolled back. - Close() error } diff --git a/const.go b/const.go new file mode 100644 index 0000000..eb68a94 --- /dev/null +++ b/const.go @@ -0,0 +1,19 @@ +package sqldb + +import ( + "context" + "database/sql" + "database/sql/driver" + "reflect" + "time" +) + +var ( + typeOfError = reflect.TypeFor[error]() + typeOfContext = reflect.TypeFor[context.Context]() + typeOfSQLScanner = reflect.TypeFor[sql.Scanner]() + typeOfDriverValuer = reflect.TypeFor[driver.Valuer]() + typeOfTime = reflect.TypeFor[time.Time]() + typeOfByte = reflect.TypeFor[byte]() + typeOfByteSlice = reflect.TypeFor[[]byte]() +) diff --git a/db/columnfilter.go b/db/columnfilter.go deleted file mode 100644 index 386890e..0000000 --- a/db/columnfilter.go +++ /dev/null @@ -1,30 +0,0 @@ -package db - -import ( - "github.com/domonda/go-sqldb" -) - -var ( - IgnoreDefault = sqldb.IgnoreDefault - IgnorePrimaryKey = sqldb.IgnorePrimaryKey - IgnoreReadOnly = sqldb.IgnoreReadOnly - IgnoreNull = sqldb.IgnoreNull - IgnoreNullOrZero = sqldb.IgnoreNullOrZero - IgnoreNullOrZeroDefault = sqldb.IgnoreNullOrZeroDefault -) - -func IgnoreColumns(names ...string) sqldb.ColumnFilter { - return sqldb.IgnoreColumns(names...) -} - -func OnlyColumns(names ...string) sqldb.ColumnFilter { - return sqldb.OnlyColumns(names...) -} - -func IgnoreStructFields(names ...string) sqldb.ColumnFilter { - return sqldb.IgnoreStructFields(names...) -} - -func OnlyStructFields(names ...string) sqldb.ColumnFilter { - return sqldb.OnlyStructFields(names...) -} diff --git a/db/config.go b/db/config.go index bd6c1b4..478fa5e 100644 --- a/db/config.go +++ b/db/config.go @@ -2,11 +2,24 @@ package db import ( "context" - "errors" + "database/sql" + "database/sql/driver" + "reflect" + "time" "github.com/domonda/go-sqldb" ) +var ( + typeOfError = reflect.TypeFor[error]() + typeOfContext = reflect.TypeFor[context.Context]() + typeOfSQLScanner = reflect.TypeFor[sql.Scanner]() + typeOfDriverValuer = reflect.TypeFor[driver.Valuer]() + typeOfTime = reflect.TypeFor[time.Time]() + typeOfByte = reflect.TypeFor[byte]() + typeOfByteSlice = reflect.TypeFor[[]byte]() +) + var ( // Number of retries used for a SerializedTransaction // before it fails @@ -14,10 +27,95 @@ var ( ) var ( - globalConn = sqldb.ConnectionWithError( - context.Background(), - errors.New("database connection not initialized"), - ) - globalConnCtxKey int + defaultStructReflector sqldb.StructReflector = sqldb.NewTaggedStructReflector() + structReflectorCtxKey int +) + +func GetStructReflector(ctx context.Context) sqldb.StructReflector { + if r, ok := ctx.Value(&structReflectorCtxKey).(sqldb.StructReflector); ok { + return r + } + return defaultStructReflector +} + +func SetStructReflector(reflector sqldb.StructReflector) { + if reflector == nil { + panic("can't set nil StructReflector") + } + defaultStructReflector = reflector +} + +func ContextWithStructReflector(ctx context.Context, reflector sqldb.StructReflector) context.Context { + return context.WithValue(ctx, &structReflectorCtxKey, reflector) +} + +var ( + globalConn sqldb.Connection = sqldb.NewErrConn(sqldb.ErrNoDatabaseConnection) + globalConnCtxKey int + + queryBuilderFunc sqldb.QueryBuilderFunc = sqldb.DefaultQueryBuilder + queryBuilderFuncCtxKey int + serializedTransactionCtxKey int ) + +// SetConn sets the global connection that will be returned by [Conn] +// if there is no other connection in the context passed to [Conn]. +func SetConn(c sqldb.Connection) { + if c == nil { + panic("can't set nil sqldb.Connection") // Prefer to panic early + } + globalConn = c +} + +// Conn returns a non nil sqldb.Connection from ctx +// or the global connection set with SetConn. +func Conn(ctx context.Context) sqldb.Connection { + return ConnOr(ctx, globalConn) +} + +// ConnOr returns a non nil sqldb.Connection from ctx +// or the passed defaultConn. +func ConnOr(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { + if c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection); c != nil { + return c + } + return defaultConn +} + +// ContextWithConn returns a new context with the passed sqldb.Connection +// added as value so it can be retrieved again using [Conn]. +// Passing a nil connection causes [Conn] to return the global connection +// configured with [SetConn]. +func ContextWithConn(ctx context.Context, conn sqldb.Connection) context.Context { + return context.WithValue(ctx, &globalConnCtxKey, conn) +} + +// ContextWithGlobalConn returns a new context with the global connection +// added as value so it can be retrieved again using Conn(ctx). +func ContextWithGlobalConn(ctx context.Context) context.Context { + return ContextWithConn(ctx, globalConn) +} + +// Close the global connection that was configured with [SetConn]. +func Close() error { + return globalConn.Close() +} + +func SetQueryBuilderFunc(f sqldb.QueryBuilderFunc) { + if f == nil { + panic("can't set nil sqldb.QueryBuilderFunc") // Prefer to panic early + } + queryBuilderFunc = f +} + +func ContextWithQueryBuilderFunc(ctx context.Context, f sqldb.QueryBuilderFunc) context.Context { + return context.WithValue(ctx, &queryBuilderFuncCtxKey, f) +} + +func QueryBuilderFuncFromContext(ctx context.Context) sqldb.QueryBuilderFunc { + if f, _ := ctx.Value(&queryBuilderFuncCtxKey).(sqldb.QueryBuilderFunc); f != nil { + return f + } + return queryBuilderFunc +} diff --git a/db/conn.go b/db/conn.go deleted file mode 100644 index 281e248..0000000 --- a/db/conn.go +++ /dev/null @@ -1,54 +0,0 @@ -package db - -import ( - "context" - - "github.com/domonda/go-sqldb" -) - -// SetConn sets the global connection returned by Conn -// if there is no other connection in the context passed to Conn. -func SetConn(c sqldb.Connection) { - if c == nil { - panic("must not set nil sqldb.Connection") - } - globalConn = c -} - -// Conn returns a non nil sqldb.Connection from ctx -// or the global connection set with SetConn. -// The returned connection will use the passed context. -// See sqldb.Connection.WithContext -func Conn(ctx context.Context) sqldb.Connection { - return ConnDefault(ctx, globalConn) -} - -// ConnDefault returns a non nil sqldb.Connection from ctx -// or the passed defaultConn. -// The returned connection will use the passed context. -// See sqldb.Connection.WithContext -func ConnDefault(ctx context.Context, defaultConn sqldb.Connection) sqldb.Connection { - c, _ := ctx.Value(&globalConnCtxKey).(sqldb.Connection) - if c == nil { - c = defaultConn - } - if c.Context() == ctx { - return c - } - return c.WithContext(ctx) -} - -// ContextWithConn returns a new context with the passed sqldb.Connection -// added as value so it can be retrieved again using Conn(ctx). -// Passing a nil connection causes Conn(ctx) -// to return the global connection set with SetConn. -func ContextWithConn(ctx context.Context, conn sqldb.Connection) context.Context { - return context.WithValue(ctx, &globalConnCtxKey, conn) -} - -// IsTransaction indicates if the connection from the context, -// or the default connection if the context has none, -// is a transaction. -func IsTransaction(ctx context.Context) bool { - return Conn(ctx).IsTransaction() -} diff --git a/db/errors.go b/db/errors.go index 7b761cf..6719912 100644 --- a/db/errors.go +++ b/db/errors.go @@ -1,39 +1,16 @@ package db -import ( - "fmt" +import "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" -) - -// // WrapNonNilErrorWithQuery wraps non nil errors with a formatted query -// // if the error was not already wrapped with a query. -// // If the passed error is nil, then nil will be returned. -// func WrapNonNilErrorWithQuery(err error, query string, args []any, argFmt sqldb.PlaceholderFormatter) error { -// if err == nil { -// return nil -// } -// var wrapped errWithQuery -// if errors.As(err, &wrapped) { -// return err // already wrapped -// } -// return errWithQuery{err, query, args, argFmt} -// } - -func wrapErrorWithQuery(err error, query string, args []any, argFmt sqldb.PlaceholderFormatter) error { - return errWithQuery{err, query, args, argFmt} +// ReplaceErrNoRows returns the passed replacement error +// if errors.Is(err, sql.ErrNoRows), +// else the passed err is returned unchanged. +func ReplaceErrNoRows(err, replacement error) error { + return sqldb.ReplaceErrNoRows(err, replacement) } -type errWithQuery struct { - err error - query string - args []any - argFmt sqldb.PlaceholderFormatter -} - -func (e errWithQuery) Unwrap() error { return e.err } - -func (e errWithQuery) Error() string { - return fmt.Sprintf("%s from query: %s", e.err, impl.FormatQuery2(e.query, e.argFmt, e.args...)) +// IsOtherThanErrNoRows returns true if the passed error is not nil +// and does not unwrap to, or is sql.ErrNoRows. +func IsOtherThanErrNoRows(err error) bool { + return sqldb.IsOtherThanErrNoRows(err) } diff --git a/db/exec.go b/db/exec.go new file mode 100644 index 0000000..dc7caaa --- /dev/null +++ b/db/exec.go @@ -0,0 +1,18 @@ +package db + +import ( + "context" + + "github.com/domonda/go-sqldb" +) + +// Exec executes a query with optional args. +func Exec(ctx context.Context, query string, args ...any) error { + return sqldb.Exec(ctx, Conn(ctx), query, args...) +} + +// ExecStmt returns a function that can be used to execute a prepared statement +// with optional args. +func ExecStmt(ctx context.Context, query string) (execFunc func(ctx context.Context, args ...any) error, closeStmt func() error, err error) { + return sqldb.ExecStmt(ctx, Conn(ctx), query) +} diff --git a/impl/foreachrow.go b/db/foreachrow.go similarity index 75% rename from impl/foreachrow.go rename to db/foreachrow.go index ccaaf24..717f1d2 100644 --- a/impl/foreachrow.go +++ b/db/foreachrow.go @@ -1,27 +1,16 @@ -package impl +package db +/* import ( "context" - "database/sql" - "database/sql/driver" "fmt" "reflect" - "time" - sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb" ) -var ( - typeOfError = reflect.TypeFor[error]() - typeOfContext = reflect.TypeFor[context.Context]() - typeOfSQLScanner = reflect.TypeFor[sql.Scanner]() - typeOfDriverValuer = reflect.TypeFor[driver.Valuer]() - typeOfTime = reflect.TypeFor[time.Time]() - typeOfByte = reflect.TypeFor[byte]() - typeOfByteSlice = reflect.TypeFor[[]byte]() -) - -// ForEachRowCallFunc will call the passed callback with scanned values or a struct for every row. +// ForEachRowCallFunc returns a function that will call the +// passed callback with scanned values or a struct for every row. // If the callback function has a single struct or struct pointer argument, // then RowScanner.ScanStruct will be used per row, // else RowScanner.Scan will be used for all arguments of the callback. @@ -31,7 +20,7 @@ var ( // If a non nil error is returned from the callback, then this error // is returned immediately by this function without scanning further rows. // In case of zero rows, no error will be returned. -func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScanner) error, err error) { +func ForEachRowCallFunc(ctx context.Context, reflector StructReflector, callback any) (f func(sqldb.Row) error, err error) { val := reflect.ValueOf(callback) typ := val.Type() if typ.Kind() != reflect.Func { @@ -58,7 +47,7 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan } switch t.Kind() { case reflect.Struct: - if t.Implements(typeOfSQLScanner) || reflect.PtrTo(t).Implements(typeOfSQLScanner) { + if t.Implements(typeOfSQLScanner) || reflect.PointerTo(t).Implements(typeOfSQLScanner) { continue } if structArg { @@ -76,14 +65,14 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan return nil, fmt.Errorf("ForEachRowCall callback function result must be of type error: %s", typ) } - f = func(row sqldb.RowScanner) (err error) { + f = func(row sqldb.Row) (err error) { // First scan row scannedValPtrs := make([]any, typ.NumIn()-firstArg) for i := range scannedValPtrs { scannedValPtrs[i] = reflect.New(typ.In(firstArg + i)).Interface() } if structArg { - err = row.ScanStruct(scannedValPtrs[0]) + err = scanStruct(row, reflector, reflect.ValueOf(scannedValPtrs[0])) } else { err = row.Scan(scannedValPtrs...) } @@ -107,3 +96,4 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan } return f, nil } +*/ diff --git a/impl/foreachrow_test.go b/db/foreachrow_test.go similarity index 98% rename from impl/foreachrow_test.go rename to db/foreachrow_test.go index 7509553..e3491ec 100644 --- a/impl/foreachrow_test.go +++ b/db/foreachrow_test.go @@ -1,4 +1,4 @@ -package impl +package db import ( "testing" diff --git a/db/insert.go b/db/insert.go index be33483..7675103 100644 --- a/db/insert.go +++ b/db/insert.go @@ -4,176 +4,143 @@ import ( "context" "fmt" "reflect" - "strings" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) -func writeInsertQuery(w *strings.Builder, table string, names []string, format sqldb.PlaceholderFormatter) { - fmt.Fprintf(w, `INSERT INTO %s(`, table) - for i, name := range names { - if i > 0 { - w.WriteByte(',') +// todo remove +func derefStruct(v reflect.Value) (reflect.Value, error) { + strct := v + for strct.Kind() == reflect.Ptr { + if strct.IsNil() { + return reflect.Value{}, fmt.Errorf("nil pointer %s", v.Type()) } - w.WriteByte('"') - w.WriteString(name) - w.WriteByte('"') + strct = strct.Elem() } - w.WriteString(`) VALUES(`) - for i := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteString(format.Placeholder(i)) + if strct.Kind() != reflect.Struct { + return reflect.Value{}, fmt.Errorf("expected struct or pointer to struct, but got %s", v.Type()) } - w.WriteByte(')') + return strct, nil } -func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) - } +// todo remove +func pkColumnsOfStruct(reflector sqldb.StructReflector, t reflect.Type) (columns []string, err error) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + column, ok := reflector.MapStructField(field) + if !ok { + continue + } - columns, _, vals = impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - return columns, vals, nil + if column.Name == "" { + columnsEmbed, err := pkColumnsOfStruct(reflector, field.Type) + if err != nil { + return nil, err + } + columns = append(columns, columnsEmbed...) + } else if column.PrimaryKey { + // if err = conn.ValidateColumnName(column); err != nil { + // return nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) + // } + columns = append(columns, column.Name) + } + } + return columns, nil } // Insert a new row into table using the values. func Insert(ctx context.Context, table string, values sqldb.Values) error { - if len(values) == 0 { - return fmt.Errorf("Insert into table %s: no values", table) - } - conn := Conn(ctx) - - var query strings.Builder - names, vals := values.Sorted() - writeInsertQuery(&query, table, names, conn) - - err := conn.Exec(query.String(), vals...) - if err != nil { - return wrapErrorWithQuery(err, query.String(), vals, conn) - } - return nil + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + ) + return sqldb.Insert(ctx, conn, queryBuilder, table, values) } // InsertUnique inserts a new row into table using the passed values // or does nothing if the onConflict statement applies. // Returns if a row was inserted. func InsertUnique(ctx context.Context, table string, values sqldb.Values, onConflict string) (inserted bool, err error) { - if len(values) == 0 { - return false, fmt.Errorf("InsertUnique into table %s: no values", table) - } - conn := Conn(ctx) - - if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") { - onConflict = onConflict[1 : len(onConflict)-1] - } - - var query strings.Builder - names, vals := values.Sorted() - writeInsertQuery(&query, table, names, conn) - fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) - err = sqldb.ReplaceErrNoRows(err, nil) - if err != nil { - return false, wrapErrorWithQuery(err, query.String(), vals, conn) - } - return inserted, err + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + ) + return sqldb.InsertUnique(ctx, conn, queryBuilder, table, values, onConflict) } -// InsertReturning inserts a new row into table using values -// and returns values from the inserted row listed in returning. -func InsertReturning(ctx context.Context, table string, values sqldb.Values, returning string) sqldb.RowScanner { - if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) - } - conn := Conn(ctx) - - var query strings.Builder - names, vals := values.Sorted() - writeInsertQuery(&query, table, names, conn) - query.WriteString(" RETURNING ") - query.WriteString(returning) - return conn.QueryRow(query.String(), vals...) // TODO wrap error with query -} - -// InsertStruct inserts a new row into table using the connection's -// StructFieldMapper to map struct fields to column names. +// // InsertReturning inserts a new row into table using values +// // and returns values from the inserted row listed in returning. +// func InsertReturning(ctx context.Context, table string, values Values, returning string) sqldb.RowScanner { +// if len(values) == 0 { +// return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) +// } +// conn := Conn(ctx) + +// var query strings.Builder +// names, vals := values.Sorted() +// err = writeInsert(&query, table, names, conn) +// query.WriteString(" RETURNING ") +// query.WriteString(returning) +// return conn.QueryRow(query.String(), vals...) // TODO wrap error with query +// } + +// InsertRowStruct inserts a new row into table. // Optional ColumnFilter can be passed to ignore mapped columns. -func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - conn := Conn(ctx) - columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns) - if err != nil { - return err - } - - var query strings.Builder - writeInsertQuery(&query, table, columns, conn) +func InsertRowStruct(ctx context.Context, rowStruct sqldb.StructWithTableName, options ...sqldb.QueryOption) error { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.InsertRowStruct(ctx, conn, queryBuilder, reflector, rowStruct, options...) +} - err = conn.Exec(query.String(), vals...) - if err != nil { - return wrapErrorWithQuery(err, query.String(), vals, conn) - } - return nil +func InsertRowStructStmt[S sqldb.StructWithTableName](ctx context.Context, options ...sqldb.QueryOption) (insertFunc func(ctx context.Context, rowStruct S) error, closeFunc func() error, err error) { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.InsertRowStructStmt[S](ctx, conn, queryBuilder, reflector, options...) } -// InsertUniqueStruct inserts a new row into table using the connection's -// StructFieldMapper to map struct fields to column names. +// func InsertStructStmt[S StructWithTableName](ctx context.Context, query string) (stmtFunc func(ctx context.Context, rowStruct S) error, closeFunc func() error, err error) { +// conn := Conn(ctx) +// stmt, err := conn.Prepare(ctx, query) +// if err != nil { +// return nil, nil, err +// } +// stmtFunc = func(ctx context.Context, rowStruct S) error { +// TODO +// if err != nil { +// return sqldb.WrapErrorWithQuery(err, query, args, conn) +// } +// return nil +// } +// return stmtFunc, stmt.Close, nil +// } + +// InsertUniqueRowStruct inserts a new row with unique private key. // Optional ColumnFilter can be passed to ignore mapped columns. // Does nothing if the onConflict statement applies -// and returns if a row was inserted. -func InsertUniqueStruct(ctx context.Context, table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { - conn := Conn(ctx) - columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns) - if err != nil { - return false, err - } - - if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") { - onConflict = onConflict[1 : len(onConflict)-1] - } - - var query strings.Builder - writeInsertQuery(&query, table, columns, conn) - fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) - err = sqldb.ReplaceErrNoRows(err, nil) - if err != nil { - return false, wrapErrorWithQuery(err, query.String(), vals, conn) - } - return inserted, err +// and returns true if a row was inserted. +func InsertUniqueRowStruct(ctx context.Context, rowStruct sqldb.StructWithTableName, onConflict string, options ...sqldb.QueryOption) (inserted bool, err error) { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.InsertUniqueRowStruct(ctx, conn, queryBuilder, reflector, rowStruct, onConflict, options...) } -// InsertStructs inserts a slice or array of structs -// as new rows into table using the connection's -// StructFieldMapper to map struct fields to column names. +// InsertRowStructs inserts a slice structs +// as new rows into table using the DefaultStructReflector. // Optional ColumnFilter can be passed to ignore mapped columns. -// -// TODO optimized version with single query if possible -// split into multiple queries depending or maxArgs for query -func InsertStructs(ctx context.Context, table string, rowStructs any, ignoreColumns ...sqldb.ColumnFilter) error { - v := reflect.ValueOf(rowStructs) - if k := v.Type().Kind(); k != reflect.Slice && k != reflect.Array { - return fmt.Errorf("InsertStructs expects a slice or array as rowStructs, got %T", rowStructs) - } - numRows := v.Len() - return Transaction(ctx, func(ctx context.Context) error { - for i := 0; i < numRows; i++ { - err := InsertStruct(ctx, table, v.Index(i).Interface(), ignoreColumns...) - if err != nil { - return err - } - } - return nil - }) +func InsertRowStructs[S sqldb.StructWithTableName](ctx context.Context, rowStructs []S, options ...sqldb.QueryOption) error { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.InsertRowStructs(ctx, conn, queryBuilder, reflector, rowStructs, options...) } diff --git a/db/insert_test.go b/db/insert_test.go new file mode 100644 index 0000000..5c51ab5 --- /dev/null +++ b/db/insert_test.go @@ -0,0 +1,120 @@ +package db + +import ( + "context" + "database/sql" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/domonda/go-sqldb" +) + +func TestInsertRowStruct(t *testing.T) { + type Struct1 struct { + sqldb.TableName `db:"my_table"` + ID int `db:"id"` + Name string `db:"name"` + } + + tests := []struct { + name string + rowStruct sqldb.StructWithTableName + options []sqldb.QueryOption + conn *sqldb.MockConn + want sqldb.QueryRecordings + wantErr bool + }{ + { + name: "simple", + rowStruct: &Struct1{ + ID: 1, + Name: "test", + }, + conn: sqldb.NewMockConn("$", nil, os.Stdout), + want: sqldb.QueryRecordings{ + Execs: []sqldb.QueryData{ + {Query: "INSERT INTO my_table(id,name) VALUES($1,$2)", Args: []any{1, "test"}}, + }, + }, + }, + // Error cases + { + name: "TableName without name tag", + rowStruct: struct { + sqldb.TableName + ID int `db:"id"` + Name string `db:"name"` + }{}, + conn: sqldb.NewMockConn("$", nil, os.Stdout), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ContextWithConn(context.Background(), tt.conn) + err := InsertRowStruct(ctx, tt.rowStruct, tt.options...) + if tt.wantErr { + require.Error(t, err, "error from InsertStruct") + return + } + require.NoError(t, err, "error from InsertStruct") + require.Equal(t, tt.want, tt.conn.Recordings, "MockConn.Recordings") + }) + } +} + +func TestInsert(t *testing.T) { + timestamp := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + tests := []struct { + name string + table string + values sqldb.Values + conn *sqldb.MockConn + want sqldb.QueryRecordings + wantErr bool + }{ + { + name: "basic", + table: "public.my_table", + values: sqldb.Values{ + "id": 1, + "name": "Test", + "created_at": timestamp, + "updated_at": sql.NullTime{}, + }, + conn: sqldb.NewMockConn("$", nil, os.Stdout), + want: sqldb.QueryRecordings{ + Execs: []sqldb.QueryData{ + { + Query: `INSERT INTO public.my_table(created_at,id,name,updated_at) VALUES($1,$2,$3,$4)`, + Args: []any{timestamp, 1, "Test", sql.NullTime{}}, + }, + }, + }, + }, + + // Error cases + { + name: "no values", + table: "public.my_table", + values: sqldb.Values{}, + conn: sqldb.NewMockConn("$", nil, os.Stdout), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ContextWithConn(context.Background(), tt.conn) + err := Insert(ctx, tt.table, tt.values) + if tt.wantErr { + require.Error(t, err, "error from Insert") + return + } + require.NoError(t, err, "error from Insert") + require.Equal(t, tt.want, tt.conn.Recordings, "MockConn.Recordings") + }) + } +} diff --git a/db/listen.go b/db/listen.go new file mode 100644 index 0000000..3c82424 --- /dev/null +++ b/db/listen.go @@ -0,0 +1,44 @@ +package db + +import ( + "context" + "errors" + "fmt" + + "github.com/domonda/go-sqldb" +) + +// ListenOnChannel will call onNotify for every channel notification +// and onUnlisten if the channel gets unlistened +// or the listener connection gets closed for some reason. +// It is valid to pass nil for onNotify or onUnlisten to not get those callbacks. +// Note that the callbacks are called in sequence from a single go routine, +// so callbacks should offload long running or potentially blocking code to other go routines. +// Panics from callbacks will be recovered and logged. +func ListenOnChannel(ctx context.Context, channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) error { + conn, ok := Conn(ctx).(sqldb.ListenerConnection) + if !ok { + return fmt.Errorf("notifications %w", errors.ErrUnsupported) + } + return conn.ListenOnChannel(channel, onNotify, onUnlisten) +} + +// UnlistenChannel will stop listening on the channel. +// An error is returned, when the channel was not listened to +// or the listener connection is closed. +func UnlistenChannel(ctx context.Context, channel string) error { + conn, ok := Conn(ctx).(sqldb.ListenerConnection) + if !ok { + return fmt.Errorf("notifications %w", errors.ErrUnsupported) + } + return conn.UnlistenChannel(channel) +} + +// IsListeningOnChannel returns if a channel is listened to. +func IsListeningOnChannel(ctx context.Context, channel string) bool { + conn, ok := Conn(ctx).(sqldb.ListenerConnection) + if !ok { + return false + } + return conn.IsListeningOnChannel(channel) +} diff --git a/db/multirowscanner.go b/db/multirowscanner.go new file mode 100644 index 0000000..1b16f71 --- /dev/null +++ b/db/multirowscanner.go @@ -0,0 +1,156 @@ +package db + +/* + +// ScanRowsAsSlice scans all srcRows as slice into dest. +// +// The sqlRows must either have only one column compatible with the element type of the slice, +// or in case of multiple columns the slice element type must be a struct or struct pointer +// so that every column maps on exactly one struct field using the passed reflector. +// +// In case of single column rows, nil must be passed for reflector. +// +// The function closes the sqlRows. +// +// TODO two different functions for single column and multi column rows? +func ScanRowsAsSlice(ctx context.Context, sqlRows sqldb.Rows, reflector StructReflector, dest any) error { + defer sqlRows.Close() + + destVal := reflect.ValueOf(dest) + if destVal.Kind() != reflect.Ptr { + return fmt.Errorf("scan dest is not a pointer but %s", destVal.Type()) + } + if destVal.IsNil() { + return errors.New("scan dest is nil") + } + slice := destVal.Elem() + if slice.Kind() != reflect.Slice { + return fmt.Errorf("scan dest is not pointer to slice but %s", destVal.Type()) + } + sliceElemType := slice.Type().Elem() + + newSlice := reflect.MakeSlice(slice.Type(), 0, 32) + + for sqlRows.Next() { + if ctx.Err() != nil { + return ctx.Err() + } + + newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType)) + target := newSlice.Index(newSlice.Len() - 1).Addr() + if reflector != nil { + err := scanStruct(sqlRows, reflector, target) + if err != nil { + return err + } + } else { + err := sqlRows.Scan(target.Interface()) + if err != nil { + return err + } + } + } + if sqlRows.Err() != nil { + return sqlRows.Err() + } + + // Assign newSlice if there were no errors + if newSlice.Len() == 0 { + slice.SetLen(0) + } else { + slice.Set(newSlice) + } + + return nil +} + +// MultiRowScanner +type MultiRowScanner struct { + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows sqldb.Rows + reflector StructReflector + argFmt sqldb.PlaceholderFormatter // for error wrapping + query string // for error wrapping + args []any // for error wrapping +} + +func NewMultiRowScanner(ctx context.Context, rows sqldb.Rows, reflector StructReflector, argFmt sqldb.PlaceholderFormatter, query string, args []any) *MultiRowScanner { + return &MultiRowScanner{ctx, rows, reflector, argFmt, query, args} +} + +func (s *MultiRowScanner) Columns() ([]string, error) { + cols, err := s.rows.Columns() + if err != nil { + return nil, wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + return cols, nil +} + +func (s *MultiRowScanner) ScanSlice(dest any) error { + err := ScanRowsAsSlice(s.ctx, s.rows, dest, nil) + if err != nil { + return wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + return nil +} + +// TODO is ScanStructSlice needed besides ScanSlice? +func (s *MultiRowScanner) ScanStructSlice(dest any) error { + err := ScanRowsAsSlice(s.ctx, s.rows, dest, s.reflector) + if err != nil { + return wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + return nil +} + +func (s *MultiRowScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { + cols, err := s.rows.Columns() + if err != nil { + return nil, err + } + if headerRow { + rows = [][]string{cols} + } + stringScannablePtrs := make([]any, len(cols)) + err = s.ForEachRow(func(rowScanner sqldb.RowScanner) error { + row := make([]string, len(cols)) + for i := range stringScannablePtrs { + stringScannablePtrs[i] = (*sqldb.StringScannable)(&row[i]) + } + err := rowScanner.Scan(stringScannablePtrs...) + if err != nil { + return err + } + rows = append(rows, row) + return nil + }) + return rows, err +} + +func (s *MultiRowScanner) ForEachRow(callback func(*RowScanner) error) (err error) { + defer func() { + err = errors.Join(err, s.rows.Close()) + err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) + }() + + for s.rows.Next() { + if s.ctx.Err() != nil { + return s.ctx.Err() + } + + err := callback(CurrentRowScanner{s.rows, s.reflector}) + if err != nil { + return err + } + } + return s.rows.Err() +} + +func (s *MultiRowScanner) ForEachRowCall(callback any) error { + forEachRowFunc, err := ForEachRowCallFunc(s.ctx, callback) + if err != nil { + return err + } + return s.ForEachRow(forEachRowFunc) +} +*/ diff --git a/db/query.go b/db/query.go index 791ca33..0848932 100644 --- a/db/query.go +++ b/db/query.go @@ -2,11 +2,6 @@ package db import ( "context" - "database/sql" - "errors" - "fmt" - "reflect" - "strings" "time" "github.com/domonda/go-sqldb" @@ -21,192 +16,123 @@ import ( // Useful for getting the timestamp of a // SQL transaction for use in Go code. func CurrentTimestamp(ctx context.Context) time.Time { - t, err := QueryValue[time.Time](ctx, "SELECT CURRENT_TIMESTAMP") + t, err := QueryRowValue[time.Time](ctx, `SELECT CURRENT_TIMESTAMP`) if err != nil { return time.Now() } return t } -// Exec executes a query with optional args. -func Exec(ctx context.Context, query string, args ...any) error { - return Conn(ctx).Exec(query, args...) +// QueryRow queries a single row and returns a Row for the results. +func QueryRow(ctx context.Context, query string, args ...any) *sqldb.Row { + var ( + conn = Conn(ctx) + reflector = GetStructReflector(ctx) + ) + return sqldb.QueryRow(ctx, conn, reflector, query, args...) } -// QueryRow queries a single row and returns a RowScanner for the results. -func QueryRow(ctx context.Context, query string, args ...any) sqldb.RowScanner { - return Conn(ctx).QueryRow(query, args...) +// QueryRowValue queries a single row mapped to the type T. +func QueryRowValue[T any](ctx context.Context, query string, args ...any) (val T, err error) { + var ( + conn = Conn(ctx) + reflector = GetStructReflector(ctx) + ) + return sqldb.QueryRowValue[T](ctx, conn, reflector, query, args...) } -// QueryRows queries multiple rows and returns a RowsScanner for the results. -func QueryRows(ctx context.Context, query string, args ...any) sqldb.RowsScanner { - return Conn(ctx).QueryRows(query, args...) +// QueryRowValueOr queries a single value of type T +// or returns the passed defaultVal in case of sql.ErrNoRows. +func QueryRowValueOr[T any](ctx context.Context, defaultVal T, query string, args ...any) (val T, err error) { + var ( + conn = Conn(ctx) + reflector = GetStructReflector(ctx) + ) + return sqldb.QueryRowValueOr[T](ctx, conn, reflector, defaultVal, query, args...) } -// QueryValue queries a single value of type T. -func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) - if err != nil { - return *new(T), err - } - return value, nil -} - -// QueryValueReplaceErrNoRows queries a single value of type T. -// In case of an sql.ErrNoRows error, errNoRows will be called -// and its result returned together with the default value for T. -func QueryValueReplaceErrNoRows[T any](ctx context.Context, errNoRows func() error, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) - if err != nil { - if errors.Is(err, sql.ErrNoRows) && errNoRows != nil { - return *new(T), errNoRows() - } - return *new(T), err - } - return value, nil -} - -// QueryValueOr queries a single value of type T -// or returns the passed defaultValue in case of sql.ErrNoRows. -func QueryValueOr[T any](ctx context.Context, defaultValue T, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return defaultValue, nil - } - return *new(T), err - } - return value, err -} - -// QueryRowStruct queries a row and scans it as struct. -func QueryRowStruct[S any](ctx context.Context, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) - if err != nil { - return nil, err - } - return row, nil -} - -// QueryRowStructReplaceErrNoRows queries a row and scans it as struct. -// In case of an sql.ErrNoRows error, errNoRows will be called -// and its result returned as error together with nil as row. -func QueryRowStructReplaceErrNoRows[S any](ctx context.Context, errNoRows func() error, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) - if err != nil { - if errors.Is(err, sql.ErrNoRows) && errNoRows != nil { - return nil, errNoRows() - } - return nil, err - } - return row, nil -} - -// QueryRowStructOrNil queries a row and scans it as struct -// or returns nil in case of sql.ErrNoRows. -func QueryRowStructOrNil[S any](ctx context.Context, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - return row, nil +func QueryRowValueStmt[T any](ctx context.Context, query string) (queryFunc func(ctx context.Context, args ...any) (T, error), closeStmt func() error, err error) { + var ( + conn = Conn(ctx) + reflector = GetStructReflector(ctx) + ) + return sqldb.QueryRowValueStmt[T](ctx, conn, reflector, query) } -// GetRow uses the passed pkValue+pkValues to query a table row -// and scan it into a struct of type S that must have tagged fields +// ReadRowStruct uses the passed pkValue+pkValues to query a table row +// and scan it into a struct of type `*S` that must have tagged fields // with primary key flags to identify the primary key column names // for the passed pkValue+pkValues and a table name. -func GetRow[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { - // Using explicit first pkValue value - // to not be able to compile without any value - pkValues = append([]any{pkValue}, pkValues...) - t := reflect.TypeOf(row).Elem() - if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("expected struct template type instead of %s", t) - } - conn := Conn(ctx) - table, pkColumns, err := pkColumnsOfStruct(conn, t) - if err != nil { - return nil, err - } - if len(pkColumns) != len(pkValues) { - return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns)) - } - var query strings.Builder - fmt.Fprintf(&query, `SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) //#nosec G104 - for i := 1; i < len(pkColumns); i++ { - fmt.Fprintf(&query, ` AND "%s" = $%d`, pkColumns[i], i+1) //#nosec G104 - } - err = conn.QueryRow(query.String(), pkValues...).ScanStruct(&row) - if err != nil { - return nil, err - } - return row, nil +func ReadRowStruct[S sqldb.StructWithTableName](ctx context.Context, pkValue any, pkValues ...any) (S, error) { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.ReadRowStruct[S](ctx, conn, queryBuilder, reflector, pkValue, pkValues...) } -// GetRowOrNil uses the passed pkValue+pkValues to query a table row +// ReadRowStructOr uses the passed pkValue+pkValues to query a table row // and scan it into a struct of type S that must have tagged fields // with primary key flags to identify the primary key column names // for the passed pkValue+pkValues and a table name. // Returns nil as row and error if no row could be found with the // passed pkValue+pkValues. -func GetRowOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { - row, err = GetRow[S](ctx, pkValue, pkValues...) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - return row, nil +func ReadRowStructOr[S sqldb.StructWithTableName](ctx context.Context, defaultVal S, pkValue any, pkValues ...any) (S, error) { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.ReadRowStructOr[S](ctx, conn, queryBuilder, reflector, defaultVal, pkValue, pkValues...) } -func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { - mapper := conn.StructFieldMapper() - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - fieldTable, column, flags, ok := mapper.MapStructField(field) - if !ok { - continue - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } +// QueryRowsAsSlice returns queried rows as slice of the generic type T +// using the passed reflector to scan column values as struct fields. +// QueryRowsAsSlice returns queried rows as slice of the generic type T. +func QueryRowsAsSlice[T any](ctx context.Context, query string, args ...any) (rows []T, err error) { + var ( + conn = Conn(ctx) + reflector = GetStructReflector(ctx) + ) + return sqldb.QueryRowsAsSlice[T](ctx, conn, reflector, query, args...) +} - if column == "" { - fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) - if err != nil { - return "", nil, err - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } - columns = append(columns, columnsEmbed...) - } else if flags.PrimaryKey() { - if err = conn.ValidateColumnName(column); err != nil { - return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) - } - columns = append(columns, column) - } - } - return table, columns, nil +// QueryRowsAsStrings scans the query result into a table of strings +// where the first row is a header row with the column names. +// +// Byte slices will be interpreted as strings, +// nil (SQL NULL) will be converted to an empty string, +// all other types are converted with `fmt.Sprint`. +// +// If the query result has no rows, then only the header row +// and no error will be returned. +func QueryRowsAsStrings(ctx context.Context, query string, args ...any) (rows [][]string, err error) { + var ( + conn = Conn(ctx) + ) + return sqldb.QueryRowsAsStrings(ctx, conn, query, args...) } -// QueryStructSlice returns queried rows as slice of the generic type S -// which must be a struct or a pointer to a struct. -func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []S, err error) { - err = Conn(ctx).QueryRows(query, args...).ScanStructSlice(&rows) - if err != nil { - return nil, err - } - return rows, nil +// QueryCallback calls the passed callback +// with scanned values or a struct for every row. +// +// If the callback function has a single struct or struct pointer argument, +// then RowScanner.ScanStruct will be used per row, +// else RowScanner.Scan will be used for all arguments of the callback. +// If the function has a context.Context as first argument, +// then the passed ctx will be passed on. +// +// The callback can have no result or a single error result value. +// +// If a non nil error is returned from the callback, then this error +// is returned immediately by this function without scanning further rows. +// +// In case of zero rows, no error will be returned. +func QueryCallback(ctx context.Context, callback any, query string, args ...any) error { + var ( + conn = Conn(ctx) + reflector = GetStructReflector(ctx) + ) + return sqldb.QueryCallback(ctx, conn, reflector, callback, query, args...) } diff --git a/db/query_test.go b/db/query_test.go new file mode 100644 index 0000000..b9403d4 --- /dev/null +++ b/db/query_test.go @@ -0,0 +1,128 @@ +package db + +import ( + "context" + "database/sql" + "database/sql/driver" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/domonda/go-sqldb" +) + +func TestQueryValue(t *testing.T) { + query := /*sql*/ `SELECT EXISTS (SELECT FROM my_table WHERE id = $1)` + conn := sqldb.NewMockConn("$", nil, nil). + WithQueryResult( + []string{"exists"}, // columns + [][]driver.Value{{true}}, // rows + query, // query + 666, // args + ). + WithQueryResult( + []string{"exists"}, // columns + nil, // rows + query, // query + 777, // args + ) + ctx := ContextWithConn(context.Background(), conn) + + // id 666 has a row with the value true + value, err := QueryRowValue[bool](ctx, query, 666) + require.NoError(t, err) + require.Equal(t, true, value, "QueryValue[bool] result") + + // id 777 has no rows + value, err = QueryRowValue[bool](ctx, query, 777) + require.ErrorIs(t, err, sql.ErrNoRows, "QueryValue[bool] result for 777 is sql.ErrNoRows") +} + +func TestQueryValueOr(t *testing.T) { + query := /*sql*/ `SELECT EXISTS (SELECT FROM my_table WHERE id = $1)` + conn := sqldb.NewMockConn("$", nil, nil). + WithQueryResult( + []string{"exists"}, // columns + [][]driver.Value{{true}}, // rows + query, // query + 666, // args + ). + WithQueryResult( + []string{"exists"}, // columns + nil, // rows + query, // query + 777, // args + ) + ctx := ContextWithConn(context.Background(), conn) + + // id 666 has a row with the value true + value, err := QueryRowValueOr(ctx, false, query, 666) + require.NoError(t, err) + require.Equal(t, true, value, "QueryValueOr[bool] result for 666") + + // id 777 has no rows + value, err = QueryRowValueOr(ctx, false, query, 777) + require.NoError(t, err) + require.Equal(t, false, value, "QueryValueOr[bool] result for 777") +} + +func TestQueryStrings(t *testing.T) { + query := /*sql*/ `SELECT test_no, col1, col2, col3 FROM my_table WHERE test_no = $1` + tests := []struct { + name string + query string + args []any + wantRows [][]string + wantErr bool + }{ + { + name: "test_no 0: no rows", + query: query, + args: []any{0}, + wantRows: [][]string{ + {"test_no", "col1", "col2", "col3"}, + }, + }, + { + name: "test_no 1: 3 rows", + query: query, + args: []any{1}, + wantRows: [][]string{ + {"test_no", "col1", "col2", "col3"}, + {"1", "row0_col1", "row0_col2", "2025-01-02 03:04:05 +0000 UTC"}, + {"1", "row1_col1", "", "0001-01-01 00:00:00 +0000 UTC"}, + {"1", "row2_col1", "bytes", "2025-01-02 03:04:05 +0000 UTC"}, + }, + }, + } + conn := sqldb.NewMockConn("$", nil, nil). + WithQueryResult( + []string{"test_no", "col1", "col2", "col3"}, + [][]driver.Value{}, + query, + 0, + ). + WithQueryResult( + []string{"test_no", "col1", "col2", "col3"}, + [][]driver.Value{ + {int64(1), "row0_col1", "row0_col2", "2025-01-02 03:04:05 +0000 UTC"}, + {int64(1), "row1_col1", nil, time.Time{}}, + {int64(1), "row2_col1", []byte("bytes"), time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)}, + }, + query, + 1, + ) + ctx := ContextWithConn(context.Background(), conn) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotRows, err := QueryRowsAsStrings(ctx, tt.query, tt.args...) + if tt.wantErr { + require.Error(t, err, "QueryStrings() error") + return + } + require.NoError(t, err, "QueryStrings() error") + require.Equal(t, tt.wantRows, gotRows, "QueryStrings() result") + }) + } +} diff --git a/db/scanresult.go b/db/scanresult.go new file mode 100644 index 0000000..6da0d91 --- /dev/null +++ b/db/scanresult.go @@ -0,0 +1,54 @@ +package db + +// TODO move to RowScanner ? + +// // scanValues returns the values of a row exactly how they are +// // passed from the database driver to an sql.Scanner. +// // Byte slices will be copied. +// func scanValues(src sqldb.Rows) ([]any, error) { +// cols, err := src.Columns() +// if err != nil { +// return nil, err +// } +// var ( +// anys = make([]sqldb.AnyValue, len(cols)) +// result = make([]any, len(cols)) +// ) +// // result elements hold pointer to sqldb.AnyValue for scanning +// for i := range result { +// result[i] = &anys[i] +// } +// err = src.Scan(result...) +// if err != nil { +// return nil, err +// } +// // don't return pointers to sqldb.AnyValue +// // but what internal value has been scanned +// for i := range result { +// result[i] = anys[i].Val +// } +// return result, nil +// } + +// // scanStrings scans the values of a row as strings. +// // Byte slices will be interpreted as strings, +// // nil (SQL NULL) will be converted to an empty string, +// // all other types are converted with fmt.Sprint. +// func scanStrings(src sqldb.Rows) ([]string, error) { +// cols, err := src.Columns() +// if err != nil { +// return nil, err +// } +// var ( +// result = make([]string, len(cols)) +// resultPtrs = make([]any, len(cols)) +// ) +// for i := range resultPtrs { +// resultPtrs[i] = (*sqldb.StringScannable)(&result[i]) +// } +// err = src.Scan(resultPtrs...) +// if err != nil { +// return nil, err +// } +// return result, nil +// } diff --git a/db/statement.go b/db/statement.go new file mode 100644 index 0000000..9d8faf2 --- /dev/null +++ b/db/statement.go @@ -0,0 +1,39 @@ +package db + +import ( + "context" + + "github.com/domonda/go-sqldb" +) + +// Prepare a statement for execution +// with the given query string. +func Prepare(ctx context.Context, query string) (sqldb.Stmt, error) { + conn := Conn(ctx) + stmt, err := conn.Prepare(ctx, query) + if err != nil { + return nil, err + } + return stmtWithErrWrapping{stmt, conn}, nil +} + +type stmtWithErrWrapping struct { + sqldb.Stmt + fmt sqldb.QueryFormatter +} + +func (s stmtWithErrWrapping) Exec(ctx context.Context, args ...any) error { + err := s.Stmt.Exec(ctx, args...) + if err != nil { + return sqldb.WrapErrorWithQuery(err, s.PreparedQuery(), args, s.fmt) + } + return nil +} + +func (s stmtWithErrWrapping) Query(ctx context.Context, args ...any) sqldb.Rows { + rows := s.Stmt.Query(ctx, args...) + if rows.Err() != nil { + return sqldb.NewErrRows(sqldb.WrapErrorWithQuery(rows.Err(), s.PreparedQuery(), args, s.fmt)) + } + return rows +} diff --git a/db/testhelper.go b/db/testhelper.go new file mode 100644 index 0000000..c1c1248 --- /dev/null +++ b/db/testhelper.go @@ -0,0 +1,103 @@ +package db + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/domonda/go-sqldb" +) + +// TypeMapper is used to map Go types to SQL column types. +type TypeMapper interface { + // ColumnType returns the SQL column type for the given Go type. + ColumnType(reflect.Type) string +} + +// StagedTypeMapper is a TypeMapper that +// first tries to map a reflect.Type using the Types map, +// if that fails it tries to map the reflect.Kind using the Kinds map, +// and if that fails it calls Default function if it is not nil. +type StagedTypeMapper struct { + Types map[reflect.Type]string + Kinds map[reflect.Kind]string + Default func(reflect.Type) string +} + +func (m *StagedTypeMapper) ColumnType(t reflect.Type) string { + if columnType, ok := m.Types[t]; ok { + return columnType + } + if columnType, ok := m.Kinds[t.Kind()]; ok { + return columnType + } + if m.Default != nil { + return m.Default(t) + } + return "" +} + +// CreateTableForStruct is mostly used to create tests. +func CreateTableForStruct(ctx context.Context, typeMap TypeMapper, rowStruct sqldb.StructWithTableName) error { + conn := Conn(ctx) + v := reflect.ValueOf(rowStruct) + tableName, err := defaultStructReflector.TableNameForStruct(v.Type()) + if err != nil { + return err + } + tableName, err = conn.FormatTableName(tableName) + if err != nil { + return err + } + columns, fields := sqldb.ReflectStructColumnsAndFields(v, defaultStructReflector) + if len(columns) == 0 { + return fmt.Errorf("CreateTableForStruct %s: no columns at struct %T", tableName, rowStruct) + } + + var query strings.Builder + fmt.Fprintf(&query, "CREATE TABLE %s (\n ", tableName) + for i := range columns { + fieldType := fields[i] + columnName, err := conn.FormatColumnName(columns[i].Name) + if err != nil { + return err + } + columnType := typeMap.ColumnType(fieldType) + if columnType == "" { + return fmt.Errorf("CreateTableForStruct %s: no column type for field %s of type %s", tableName, columnName, fieldType) + } + if i > 0 { + query.WriteString(",\n ") + } + fmt.Fprint(&query, columnName, " ", columnType) + if columns[i].PrimaryKey { + query.WriteString(" PRIMARY KEY") + } else if !sqldb.IsNullable(fieldType) { + query.WriteString(" NOT NULL") + } + } + query.WriteString("\n)") + + return Exec(ctx, query.String()) +} + +// CreateTablesAndInsertStructs is mostly used to create tests. +func CreateTablesAndInsertStructs(ctx context.Context, typeMap TypeMapper, tables ...[]sqldb.StructWithTableName) error { + for _, rows := range tables { + if len(rows) == 0 { + continue + } + err := CreateTableForStruct(ctx, typeMap, rows[0]) + if err != nil { + return err + } + for _, row := range rows { + err := InsertRowStruct(ctx, row) + if err != nil { + return err + } + } + } + return nil +} diff --git a/db/testing.go b/db/testing.go new file mode 100644 index 0000000..8cacebf --- /dev/null +++ b/db/testing.go @@ -0,0 +1,19 @@ +package db + +import ( + "context" + "testing" + + "github.com/domonda/go-sqldb" +) + +// ContextWithNonConnectionForTest returns a new context with a sqldb.Connection +// intended for unit tests that should work without an actual database connection +// by mocking any SQL related functionality so that the connection won't be used. +// +// The transaction related methods of that connection +// simulate a transaction without any actual transaction handling. +// All other methods except Close will cause the test to fail. +func ContextWithNonConnectionForTest(ctx context.Context, t *testing.T) context.Context { + return ContextWithConn(ctx, sqldb.NonConnForTest(t)) +} diff --git a/db/transaction.go b/db/transaction.go index fb85419..e73e4ab 100644 --- a/db/transaction.go +++ b/db/transaction.go @@ -12,27 +12,45 @@ import ( "github.com/domonda/go-sqldb" ) -// ValidateWithinTransaction returns sqldb.ErrNotWithinTransaction +var noTransactionsCtxKey int + +func ContextWithoutTransactions(ctx context.Context) context.Context { + return context.WithValue(ctx, &noTransactionsCtxKey, struct{}{}) +} + +func IsContextWithoutTransactions(ctx context.Context) bool { + return ctx.Value(&noTransactionsCtxKey) != nil +} + +// IsTransaction indicates if the connection from the context, +// or the default connection if the context has none, +// is a transaction. +func IsTransaction(ctx context.Context) bool { + if IsContextWithoutTransactions(ctx) { + return false + } + return Conn(ctx).Transaction().Active() +} + +// ValidateWithinTransaction returns [sqldb.ErrNotWithinTransaction] // if the database connection from the context is not a transaction. func ValidateWithinTransaction(ctx context.Context) error { - conn := Conn(ctx) - if err := conn.Config().Err; err != nil { - return err + if IsContextWithoutTransactions(ctx) { + return sqldb.ErrNotWithinTransaction } - if !conn.IsTransaction() { + if !Conn(ctx).Transaction().Active() { return sqldb.ErrNotWithinTransaction } return nil } -// ValidateNotWithinTransaction returns sqldb.ErrWithinTransaction +// ValidateNotWithinTransaction returns [sqldb.ErrWithinTransaction] // if the database connection from the context is a transaction. func ValidateNotWithinTransaction(ctx context.Context) error { - conn := Conn(ctx) - if err := conn.Config().Err; err != nil { - return err + if IsContextWithoutTransactions(ctx) { + return nil } - if conn.IsTransaction() { + if Conn(ctx).Transaction().Active() { return sqldb.ErrWithinTransaction } return nil @@ -44,37 +62,93 @@ func DebugNoTransaction(ctx context.Context, nonTxFunc func(context.Context) err return nonTxFunc(ctx) } -// IsolatedTransaction executes txFunc within a database transaction that is passed in to txFunc as tx Connection. +// IsolatedTransaction executes txFunc within a database transaction that is passed in to txFunc as tx [Connection]. // IsolatedTransaction returns all errors from txFunc or transaction commit errors happening after txFunc. // If parentConn is already a transaction, a brand new transaction will begin on the parent's connection. // Errors and panics from txFunc will rollback the transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. +// Recovered panics are re-paniced and rollback errors after a panic are logged with [ErrLogger]. func IsolatedTransaction(ctx context.Context, txFunc func(context.Context) error) error { - return sqldb.IsolatedTransaction(Conn(ctx), nil, func(tx sqldb.Connection) error { + if IsContextWithoutTransactions(ctx) { + return txFunc(ctx) + } + return sqldb.IsolatedTransaction(ctx, Conn(ctx), nil, func(tx sqldb.Connection) error { return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } +// IsolatedTransactionResult executes txFunc within a database transaction that is passed in to txFunc as tx Connection. +// IsolatedTransactionResult returns all errors from txFunc or transaction commit errors happening after txFunc. +// If parentConn is already a transaction, a brand new transaction will begin on the parent's connection. +// Errors and panics from txFunc will rollback the transaction. +// Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. +func IsolatedTransactionResult[T any](ctx context.Context, txFunc func(context.Context) (T, error)) (result T, err error) { + err = IsolatedTransaction(ctx, func(ctx context.Context) error { + result, err = txFunc(ctx) + return err + }) + return result, err +} + // Transaction executes txFunc within a database transaction that is passed in to txFunc via the context. -// Use db.Conn(ctx) to get the transaction connection within txFunc. +// Use `db.Conn(ctx)` to get the transaction connection within txFunc. // Transaction returns all errors from txFunc or transaction commit errors happening after txFunc. -// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx sqldb.Connection +// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx [sqldb.Connection]. // and no parentConn.Begin, Commit, or Rollback calls will occour within this Transaction call. // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. +// Recovered panics are re-paniced and rollback errors after a panic are logged with [sqldb.ErrLogger]. func Transaction(ctx context.Context, txFunc func(context.Context) error) error { - return sqldb.Transaction(Conn(ctx), nil, func(tx sqldb.Connection) error { + if IsContextWithoutTransactions(ctx) { + return txFunc(ctx) + } + return sqldb.Transaction(ctx, Conn(ctx), nil, func(tx sqldb.Connection) error { return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } +// TransactionResult executes txFunc within a database transaction and returns the result of txFunc. +// Use db.Conn(ctx) to get the transaction connection within txFunc. +// Transaction returns all errors from txFunc or transaction commit errors happening after txFunc. +// If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx sqldb.Connection +// and no parentConn.Begin, Commit, or Rollback calls will occour within this TransactionResult call. +// Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. +// Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. +func TransactionResult[T any](ctx context.Context, txFunc func(context.Context) (T, error)) (result T, err error) { + err = Transaction(ctx, func(ctx context.Context) error { + result, err = txFunc(ctx) + return err + }) + return result, err +} + +// OptionalTransaction executes txFunc within a database transaction if useTransaction is true. +// If useTransaction is false, then txFunc is executed without a transaction. +// +// See [Transaction] for more details. +func OptionalTransaction(ctx context.Context, useTransaction bool, txFunc func(context.Context) error) error { + if !useTransaction { + return txFunc(ctx) + } + return Transaction(ctx, txFunc) +} + +// OptionalTransactionResult executes txFunc within a database transaction if useTransaction is true. +// If useTransaction is false, then txFunc is executed without a transaction. +// +// See [TransactionResult] for more details. +func OptionalTransactionResult[T any](ctx context.Context, useTransaction bool, txFunc func(context.Context) (T, error)) (result T, err error) { + if !useTransaction { + return txFunc(ctx) + } + return TransactionResult(ctx, txFunc) +} + // SerializedTransaction executes txFunc "serially" within a database transaction that is passed in to txFunc via the context. // Use db.Conn(ctx) to get the transaction connection within txFunc. // Transaction returns all errors from txFunc or transaction commit errors happening after txFunc. // If parentConn is already a transaction, then it is passed through to txFunc unchanged as tx sqldb.Connection // and no parentConn.Begin, Commit, or Rollback calls will occour within this Transaction call. // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. -// Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. +// Recovered panics are re-paniced and rollback errors after a panic are logged with [sqldb.ErrLogger]. // // Serialized transactions are typically necessary when an insert depends on a previous select within // the transaction, but that pre-insert select can't lock the table like it's possible with SELECT FOR UPDATE. @@ -104,8 +178,12 @@ func Transaction(ctx context.Context, txFunc func(context.Context) error) error // // Because of the retryable nature, please be careful with the size of the transaction and the retry cost. func SerializedTransaction(ctx context.Context, txFunc func(context.Context) error) error { + if IsContextWithoutTransactions(ctx) { + return txFunc(ctx) + } + // Pass nested serialized transactions through - if Conn(ctx).IsTransaction() { + if IsTransaction(ctx) { if ctx.Value(&serializedTransactionCtxKey) == nil { return errors.New("SerializedTransaction called from within a non-serialized transaction") } @@ -134,7 +212,10 @@ func SerializedTransaction(ctx context.Context, txFunc func(context.Context) err // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(context.Context) error) error { - return sqldb.Transaction(Conn(ctx), opts, func(tx sqldb.Connection) error { + if IsContextWithoutTransactions(ctx) { + return txFunc(ctx) + } + return sqldb.Transaction(ctx, Conn(ctx), opts, func(tx sqldb.Connection) error { return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -147,8 +228,11 @@ func TransactionOpts(ctx context.Context, opts *sql.TxOptions, txFunc func(conte // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with sqldb.ErrLogger. func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error) error { + if IsContextWithoutTransactions(ctx) { + return txFunc(ctx) + } opts := sql.TxOptions{ReadOnly: true} - return sqldb.Transaction(Conn(ctx), &opts, func(tx sqldb.Connection) error { + return sqldb.Transaction(ctx, Conn(ctx), &opts, func(tx sqldb.Connection) error { return txFunc(context.WithValue(ctx, &globalConnCtxKey, tx)) }) } @@ -165,8 +249,12 @@ func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error // Panics from txFunc are not recovered to rollback to the savepoint, // they should behandled by the parent Transaction function. func TransactionSavepoint(ctx context.Context, txFunc func(context.Context) error) error { + if IsContextWithoutTransactions(ctx) { + return txFunc(ctx) + } + conn := Conn(ctx) - if !conn.IsTransaction() { + if !conn.Transaction().Active() { // If not already in a transaction, then execute txFunc // within a as transaction instead of using savepoints: return Transaction(ctx, txFunc) @@ -176,14 +264,14 @@ func TransactionSavepoint(ctx context.Context, txFunc func(context.Context) erro if err != nil { return err } - err = conn.Exec("savepoint " + savepoint) + err = conn.Exec(ctx, "savepoint "+savepoint) if err != nil { return err } err = txFunc(ctx) if err != nil { - e := conn.Exec("rollback to " + savepoint) + e := conn.Exec(ctx, "rollback to "+savepoint) if e != nil && !errors.Is(e, sql.ErrTxDone) { // Double error situation, wrap err with e so it doesn't get lost err = fmt.Errorf("TransactionSavepoint error (%s) from rollback after error: %w", e, err) @@ -191,7 +279,7 @@ func TransactionSavepoint(ctx context.Context, txFunc func(context.Context) erro return err } - return conn.Exec("release savepoint " + savepoint) + return conn.Exec(ctx, "release savepoint "+savepoint) } func randomSavepoint() (string, error) { diff --git a/db/transaction_test.go b/db/transaction_test.go index a6e16ac..32e133f 100644 --- a/db/transaction_test.go +++ b/db/transaction_test.go @@ -1,5 +1,6 @@ package db +/* import ( "context" "errors" @@ -13,7 +14,7 @@ func TestSerializedTransaction(t *testing.T) { globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectSerialized := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) == nil { @@ -23,7 +24,7 @@ func TestSerializedTransaction(t *testing.T) { } expectSerializedWithError := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) == nil { @@ -67,7 +68,7 @@ func TestTransaction(t *testing.T) { globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectNonSerialized := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) != nil { @@ -77,7 +78,7 @@ func TestTransaction(t *testing.T) { } expectNonSerializedWithError := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) != nil { @@ -116,3 +117,4 @@ func TestTransaction(t *testing.T) { }) } } +*/ diff --git a/db/update.go b/db/update.go new file mode 100644 index 0000000..96c73a2 --- /dev/null +++ b/db/update.go @@ -0,0 +1,31 @@ +package db + +import ( + "context" + + "github.com/domonda/go-sqldb" +) + +// Update table rows(s) with values using the where statement with passed in args starting at $1. +func Update(ctx context.Context, table string, values sqldb.Values, where string, args ...any) error { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + ) + return sqldb.Update(ctx, conn, queryBuilder, table, values, where, args...) +} + +// UpdateStruct updates a row in a table using the exported fields +// of rowStruct which have a `db` tag that is not "-". +// If restrictToColumns are provided, then only struct fields with a `db` tag +// matching any of the passed column names will be used. +// The struct must have at least one field with a `db` tag value having a ",pk" suffix +// to mark primary key column(s). +func UpdateStruct(ctx context.Context, table string, rowStruct any, options ...sqldb.QueryOption) error { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.UpdateStruct(ctx, conn, queryBuilder, reflector, table, rowStruct, options...) +} diff --git a/db/upsert.go b/db/upsert.go new file mode 100644 index 0000000..f52dc4b --- /dev/null +++ b/db/upsert.go @@ -0,0 +1,36 @@ +package db + +import ( + "context" + + "github.com/domonda/go-sqldb" +) + +// UpsertStruct TODO +// If inserting conflicts on the primary key column(s), then an update is performed. +func UpsertStruct(ctx context.Context, rowStruct sqldb.StructWithTableName, options ...sqldb.QueryOption) error { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.UpsertStruct(ctx, conn, queryBuilder, reflector, rowStruct, options...) +} + +func UpsertStructStmt[S sqldb.StructWithTableName](ctx context.Context, options ...sqldb.QueryOption) (upsert func(ctx context.Context, rowStruct S) error, done func() error, err error) { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.UpsertStructStmt[S](ctx, conn, queryBuilder, reflector, options...) +} + +func UpsertStructs[S sqldb.StructWithTableName](ctx context.Context, rowStructs []S, options ...sqldb.QueryOption) error { + var ( + conn = Conn(ctx) + queryBuilder = QueryBuilderFuncFromContext(ctx)(conn) + reflector = GetStructReflector(ctx) + ) + return sqldb.UpsertStructs(ctx, conn, queryBuilder, reflector, rowStructs, options...) +} diff --git a/db/upsert_test.go b/db/upsert_test.go new file mode 100644 index 0000000..8553eac --- /dev/null +++ b/db/upsert_test.go @@ -0,0 +1,58 @@ +package db_test + +import ( + "context" + "os" + + "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/db" +) + +func ExampleUpsertStructStmt() { + type User struct { + sqldb.TableName `db:"public.user"` + ID int64 `db:"id,primarykey"` + Name string `db:"name"` + Email string `db:"email"` + } + + users := []User{ + {ID: 1, Name: "Alice", Email: "alice@example.com"}, + {ID: 2, Name: "Bob", Email: "bob@example.com"}, + {ID: 3, Name: "Charlie", Email: "charlie@example.com"}, + } + + conn := &sqldb.MockConn{ + QueryFormatter: sqldb.NewQueryFormatter("$"), + QueryLog: os.Stdout, + } + ctx := db.ContextWithConn(context.Background(), conn) + + err := db.Transaction(ctx, func(ctx context.Context) error { + upsert, done, err := db.UpsertStructStmt[User](ctx) + if err != nil { + return err + } + defer done() + + for _, user := range users { + err = upsert(ctx, user) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + panic(err) + } + + // Output: + // BEGIN; + // PREPARE stmt1 AS INSERT INTO public.user(id,name,email) VALUES($1,$2,$3) ON CONFLICT(id) DO UPDATE SET name=$2, email=$3; + // INSERT INTO public.user(id,name,email) VALUES(1,'Alice','alice@example.com') ON CONFLICT(id) DO UPDATE SET name='Alice', email='alice@example.com'; + // INSERT INTO public.user(id,name,email) VALUES(2,'Bob','bob@example.com') ON CONFLICT(id) DO UPDATE SET name='Bob', email='bob@example.com'; + // INSERT INTO public.user(id,name,email) VALUES(3,'Charlie','charlie@example.com') ON CONFLICT(id) DO UPDATE SET name='Charlie', email='charlie@example.com'; + // DEALLOCATE PREPARE stmt1; + // COMMIT; +} diff --git a/db/utils.go b/debug.go similarity index 79% rename from db/utils.go rename to debug.go index 3d2c210..0270765 100644 --- a/db/utils.go +++ b/debug.go @@ -1,4 +1,4 @@ -package db +package sqldb import ( "bytes" @@ -9,37 +9,21 @@ import ( "os" "time" "unicode/utf8" - - "github.com/domonda/go-sqldb" ) -// ReplaceErrNoRows returns the passed replacement error -// if errors.Is(err, sql.ErrNoRows), -// else the passed err is returned unchanged. -func ReplaceErrNoRows(err, replacement error) error { - return sqldb.ReplaceErrNoRows(err, replacement) -} - -// IsOtherThanErrNoRows returns true if the passed error is not nil -// and does not unwrap to, or is sql.ErrNoRows. -func IsOtherThanErrNoRows(err error) bool { - return sqldb.IsOtherThanErrNoRows(err) -} - // DebugPrintConn prints a line to stderr using the passed args // and appending the transaction state of the connection // and the current time of the database using `select now()` // or an error if the time could not be queried. -func DebugPrintConn(ctx context.Context, args ...any) { - opts, isTx := Conn(ctx).TransactionOptions() - if isTx { +func DebugPrintConn(ctx context.Context, conn Connection, args ...any) { + if tx := conn.Transaction(); tx.Active() { args = append(args, "SQL-Transaction") - if optsStr := TxOptionsString(opts); optsStr != "" { + if optsStr := TxOptionsString(tx.Opts); optsStr != "" { args = append(args, "Isolation", optsStr) } } var t time.Time - err := Conn(ctx).QueryRow("SELECT CURRENT_TIMESTAMP").Scan(&t) + err := conn.Query(ctx, "SELECT CURRENT_TIMESTAMP").Scan(&t) if err == nil { args = append(args, "CURRENT_TIMESTAMP:", t) } else { diff --git a/errconn.go b/errconn.go new file mode 100644 index 0000000..a7f15ef --- /dev/null +++ b/errconn.go @@ -0,0 +1,93 @@ +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +// ErrConn implements ListenerConnection +var _ Connection = ErrConn{} + +// NewErrConn returns an ErrConn with the passed error. +func NewErrConn(err error) ErrConn { + if err == nil { + panic("NewErrConn expects non nil error") + } + return ErrConn{Err: err} +} + +// ErrConn is a dummy ListenerConnection +// where all methods except Close return Err. +type ErrConn struct { + StdQueryFormatter + Err error +} + +func (e ErrConn) Ping(context.Context, time.Duration) error { + return e.Err +} + +func (e ErrConn) Stats() sql.DBStats { + return sql.DBStats{} +} + +func (e ErrConn) Config() *Config { + return &Config{ + Driver: "error", + Host: "error", + Database: "error", + Extra: map[string]string{"error": e.Err.Error()}, + } +} + +func (e ErrConn) Exec(ctx context.Context, query string, args ...any) error { + return e.Err +} + +func (e ErrConn) Query(ctx context.Context, query string, args ...any) Rows { + return NewErrRows(e.Err) +} + +func (e ErrConn) Prepare(ctx context.Context, query string) (Stmt, error) { + return nil, e.Err +} + +func (e ErrConn) DefaultIsolationLevel() sql.IsolationLevel { + return sql.LevelDefault +} + +func (e ErrConn) Transaction() TransactionState { + return TransactionState{ + ID: 0, + Opts: nil, + } +} + +func (e ErrConn) Begin(ctx context.Context, id uint64, opts *sql.TxOptions) (Connection, error) { + return nil, e.Err +} + +func (e ErrConn) Commit() error { + return e.Err +} + +func (e ErrConn) Rollback() error { + return e.Err +} + +func (e ErrConn) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return e.Err +} + +func (e ErrConn) UnlistenChannel(channel string) error { + return e.Err +} + +func (e ErrConn) IsListeningOnChannel(channel string) bool { + return false +} + +func (e ErrConn) Close() error { + return nil +} diff --git a/errors.go b/errors.go index 3427ca8..753746d 100644 --- a/errors.go +++ b/errors.go @@ -1,17 +1,9 @@ package sqldb import ( - "context" "database/sql" "errors" "fmt" - "time" -) - -var ( - _ Connection = connectionWithError{} - _ RowScanner = rowScannerWithError{} - _ RowsScanner = rowsScannerWithError{} ) // ReplaceErrNoRows returns the passed replacement error @@ -45,6 +37,8 @@ func (s sentinelError) Error() string { // Transaction errors const ( + ErrNoDatabaseConnection sentinelError = "no database connection" + // ErrWithinTransaction is returned by methods // that are not allowed within DB transactions // when the DB connection is a transaction. @@ -167,196 +161,29 @@ func (e ErrExclusionViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } -// ConnectionWithError - -// ConnectionWithError returns a dummy Connection -// where all methods return the passed error. -func ConnectionWithError(ctx context.Context, err error) Connection { +// WrapErrorWithQuery wraps an errors with a formatted query +// if the error was not already wrapped with a query. +// If the passed error is nil, then nil will be returned. +func WrapErrorWithQuery(err error, query string, args []any, queryFmt QueryFormatter) error { if err == nil { - panic("ConnectionWithError needs an error") + return nil } - return connectionWithError{ctx, err} -} - -type connectionWithError struct { - ctx context.Context - err error -} - -func (e connectionWithError) Context() context.Context { return e.ctx } - -func (e connectionWithError) WithContext(ctx context.Context) Connection { - return connectionWithError{ctx: ctx, err: e.err} -} - -func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection { - return e -} - -func (e connectionWithError) StructFieldMapper() StructFieldMapper { - return DefaultStructFieldMapping -} - -func (e connectionWithError) Ping(time.Duration) error { - return e.err -} - -func (e connectionWithError) Stats() sql.DBStats { - return sql.DBStats{} -} - -func (e connectionWithError) Config() *Config { - return &Config{Err: e.err} -} - -func (e connectionWithError) Placeholder(paramIndex int) string { - return fmt.Sprintf("$%d", paramIndex+1) -} - -func (e connectionWithError) ValidateColumnName(name string) error { - return e.err -} - -func (e connectionWithError) Exec(query string, args ...any) error { - return e.err -} - -func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { - return e.err -} - -func (e connectionWithError) UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) IsTransaction() bool { - return false -} - -func (e connectionWithError) TransactionNo() uint64 { - return 0 -} - -func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (e connectionWithError) Begin(opts *sql.TxOptions, no uint64) (Connection, error) { - return nil, e.err -} - -func (e connectionWithError) Commit() error { - return e.err -} - -func (e connectionWithError) Rollback() error { - return e.err -} - -func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { - return e.err -} - -func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { - return e.err -} - -func (e connectionWithError) UnlistenChannel(channel string) error { - return e.err -} - -func (e connectionWithError) IsListeningOnChannel(channel string) bool { - return false -} - -func (e connectionWithError) Close() error { - return e.err -} - -// RowScannerWithError - -// RowScannerWithError returns a dummy RowScanner -// where all methods return the passed error. -func RowScannerWithError(err error) RowScanner { - return rowScannerWithError{err} -} - -type rowScannerWithError struct { - err error -} - -func (e rowScannerWithError) Scan(dest ...any) error { - return e.err -} - -func (e rowScannerWithError) ScanStruct(dest any) error { - return e.err -} - -func (e rowScannerWithError) ScanValues() ([]any, error) { - return nil, e.err -} - -func (e rowScannerWithError) ScanStrings() ([]string, error) { - return nil, e.err -} - -func (e rowScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -// RowsScannerWithError - -// RowsScannerWithError returns a dummy RowsScanner -// where all methods return the passed error. -func RowsScannerWithError(err error) RowsScanner { - return rowsScannerWithError{err} -} - -type rowsScannerWithError struct { - err error -} - -func (e rowsScannerWithError) ScanSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) ScanStructSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) Columns() ([]string, error) { - return nil, e.err + var wrapped errWithQuery + if errors.As(err, &wrapped) { + return err // already wrapped + } + return errWithQuery{err, query, args, queryFmt} } -func (e rowsScannerWithError) ScanAllRowsAsStrings(headerRow bool) ([][]string, error) { - return nil, e.err +type errWithQuery struct { + err error + query string + args []any + queryFmt QueryFormatter } -func (e rowsScannerWithError) ForEachRow(callback func(RowScanner) error) error { - return e.err -} +func (e errWithQuery) Unwrap() error { return e.err } -func (e rowsScannerWithError) ForEachRowCall(callback any) error { - return e.err +func (e errWithQuery) Error() string { + return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.queryFmt, e.query, e.args...)) } diff --git a/impl/errors_test.go b/errors_test.go similarity index 52% rename from impl/errors_test.go rename to errors_test.go index 0582e77..7e53269 100644 --- a/impl/errors_test.go +++ b/errors_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql" @@ -7,12 +7,12 @@ import ( "testing" ) -func TestWrapNonNilErrorWithQuery(t *testing.T) { +func TestWrapErrorWithQuery(t *testing.T) { type args struct { - err error - query string - argFmt string - args []any + err error + query string + args []any + queryFmt QueryFormatter } tests := []struct { name string @@ -23,12 +23,12 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { { name: "select no rows", args: args{ - err: sql.ErrNoRows, - query: `SELECT * FROM table WHERE b = $2 and a = $1`, - argFmt: "$%d", - args: []any{1, "2"}, + err: sql.ErrNoRows, + query: `SELECT * FROM table WHERE b = $2 AND a = $1`, + queryFmt: StdQueryFormatter{PlaceholderPosPrefix: "$"}, + args: []any{1, "2"}, }, - wantError: fmt.Sprintf("%s from query: %s", sql.ErrNoRows, `SELECT * FROM table WHERE b = '2' and a = 1`), + wantError: fmt.Sprintf("%s from query: %s", sql.ErrNoRows, `SELECT * FROM table WHERE b = '2' AND a = 1`), }, { name: "multi line", @@ -38,9 +38,9 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { SELECT * FROM table WHERE b = $2 - and a = $1`, - argFmt: "$%d", - args: []any{1, "2"}, + AND a = $1`, + queryFmt: StdQueryFormatter{PlaceholderPosPrefix: "$"}, + args: []any{1, "2"}, }, wantError: fmt.Sprintf( "%s from query: %s", @@ -48,18 +48,18 @@ func TestWrapNonNilErrorWithQuery(t *testing.T) { `SELECT * FROM table WHERE b = '2' - and a = 1`, + AND a = 1`, ), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := WrapNonNilErrorWithQuery(tt.args.err, tt.args.query, tt.args.argFmt, tt.args.args) + err := WrapErrorWithQuery(tt.args.err, tt.args.query, tt.args.args, tt.args.queryFmt) if tt.wantError == "" && err != nil || tt.wantError != "" && (err == nil || err.Error() != tt.wantError) { - t.Errorf("WrapNonNilErrorWithQuery() error = \n%s\nwantErr\n%s", err, tt.wantError) + t.Errorf("WrapErrorWithQuery() error = \n%s\nwantErr\n%s", err, tt.wantError) } if !errors.Is(err, tt.args.err) { - t.Errorf("WrapNonNilErrorWithQuery() error = %v does not wrap %v", err, tt.args.err) + t.Errorf("WrapErrorWithQuery() error = %v does not wrap %v", err, tt.args.err) } }) } diff --git a/examples/user_demo/go.mod b/examples/user_demo/go.mod index 93ce6b3..621ee78 100644 --- a/examples/user_demo/go.mod +++ b/examples/user_demo/go.mod @@ -1,33 +1,41 @@ module github.com/domonda/go-sqldb/examples/user_demo -go 1.23 +go 1.24 -toolchain go1.23.1 - -replace github.com/domonda/go-sqldb => ../.. +replace ( + github.com/domonda/go-sqldb => ../.. + github.com/domonda/go-sqldb/pqconn => ../../pqconn +) -require github.com/domonda/go-sqldb v0.0.0-00010101000000-000000000000 // replaced +require ( + github.com/domonda/go-sqldb v0.0.0-00010101000000-000000000000 // replaced + github.com/domonda/go-sqldb/pqconn v0.0.0-00010101000000-000000000000 // replaced +) -require github.com/domonda/go-types v0.0.0-20240924082825-270782de7296 +require github.com/domonda/go-types v0.0.0-20250725104804-d473d8b9dd16 require ( + github.com/DataDog/go-sqllexer v0.1.6 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cention-sany/utf7 v0.0.0-20170124080048-26cad61bd60a // indirect - github.com/domonda/go-errs v0.0.0-20240702051036-0e696c849b5f // indirect - github.com/domonda/go-pretty v0.0.0-20240110134850-17385799142f // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/domonda/go-errs v0.0.0-20250603150208-71d6de0c48ea // indirect + github.com/domonda/go-pretty v0.0.0-20250602142956-1b467adc6387 // indirect github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 // indirect github.com/jhillyerd/enmime v1.3.0 // indirect github.com/lib/pq v1.10.9 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/teamwork/tnef v0.0.0-20200108124832-7deabccfdb32 // indirect - github.com/ungerik/go-fs v0.0.0-20240919125757-1b6f933a416d // indirect - golang.org/x/net v0.29.0 // indirect - golang.org/x/sys v0.25.0 // indirect - golang.org/x/text v0.18.0 // indirect - mvdan.cc/xurls/v2 v2.5.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/text v0.26.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + mvdan.cc/xurls/v2 v2.6.0 // indirect ) diff --git a/examples/user_demo/go.sum b/examples/user_demo/go.sum index e3fef5e..16ad65e 100644 --- a/examples/user_demo/go.sum +++ b/examples/user_demo/go.sum @@ -1,36 +1,40 @@ +github.com/DataDog/go-sqllexer v0.1.6 h1:skEXpWEVCpeZFIiydoIa2f2rf+ymNpjiIMqpW4w3YAk= +github.com/DataDog/go-sqllexer v0.1.6/go.mod h1:GGpo1h9/BVSN+6NJKaEcJ9Jn44Hqc63Rakeb+24Mjgo= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cention-sany/utf7 v0.0.0-20170124080048-26cad61bd60a h1:MISbI8sU/PSK/ztvmWKFcI7UGb5/HQT7B+i3a2myKgI= github.com/cention-sany/utf7 v0.0.0-20170124080048-26cad61bd60a/go.mod h1:2GxOXOlEPAMFPfp014mK1SWq8G8BN8o7/dfYqJrVGn8= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/domonda/go-errs v0.0.0-20240301142737-8fde935c9bd4 h1:qidwzgjM8qrKy326iXVNHNN/qB89o1lfiAi7pMuNbQU= -github.com/domonda/go-errs v0.0.0-20240301142737-8fde935c9bd4/go.mod h1:NnvsIo+bzAany1nQLMViGDgJ8kx3k5N/D1+UJz3hEXc= -github.com/domonda/go-errs v0.0.0-20240702051036-0e696c849b5f h1:9CZgqCVP/7eixUjU+A+ozHo+oxRKJSkFgRtakoB5byc= -github.com/domonda/go-errs v0.0.0-20240702051036-0e696c849b5f/go.mod h1:qLWt1z3aIg12+Dbxu9bMydFOHEi92vWE7vAHcHLd8n8= -github.com/domonda/go-pretty v0.0.0-20240110134850-17385799142f h1:5eA74m451PqlqCXyJzWXp95Quj4PZ6Lm/ndKBuiNhe4= -github.com/domonda/go-pretty v0.0.0-20240110134850-17385799142f/go.mod h1:3QkM8UJdyJMeKZiIo7hYzSkQBpRS3k0gOHw4ysyEIB4= -github.com/domonda/go-types v0.0.0-20240309180027-0196423b3d5b h1:KHVyZmDdpN2PsjoEeBfxX9F1AvIbZubbsrFoB0wyVn4= -github.com/domonda/go-types v0.0.0-20240309180027-0196423b3d5b/go.mod h1:iLpZ3myjpxZgM1Q8Z+Jg8WDCzHjuVj5U3WZyh+2QBac= -github.com/domonda/go-types v0.0.0-20240924082825-270782de7296 h1:9NOpTrmdnxFMdZft6idPmaqcIyxbKZYl7Ija7bFaric= -github.com/domonda/go-types v0.0.0-20240924082825-270782de7296/go.mod h1:QfZG5NrNWDrwcqOp3ZlNh2XaLjZI1ncNpGPAa9MIUUE= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/domonda/go-errs v0.0.0-20250603150208-71d6de0c48ea h1:jJkN+JvDKnzxM0yu+ob0sOLCyN95gevMeYF5VBKDg6w= +github.com/domonda/go-errs v0.0.0-20250603150208-71d6de0c48ea/go.mod h1:d1vM8jnNOby2gJSsbnCYPE/WadNbdxHTCE0sDUTMVSs= +github.com/domonda/go-pretty v0.0.0-20250602142956-1b467adc6387 h1:ZSMYEHfFwpMlVJ+yzPXOSOfikWBNdzcnC0YxxNQxkDk= +github.com/domonda/go-pretty v0.0.0-20250602142956-1b467adc6387/go.mod h1:3QkM8UJdyJMeKZiIo7hYzSkQBpRS3k0gOHw4ysyEIB4= +github.com/domonda/go-types v0.0.0-20250725104804-d473d8b9dd16 h1:Vf9f93nsItPIaLPD2/vjsMmSakEjdkkMPEJK6zJv1vg= +github.com/domonda/go-types v0.0.0-20250725104804-d473d8b9dd16/go.mod h1:5esmMaEB57phklyiGu9a9/ttw338cZBZDCcxoO8A7kY= github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f h1:3BSP1Tbs2djlpprl7wCLuiqMaUh5SJkkzI2gDs+FgLs= github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f/go.mod h1:Pcatq5tYkCW2Q6yrR2VRHlbHpZ/R4/7qyL1TCF7vl14= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 h1:iCHtR9CQyktQ5+f3dMVZfwD2KWJUgm7M0gdL9NGr8KA= github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056/go.mod h1:CVKlgaMiht+LXvHG173ujK6JUhZXKb2u/BQtjPDIvyk= -github.com/jhillyerd/enmime v1.2.0 h1:dIu1IPEymQgoT2dzuB//ttA/xcV40NMPpQtmd4wslHk= -github.com/jhillyerd/enmime v1.2.0/go.mod h1:FRFuUPCLh8PByQv+8xRcLO9QHqaqTqreYhopv5eyk4I= github.com/jhillyerd/enmime v1.3.0 h1:LV5kzfLidiOr8qRGIpYYmUZCnhrPbcFAnAFUnWn99rw= github.com/jhillyerd/enmime v1.3.0/go.mod h1:6c6jg5HdRRV2FtvVL69LjiX1M8oE0xDX9VEhV3oy4gs= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= @@ -42,35 +46,30 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/teamwork/test v0.0.0-20200108114543-02621bae84ad h1:25sEr0awm0ZPancg5W5H5VvN7PWsJloUBpii10a9isw= github.com/teamwork/test v0.0.0-20200108114543-02621bae84ad/go.mod h1:TIbx7tx6WHBjQeLRM4eWQZBL7kmBZ7/KI4x4v7Y5YmA= github.com/teamwork/tnef v0.0.0-20200108124832-7deabccfdb32 h1:j15wq0XPAY/HR/0+dtwUrIrF2ZTKbk7QIES2p4dAG+k= github.com/teamwork/tnef v0.0.0-20200108124832-7deabccfdb32/go.mod h1:v7dFaQrF/4+curx7UTH9rqTkHTgXqghfI3thANW150o= github.com/teamwork/utils v0.0.0-20220314153103-637fa45fa6cc h1:BidxxRk9kopF5IGEyosTRtanaYVYTUbGJh9eULOhv04= github.com/teamwork/utils v0.0.0-20220314153103-637fa45fa6cc/go.mod h1:3Fn0qxFeRNpvsg/9T1+btOOOKkd1qG2nPYKKcOmNpcs= -github.com/ungerik/go-fs v0.0.0-20240118121925-91844f9bdba8 h1:LkAUtMadwzxaMYrdOpWlPJ4jdquUl5xafd0cQwRPqVw= -github.com/ungerik/go-fs v0.0.0-20240118121925-91844f9bdba8/go.mod h1:uJoyhNruti7dh2/DTNIF+N8s/sCd9uIhCBT8gzk6190= -github.com/ungerik/go-fs v0.0.0-20240919125757-1b6f933a416d h1:71JniF82NUc6v7nBx23OMSzdYiV5phxvTIU8XsRMdnU= -github.com/ungerik/go-fs v0.0.0-20240919125757-1b6f933a416d/go.mod h1:nMIa35zyLzk4K3tTLL+AAsOZ9Q+0lgX/lxYubEwCZSY= -github.com/ungerik/go-reflection v0.0.0-20240110134735-61cada706fec h1:QiS/w0cXNtHs0xhs+Pa2Pp71CTeM9z7zVgbxV+CvezM= -github.com/ungerik/go-reflection v0.0.0-20240110134735-61cada706fec/go.mod h1:6mOx6LfN4Xbb4fyHO6syugCjbx88cgpbxekcx4W1mpM= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +github.com/ungerik/go-reflection v0.0.0-20250602142243-03da83aecd0d h1:ctOx9QLFjuGij9QUMk3XoJWnbeC/O8kR8SRRNK9TK9U= +github.com/ungerik/go-reflection v0.0.0-20250602142243-03da83aecd0d/go.mod h1:2HaymCMIvGNYIy+2JDI9RdPytWuP/Q8fJSGcS+2mb20= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8= -mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= +mvdan.cc/xurls/v2 v2.6.0 h1:3NTZpeTxYVWNSokW3MKeyVkz/j7uYXYiMtXRUfmjbgI= +mvdan.cc/xurls/v2 v2.6.0/go.mod h1:bCvEZ1XvdA6wDnxY7jPPjEmigDtvtvPXAD/Exa9IMSk= diff --git a/examples/user_demo/user_demo.go b/examples/user_demo/user_demo.go index b543f2c..a244282 100644 --- a/examples/user_demo/user_demo.go +++ b/examples/user_demo/user_demo.go @@ -16,7 +16,9 @@ import ( ) type User struct { - ID uu.ID `db:"id,pk,default"` + sqldb.TableName `db:"public.user"` + + ID uu.ID `db:"id,primarykey,default"` Email email.NullableAddress `db:"email"` Title nullable.NonEmptyString `db:"title"` @@ -38,14 +40,14 @@ func main() { Extra: map[string]string{"sslmode": "disable"}, } - fmt.Println("Connecting to:", config.ConnectURL()) + fmt.Println("Connecting to:", config) conn, err := pqconn.New(context.Background(), config) if err != nil { panic(err) } - conn = conn.WithStructFieldMapper(&sqldb.TaggedStructFieldMapping{ + conn = conn.WithStructFieldMapper(&sqldb.TaggedStructReflector{ NameTag: "col", Ignore: "ignore", UntaggedNameFunc: sqldb.ToSnakeCase, @@ -88,7 +90,7 @@ func main() { } newUser := &User{ /* ... */ } - err = conn.InsertStruct("public.user", newUser) + err = db.InsertRowStruct(ctx, newUser) if err != nil { panic(err) } @@ -98,7 +100,7 @@ func main() { panic(err) } - err = conn.Insert("public.user", sqldb.Values{ + err = conn.InsertStruct(sqldb.Values{ "name": "Erik Unger", "email": "erik@domonda.com", }) diff --git a/exec.go b/exec.go new file mode 100644 index 0000000..2167b41 --- /dev/null +++ b/exec.go @@ -0,0 +1,29 @@ +package sqldb + +import "context" + +// Exec executes a query with optional args. +func Exec(ctx context.Context, conn Executor, query string, args ...any) error { + err := conn.Exec(ctx, query, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, GetQueryFormatter(conn)) + } + return nil +} + +// ExecStmt returns a function that can be used to execute a prepared statement +// with optional args. +func ExecStmt(ctx context.Context, conn Preparer, query string) (execFunc func(ctx context.Context, args ...any) error, closeStmt func() error, err error) { + stmt, err := conn.Prepare(ctx, query) + if err != nil { + return nil, nil, WrapErrorWithQuery(err, query, nil, GetQueryFormatter(conn)) + } + execFunc = func(ctx context.Context, args ...any) error { + err := stmt.Exec(ctx, args...) + if err != nil { + return WrapErrorWithQuery(err, query, args, GetQueryFormatter(conn)) + } + return nil + } + return execFunc, stmt.Close, nil +} diff --git a/impl/format.go b/format.go similarity index 70% rename from impl/format.go rename to format.go index 98dbdc3..6d83cd3 100644 --- a/impl/format.go +++ b/format.go @@ -1,16 +1,15 @@ -package impl +package sqldb import ( "database/sql/driver" "encoding/hex" "fmt" "reflect" + "slices" "strings" "time" "unicode" "unicode/utf8" - - "github.com/domonda/go-sqldb" ) const timeFormat = "'2006-01-02 15:04:05.999999Z07:00:00'" @@ -79,61 +78,53 @@ func FormatValue(val any) (string, error) { return fmt.Sprint(val), nil } -func FormatQuery(query, argFmt string, args ...any) string { - for i := len(args) - 1; i >= 0; i-- { - placeholder := fmt.Sprintf(argFmt, i+1) - value, err := FormatValue(args[i]) - if err != nil { - value = "FORMATERROR:" + err.Error() - } - query = strings.ReplaceAll(query, placeholder, value) +func NormalizeAndFormatQuery(normalize NormalizeQueryFunc, formatter QueryFormatter, query string, args ...any) (string, error) { + if normalize == nil { + normalize = NoChangeNormalizeQuery } - - lines := strings.Split(query, "\n") - if len(lines) == 1 { - return strings.TrimSpace(query) + if formatter == nil { + formatter = NewQueryFormatter("$") } - - // Trim whitespace at end of line and remove empty lines - for i := 0; i < len(lines); i++ { - lines[i] = strings.TrimRightFunc(lines[i], unicode.IsSpace) - if lines[i] == "" { - lines = append(lines[:i], lines[i+1:]...) - i-- - } + query, err := normalize(query) + if err != nil { + return "", err } + return FormatQuery(formatter, query, args...), nil +} - // Remove identical whitespace at beginning of each line - firstLineRune, runeSize := utf8.DecodeRuneInString(lines[0]) - for unicode.IsSpace(firstLineRune) { - identical := true - for i := 1; i < len(lines); i++ { - lineRune, _ := utf8.DecodeRuneInString(lines[i]) - if lineRune != firstLineRune { - identical = false - break - } - } - if !identical { - break - } - for i := range lines { - lines[i] = lines[i][runeSize:] - } - firstLineRune, _ = utf8.DecodeRuneInString(lines[0]) +func MustNormalizeAndFormatQuery(normalize NormalizeQueryFunc, formatter QueryFormatter, query string, args ...any) string { + query, err := NormalizeAndFormatQuery(normalize, formatter, query, args...) + if err != nil { + panic("NormalizeAndFormatQuery error: " + err.Error()) } - - return strings.Join(lines, "\n") + return query } -func FormatQuery2(query string, argFmt sqldb.PlaceholderFormatter, args ...any) string { - for i := len(args) - 1; i >= 0; i-- { - placeholder := argFmt.Placeholder(i) - value, err := FormatValue(args[i]) - if err != nil { - value = "FORMATERROR:" + err.Error() +func FormatQuery(f QueryFormatter, query string, args ...any) string { + if len(args) > 0 { + if placeholder := f.FormatPlaceholder(0); f.FormatPlaceholder(1) == placeholder { + // Uniform placeholders, replace every instance with one arg + for _, arg := range args { + value, err := FormatValue(arg) + if err != nil { + value = "FORMATERROR:" + err.Error() + } + // Note that this will replace placeholders in comments and strings + query = strings.Replace(query, placeholder, value, 1) + } + } else { + // Numbered placeholders, replace in reverse order + // to avoid replacing shorter placeholders contained in longer ones + for i := len(args) - 1; i >= 0; i-- { + placeholder := f.FormatPlaceholder(i) + value, err := FormatValue(args[i]) + if err != nil { + value = "FORMATERROR:" + err.Error() + } + // Note that this will replace placeholders in comments and strings + query = strings.ReplaceAll(query, placeholder, value) + } } - query = strings.ReplaceAll(query, placeholder, value) } lines := strings.Split(query, "\n") @@ -145,7 +136,7 @@ func FormatQuery2(query string, argFmt sqldb.PlaceholderFormatter, args ...any) for i := 0; i < len(lines); i++ { lines[i] = strings.TrimRightFunc(lines[i], unicode.IsSpace) if lines[i] == "" { - lines = append(lines[:i], lines[i+1:]...) + lines = slices.Delete(lines, i, i+1) i-- } } diff --git a/impl/format_test.go b/format_test.go similarity index 86% rename from impl/format_test.go rename to format_test.go index 8d6ab7c..2f85156 100644 --- a/impl/format_test.go +++ b/format_test.go @@ -1,4 +1,4 @@ -package impl +package sqldb import ( "database/sql/driver" @@ -80,17 +80,17 @@ WHERE tests := []struct { name string + argFmt QueryFormatter query string - argFmt string args []any want string }{ - {name: "query1", query: query1, argFmt: "$%d", args: []any{createdAt, true, `Erik's Test`}, want: query1formatted}, - {name: "query2", query: query2, argFmt: "$%d", args: []any{"", 2, "3"}, want: query2formatted}, + {name: "query1", query: query1, argFmt: NewQueryFormatter("$"), args: []any{createdAt, true, `Erik's Test`}, want: query1formatted}, + {name: "query2", query: query2, argFmt: NewQueryFormatter("$"), args: []any{"", 2, "3"}, want: query2formatted}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := FormatQuery(tt.query, tt.argFmt, tt.args...); got != tt.want { + if got := FormatQuery(tt.argFmt, tt.query, tt.args...); got != tt.want { t.Errorf("FormatQuery():\n%q\nWant:\n%q", got, tt.want) } }) diff --git a/genericconn.go b/genericconn.go new file mode 100644 index 0000000..3b734de --- /dev/null +++ b/genericconn.go @@ -0,0 +1,98 @@ +package sqldb + +import ( + "context" + "database/sql" + "errors" + "time" +) + +// NewGenericConn returns a generic Connection implementation +// for an existing sql.DB connection. +func NewGenericConn(db *sql.DB, config *Config, queryFmt QueryFormatter, defaultIsolationLevel sql.IsolationLevel) Connection { + return &genericConn{ + QueryFormatter: queryFmt, + db: db, + config: config, + defaultIsolationLevel: defaultIsolationLevel, + } +} + +type genericConn struct { + QueryFormatter + db *sql.DB + config *Config + defaultIsolationLevel sql.IsolationLevel +} + +func (conn *genericConn) Ping(ctx context.Context, timeout time.Duration) error { + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + return conn.db.PingContext(ctx) +} + +func (conn *genericConn) Stats() sql.DBStats { + return conn.db.Stats() +} + +func (conn *genericConn) Config() *Config { + return conn.config +} + +func (conn *genericConn) Exec(ctx context.Context, query string, args ...any) error { + _, err := conn.db.ExecContext(ctx, query, args...) + return err +} + +func (conn *genericConn) Query(ctx context.Context, query string, args ...any) Rows { + rows, err := conn.db.QueryContext(ctx, query, args...) + if err != nil { + return NewErrRows(err) + } + return rows +} + +func (conn *genericConn) Prepare(ctx context.Context, query string) (Stmt, error) { + stmt, err := conn.db.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return NewStmt(stmt, query), nil +} + +func (conn *genericConn) DefaultIsolationLevel() sql.IsolationLevel { + return conn.defaultIsolationLevel +} + +func (conn *genericConn) Transaction() TransactionState { + return TransactionState{ + ID: 0, + Opts: nil, + } +} + +func (conn *genericConn) Begin(ctx context.Context, id uint64, opts *sql.TxOptions) (Connection, error) { + if id == 0 { + return nil, errors.New("transaction ID must not be zero") + } + tx, err := conn.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return newGenericTx(conn, tx, opts, id), nil +} + +func (conn *genericConn) Commit() error { + return ErrNotWithinTransaction +} + +func (conn *genericConn) Rollback() error { + return ErrNotWithinTransaction +} + +func (conn *genericConn) Close() error { + return conn.db.Close() +} diff --git a/generictx.go b/generictx.go new file mode 100644 index 0000000..26432ea --- /dev/null +++ b/generictx.go @@ -0,0 +1,89 @@ +package sqldb + +import ( + "context" + "database/sql" + "errors" + "time" +) + +type genericTx struct { + QueryFormatter + // The parent non-transaction connection is needed + // for Ping(), Stats(), Config(), and Begin(). + parent *genericConn + tx *sql.Tx + opts *sql.TxOptions + id uint64 +} + +func newGenericTx(parent *genericConn, tx *sql.Tx, opts *sql.TxOptions, id uint64) *genericTx { + return &genericTx{ + QueryFormatter: parent.QueryFormatter, + parent: parent, + tx: tx, + opts: opts, + id: id, + } +} + +func (conn *genericTx) Ping(ctx context.Context, timeout time.Duration) error { + return conn.parent.Ping(ctx, timeout) +} +func (conn *genericTx) Stats() sql.DBStats { return conn.parent.Stats() } +func (conn *genericTx) Config() *Config { return conn.parent.Config() } + +func (conn *genericTx) Exec(ctx context.Context, query string, args ...any) error { + _, err := conn.tx.ExecContext(ctx, query, args...) + return err +} + +func (conn *genericTx) Query(ctx context.Context, query string, args ...any) Rows { + rows, err := conn.tx.QueryContext(ctx, query, args...) + if err != nil { + return NewErrRows(err) + } + return rows +} + +func (conn *genericTx) Prepare(ctx context.Context, query string) (Stmt, error) { + stmt, err := conn.tx.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return NewStmt(stmt, query), nil +} + +func (conn *genericTx) DefaultIsolationLevel() sql.IsolationLevel { + return conn.parent.defaultIsolationLevel +} + +func (conn *genericTx) Transaction() TransactionState { + return TransactionState{ + ID: conn.id, + Opts: conn.opts, + } +} + +func (conn *genericTx) Begin(ctx context.Context, id uint64, opts *sql.TxOptions) (Connection, error) { + if id == 0 { + return nil, errors.New("transaction ID must not be zero") + } + tx, err := conn.parent.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return newGenericTx(conn.parent, tx, opts, id), nil +} + +func (conn *genericTx) Commit() error { + return conn.tx.Commit() +} + +func (conn *genericTx) Rollback() error { + return conn.tx.Rollback() +} + +func (conn *genericTx) Close() error { + return conn.Rollback() +} diff --git a/go.mod b/go.mod index 18232cd..85b7a32 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,25 @@ module github.com/domonda/go-sqldb go 1.23 require ( - github.com/domonda/go-errs v0.0.0-20240702051036-0e696c849b5f - github.com/domonda/go-types v0.0.0-20241104173616-e85c6dede426 - github.com/lib/pq v1.10.9 + github.com/DataDog/go-sqllexer v0.1.6 + github.com/domonda/go-types v0.0.0-20250725104804-d473d8b9dd16 github.com/stretchr/testify v1.10.0 ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/domonda/go-pretty v0.0.0-20240110134850-17385799142f // indirect + github.com/domonda/go-errs v0.0.0-20250603150208-71d6de0c48ea // indirect + github.com/domonda/go-pretty v0.0.0-20250602142956-1b467adc6387 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/kr/pretty v0.3.1 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect - github.com/ungerik/go-reflection v0.0.0-20240905081803-708928fe0862 // indirect - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/ungerik/go-reflection v0.0.0-20250602142243-03da83aecd0d // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8068665..6e4b948 100644 --- a/go.sum +++ b/go.sum @@ -1,32 +1,45 @@ +github.com/DataDog/go-sqllexer v0.1.6 h1:skEXpWEVCpeZFIiydoIa2f2rf+ymNpjiIMqpW4w3YAk= +github.com/DataDog/go-sqllexer v0.1.6/go.mod h1:GGpo1h9/BVSN+6NJKaEcJ9Jn44Hqc63Rakeb+24Mjgo= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/domonda/go-errs v0.0.0-20240702051036-0e696c849b5f h1:9CZgqCVP/7eixUjU+A+ozHo+oxRKJSkFgRtakoB5byc= -github.com/domonda/go-errs v0.0.0-20240702051036-0e696c849b5f/go.mod h1:qLWt1z3aIg12+Dbxu9bMydFOHEi92vWE7vAHcHLd8n8= -github.com/domonda/go-pretty v0.0.0-20240110134850-17385799142f h1:5eA74m451PqlqCXyJzWXp95Quj4PZ6Lm/ndKBuiNhe4= -github.com/domonda/go-pretty v0.0.0-20240110134850-17385799142f/go.mod h1:3QkM8UJdyJMeKZiIo7hYzSkQBpRS3k0gOHw4ysyEIB4= -github.com/domonda/go-types v0.0.0-20241104173616-e85c6dede426 h1:pWWcXqt8jvIGsqpo+o2RPe1Rx5lyFRj6lUKN2sTJ+rU= -github.com/domonda/go-types v0.0.0-20241104173616-e85c6dede426/go.mod h1:QfZG5NrNWDrwcqOp3ZlNh2XaLjZI1ncNpGPAa9MIUUE= +github.com/domonda/go-errs v0.0.0-20250603150208-71d6de0c48ea h1:jJkN+JvDKnzxM0yu+ob0sOLCyN95gevMeYF5VBKDg6w= +github.com/domonda/go-errs v0.0.0-20250603150208-71d6de0c48ea/go.mod h1:d1vM8jnNOby2gJSsbnCYPE/WadNbdxHTCE0sDUTMVSs= +github.com/domonda/go-pretty v0.0.0-20250602142956-1b467adc6387 h1:ZSMYEHfFwpMlVJ+yzPXOSOfikWBNdzcnC0YxxNQxkDk= +github.com/domonda/go-pretty v0.0.0-20250602142956-1b467adc6387/go.mod h1:3QkM8UJdyJMeKZiIo7hYzSkQBpRS3k0gOHw4ysyEIB4= +github.com/domonda/go-types v0.0.0-20250725104804-d473d8b9dd16 h1:Vf9f93nsItPIaLPD2/vjsMmSakEjdkkMPEJK6zJv1vg= +github.com/domonda/go-types v0.0.0-20250725104804-d473d8b9dd16/go.mod h1:5esmMaEB57phklyiGu9a9/ttw338cZBZDCcxoO8A7kY= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/ungerik/go-reflection v0.0.0-20240905081803-708928fe0862 h1:rxp/NtuHYkx0HRYL/Y7xh/07ZwI/Pbk3VPkVoq3IUgQ= -github.com/ungerik/go-reflection v0.0.0-20240905081803-708928fe0862/go.mod h1:Ic/uip1MCECqTPItawo5lRHmyaOT6vCM0UuKrczg6LY= +github.com/ungerik/go-reflection v0.0.0-20250602142243-03da83aecd0d h1:ctOx9QLFjuGij9QUMk3XoJWnbeC/O8kR8SRRNK9TK9U= +github.com/ungerik/go-reflection v0.0.0-20250602142243-03da83aecd0d/go.mod h1:2HaymCMIvGNYIy+2JDI9RdPytWuP/Q8fJSGcS+2mb20= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go.work b/go.work index 28524b6..dfc124f 100644 --- a/go.work +++ b/go.work @@ -4,4 +4,5 @@ use ( . ./mssqlconn ./mysqlconn + ./pqconn ) diff --git a/go.work.sum b/go.work.sum index 295db62..3d7e9af 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,6 +1,27 @@ +cel.dev/expr v0.23.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= +cloud.google.com/go v0.121.2/go.mod h1:nRFlrHq39MNVWu+zESP2PosMWA0ryJw8KUBZ2iZpxbw= +cloud.google.com/go/ai v0.12.1/go.mod h1:5vIPNe1ZQsVZqCliXIPL4QnhObQQY4d9hAGHdVc4iw4= +cloud.google.com/go/auth v0.16.2/go.mod h1:sRBas2Y1fB1vZTdurouM0AzuYQBMZinrUYL8EufhtEA= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= +cloud.google.com/go/iam v1.5.2/go.mod h1:SE1vg0N81zQqLzQEwxL2WI6yhetBdbNQuTvIKCSkUHE= +cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY= +cloud.google.com/go/monitoring v1.24.0/go.mod h1:Bd1PRK5bmQBQNnuGwHBfUamAV1ys9049oEPHnn4pcsc= +cloud.google.com/go/storage v1.53.0/go.mod h1:7/eO2a/srr9ImZW9k5uufcNahT2+fPb8w5it1i5boaA= +cloud.google.com/go/translate v1.10.3/go.mod h1:GW0vC1qvPtd3pgtypCv4k4U8B7EdgK9/QEF2aJEUovs= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0/go.mod h1:yAZHSGnqScoU556rBOVkwLze6WP5N+U11RHuWaGVxwY= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0/go.mod h1:BnBReJLvVYx2CS/UHOgVz2BXKXD9wsQPxZug20nZhd0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0/go.mod h1:otE2jQekW/PqXk1Awf5lmfokJx4uwuqcj1ab5SpGeW0= github.com/Strum355/go-difflib v1.1.0/go.mod h1:r1cVg1JkGsTWkaR7At56v7hfuMgiUL8meTLwxFzOmvE= +github.com/ccojocar/zxcvbn-go v1.0.4/go.mod h1:3GxGX+rHmueTUMvm5ium7irpyjmm7ikxYFOSJB21Das= github.com/cention-sany/utf7 v0.0.0-20170124080048-26cad61bd60a h1:MISbI8sU/PSK/ztvmWKFcI7UGb5/HQT7B+i3a2myKgI= github.com/cention-sany/utf7 v0.0.0-20170124080048-26cad61bd60a/go.mod h1:2GxOXOlEPAMFPfp014mK1SWq8G8BN8o7/dfYqJrVGn8= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/domonda/go-errs v0.0.0-20230810132956-1b6272f9fc8f h1:OQaXlKXZc52Vsz7iH23NhddeMr0niW0tetB8Fq3k4yQ= github.com/domonda/go-errs v0.0.0-20230810132956-1b6272f9fc8f/go.mod h1:DYkFE3rxUGhTCMmR5MpQ2NTtoCPiORdjBATGkIEeGKM= @@ -8,14 +29,44 @@ github.com/domonda/go-errs v0.0.0-20230920094343-6b122da4d22f h1:ECYzMHlxXTmVwOY github.com/domonda/go-errs v0.0.0-20230920094343-6b122da4d22f/go.mod h1:DYkFE3rxUGhTCMmR5MpQ2NTtoCPiORdjBATGkIEeGKM= github.com/domonda/go-errs v0.0.0-20240301142737-8fde935c9bd4 h1:qidwzgjM8qrKy326iXVNHNN/qB89o1lfiAi7pMuNbQU= github.com/domonda/go-errs v0.0.0-20240301142737-8fde935c9bd4/go.mod h1:NnvsIo+bzAany1nQLMViGDgJ8kx3k5N/D1+UJz3hEXc= +github.com/domonda/go-errs v0.0.0-20250509130707-0373cd8156d7/go.mod h1:d1vM8jnNOby2gJSsbnCYPE/WadNbdxHTCE0sDUTMVSs= +github.com/domonda/go-types v0.0.0-20240822142828-3b45a403e1e2/go.mod h1:Voo8dh8EeexiIXpvvmFM3WPyXtiukUjtwLpgOr9rKNM= +github.com/domonda/go-types v0.0.0-20250527155803-62531baa899e h1:L5L+wACuTwmBBE9MHzb2obR81wYaQt4/ZMRRvoVhiP0= +github.com/domonda/go-types v0.0.0-20250527155803-62531baa899e/go.mod h1:JUvMIVKkftP8sJ41yWJ6p3ZNSc5nuQxVpWyPEpDML0s= +github.com/envoyproxy/go-control-plane v0.13.4/go.mod h1:kDfuBlDVsSj2MjrLEtRWtHlsWIFcGyB2RMO44Dc5GZA= +github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= +github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f h1:3BSP1Tbs2djlpprl7wCLuiqMaUh5SJkkzI2gDs+FgLs= github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f/go.mod h1:Pcatq5tYkCW2Q6yrR2VRHlbHpZ/R4/7qyL1TCF7vl14= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/generative-ai-go v0.20.1/go.mod h1:TjOnZJmZKzarWbjUJgy+r3Ee7HGBRVLhOIgupnwR4Bg= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-pkcs11 v0.3.0/go.mod h1:6eQoGcuNJpa7jnd5pMGdkSaQpNDYvPlXWMcjXXThLlY= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= +github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.14.2/go.mod h1:ON64QhlJkhVtSqp4v1uaK92VyZ2gmvDQsweuyLV+8+w= +github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 h1:iCHtR9CQyktQ5+f3dMVZfwD2KWJUgm7M0gdL9NGr8KA= github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056/go.mod h1:CVKlgaMiht+LXvHG173ujK6JUhZXKb2u/BQtjPDIvyk= @@ -30,25 +81,32 @@ github.com/jhillyerd/enmime v1.0.0/go.mod h1:EktNOa/V6ka9yCrfoB2uxgefp1lno6OVdsz github.com/jhillyerd/enmime v1.2.0 h1:dIu1IPEymQgoT2dzuB//ttA/xcV40NMPpQtmd4wslHk= github.com/jhillyerd/enmime v1.2.0/go.mod h1:FRFuUPCLh8PByQv+8xRcLO9QHqaqTqreYhopv5eyk4I= github.com/jhillyerd/enmime v1.3.0/go.mod h1:6c6jg5HdRRV2FtvVL69LjiX1M8oE0xDX9VEhV3oy4gs= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mozilla/tls-observatory v0.0.0-20210609171429-7bc42856d2e5/go.mod h1:FUqVoUPHSEdDR0MnFM3Dh8AU0pZHLXUD127SAJGER/s= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= +github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/securego/gosec/v2 v2.22.5/go.mod h1:AWfgrFsVewk5LKobsPWlygCHt8K91boVPyL6GUZG5NY= +github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/teamwork/test v0.0.0-20190410143529-8897d82f8d46/go.mod h1:TIbx7tx6WHBjQeLRM4eWQZBL7kmBZ7/KI4x4v7Y5YmA= github.com/teamwork/test v0.0.0-20200108114543-02621bae84ad h1:25sEr0awm0ZPancg5W5H5VvN7PWsJloUBpii10a9isw= github.com/teamwork/test v0.0.0-20200108114543-02621bae84ad/go.mod h1:TIbx7tx6WHBjQeLRM4eWQZBL7kmBZ7/KI4x4v7Y5YmA= @@ -62,38 +120,116 @@ github.com/ungerik/go-fs v0.0.0-20240118121925-91844f9bdba8 h1:LkAUtMadwzxaMYrdO github.com/ungerik/go-fs v0.0.0-20240118121925-91844f9bdba8/go.mod h1:uJoyhNruti7dh2/DTNIF+N8s/sCd9uIhCBT8gzk6190= github.com/ungerik/go-fs v0.0.0-20240702143946-3ecb6733945d/go.mod h1:+8Ezjyw6fCooNzoVofQIhWLXLe6E23nZ/9cfQ79Wzo0= github.com/ungerik/go-fs v0.0.0-20240919065241-437d7c2c9f63/go.mod h1:nMIa35zyLzk4K3tTLL+AAsOZ9Q+0lgX/lxYubEwCZSY= +github.com/ungerik/go-fs v0.0.0-20241213130555-c93eabeaac28/go.mod h1:pKXmcBwT1D8y0dEHrOFDYGtYp4NNn1g9nDd/9vLQ6Dg= +github.com/ungerik/go-fs v0.0.0-20250123134246-3ac71b34b8e3/go.mod h1:pKXmcBwT1D8y0dEHrOFDYGtYp4NNn1g9nDd/9vLQ6Dg= +github.com/ungerik/go-fs v0.0.0-20250310161700-3b05d22755dd/go.mod h1:pKXmcBwT1D8y0dEHrOFDYGtYp4NNn1g9nDd/9vLQ6Dg= +github.com/ungerik/go-fs v0.0.0-20250410112719-c5187364824d/go.mod h1:l1gaYFOhIK4gUWy7xuaskrq/zP+lK6xGyImqAyt4mSw= +github.com/ungerik/go-fs v0.0.0-20250527162931-1691110c1708/go.mod h1:4QQilZ+gupNC1QdygHgHM8tKAj6/l1ohr8b1jYiPlaA= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/detectors/gcp v1.35.0/go.mod h1:qGWP8/+ILwMRIUf9uIVLloR1uo5ZYAslM4O6OqUi1DA= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= +go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= +go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= +go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= +go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= +go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= +go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= +golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= +golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.237.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= +google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2/go.mod h1:49MsLSx0oWMOZqcpB3uL8ZOkAh1+TndpJ8ONoCBWiZk= +google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= +google.golang.org/genproto/googleapis/bytestream v0.0.0-20250603155806-513f23925822/go.mod h1:h6yxum/C2qRb4txaZRLDHK8RyS0H/o2oEDeKY4onY/Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8= mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= +mvdan.cc/xurls/v2 v2.6.0/go.mod h1:bCvEZ1XvdA6wDnxY7jPPjEmigDtvtvPXAD/Exa9IMSk= diff --git a/impl/connection.go b/impl/connection.go deleted file mode 100644 index f461d96..0000000 --- a/impl/connection.go +++ /dev/null @@ -1,174 +0,0 @@ -package impl - -import ( - "context" - "database/sql" - "errors" - "fmt" - "time" - - "github.com/domonda/go-sqldb" -) - -// Connection returns a generic sqldb.Connection implementation -// for an existing sql.DB connection. -// argFmt is the format string for argument placeholders like "?" or "$%d" -// that will be replaced error messages to format a complete query. -func Connection(ctx context.Context, db *sql.DB, config *sqldb.Config, validateColumnName func(string) error, argFmt string) sqldb.Connection { - return &connection{ - ctx: ctx, - db: db, - config: config, - structFieldNamer: sqldb.DefaultStructFieldMapping, - argFmt: argFmt, - validateColumnName: validateColumnName, - } -} - -type connection struct { - ctx context.Context - db *sql.DB - config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper - argFmt string - validateColumnName func(string) error -} - -func (conn *connection) clone() *connection { - c := *conn - return &c -} - -func (conn *connection) Context() context.Context { return conn.ctx } - -func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.ctx { - return conn - } - c := conn.clone() - c.ctx = ctx - return c -} - -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *connection) Ping(timeout time.Duration) error { - ctx := conn.ctx - if timeout > 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - return conn.db.PingContext(ctx) -} - -func (conn *connection) Stats() sql.DBStats { - return conn.db.Stats() -} - -func (conn *connection) Config() *sqldb.Config { - return conn.config -} - -func (conn *connection) Placeholder(paramIndex int) string { - return fmt.Sprintf(conn.argFmt, paramIndex+1) -} - -func (conn *connection) ValidateColumnName(name string) error { - return conn.validateColumnName(name) -} - -func (conn *connection) Exec(query string, args ...any) error { - _, err := conn.db.ExecContext(conn.ctx, query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) -} - -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.argFmt, args) -} - -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) -} - -func (conn *connection) IsTransaction() bool { - return false -} - -func (conn *connection) TransactionNo() uint64 { - return 0 -} - -func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (conn *connection) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { - tx, err := conn.db.BeginTx(conn.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn, tx, opts, no), nil -} - -func (conn *connection) Commit() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) Rollback() error { - return sqldb.ErrNotWithinTransaction -} - -func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *connection) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *connection) IsListeningOnChannel(channel string) bool { - return false -} - -func (conn *connection) Close() error { - return conn.db.Close() -} diff --git a/impl/errors.go b/impl/errors.go deleted file mode 100644 index 6f69b5f..0000000 --- a/impl/errors.go +++ /dev/null @@ -1,33 +0,0 @@ -package impl - -import ( - "errors" - "fmt" -) - -// WrapNonNilErrorWithQuery wraps non nil errors with a formatted query -// if the error was not already wrapped with a query. -// If the passed error is nil, then nil will be returned. -func WrapNonNilErrorWithQuery(err error, query, argFmt string, args []any) error { - if err == nil { - return nil - } - var wrapped errWithQuery - if errors.As(err, &wrapped) { - return err // already wrapped - } - return errWithQuery{err, query, argFmt, args} -} - -type errWithQuery struct { - err error - query string - argFmt string - args []any -} - -func (e errWithQuery) Unwrap() error { return e.err } - -func (e errWithQuery) Error() string { - return fmt.Sprintf("%s from query: %s", e.err, FormatQuery(e.query, e.argFmt, e.args...)) -} diff --git a/impl/insert.go b/impl/insert.go deleted file mode 100644 index 8254b4f..0000000 --- a/impl/insert.go +++ /dev/null @@ -1,144 +0,0 @@ -package impl - -import ( - "fmt" - "reflect" - "strings" - - sqldb "github.com/domonda/go-sqldb" -) - -// Insert a new row into table using the values. -func Insert(conn sqldb.Connection, table, argFmt string, values sqldb.Values) error { - if len(values) == 0 { - return fmt.Errorf("Insert into table %s: no values", table) - } - - names, vals := values.Sorted() - b := strings.Builder{} - writeInsertQuery(&b, table, argFmt, names) - query := b.String() - - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} - -// InsertUnique inserts a new row into table using the passed values -// or does nothing if the onConflict statement applies. -// Returns if a row was inserted. -func InsertUnique(conn sqldb.Connection, table, argFmt string, values sqldb.Values, onConflict string) (inserted bool, err error) { - if len(values) == 0 { - return false, fmt.Errorf("InsertUnique into table %s: no values", table) - } - - if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") { - onConflict = onConflict[1 : len(onConflict)-1] - } - - names, vals := values.Sorted() - var query strings.Builder - writeInsertQuery(&query, table, argFmt, names) - fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) - - err = sqldb.ReplaceErrNoRows(err, nil) - err = WrapNonNilErrorWithQuery(err, query.String(), argFmt, vals) - return inserted, err -} - -// InsertReturning inserts a new row into table using values -// and returns values from the inserted row listed in returning. -func InsertReturning(conn sqldb.Connection, table, argFmt string, values sqldb.Values, returning string) sqldb.RowScanner { - if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) - } - - names, vals := values.Sorted() - var query strings.Builder - writeInsertQuery(&query, table, argFmt, names) - query.WriteString(" RETURNING ") - query.WriteString(returning) - return conn.QueryRow(query.String(), vals...) -} - -func writeInsertQuery(w *strings.Builder, table, argFmt string, names []string) { - fmt.Fprintf(w, `INSERT INTO %s(`, table) - for i, name := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteByte('"') - w.WriteString(name) - w.WriteByte('"') - } - w.WriteString(`) VALUES(`) - for i := range names { - if i > 0 { - w.WriteByte(',') - } - fmt.Fprintf(w, argFmt, i+1) - } - w.WriteByte(')') -} - -// InsertStruct inserts a new row into table using the connection's -// StructFieldMapper to map struct fields to column names. -// Optional ColumnFilter can be passed to ignore mapped columns. -func InsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) - if err != nil { - return err - } - - var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) - query := b.String() - - err = conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} - -// InsertUniqueStruct inserts a new row into table using the connection's -// StructFieldMapper to map struct fields to column names. -// Optional ColumnFilter can be passed to ignore mapped columns. -// Does nothing if the onConflict statement applies -// and returns if a row was inserted. -func InsertUniqueStruct(conn sqldb.Connection, table string, rowStruct any, onConflict string, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) (inserted bool, err error) { - columns, vals, err := insertStructValues(table, rowStruct, namer, ignoreColumns) - if err != nil { - return false, err - } - - if strings.HasPrefix(onConflict, "(") && strings.HasSuffix(onConflict, ")") { - onConflict = onConflict[1 : len(onConflict)-1] - } - - var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) - fmt.Fprintf(&b, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - query := b.String() - - err = conn.QueryRow(query, vals...).Scan(&inserted) - err = sqldb.ReplaceErrNoRows(err, nil) - - return inserted, WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} - -func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) - } - - columns, _, vals = ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - return columns, vals, nil -} diff --git a/impl/now.go b/impl/now.go deleted file mode 100644 index 4a1bd2f..0000000 --- a/impl/now.go +++ /dev/null @@ -1,15 +0,0 @@ -package impl - -import ( - "time" - - "github.com/domonda/go-sqldb" -) - -func Now(conn sqldb.Connection) (now time.Time, err error) { - err = conn.QueryRow(`select now()`).Scan(&now) - if err != nil { - return time.Time{}, err - } - return now, nil -} diff --git a/impl/reflectstruct.go b/impl/reflectstruct.go deleted file mode 100644 index f224d71..0000000 --- a/impl/reflectstruct.go +++ /dev/null @@ -1,121 +0,0 @@ -package impl - -import ( - "errors" - "fmt" - "reflect" - "slices" - "strings" - - "github.com/domonda/go-sqldb" -) - -func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { - for i := 0; i < structVal.NumField(); i++ { - fieldType := structVal.Type().Field(i) - _, column, flags, use := namer.MapStructField(fieldType) - if !use { - continue - } - fieldValue := structVal.Field(i) - - if column == "" { - // Embedded struct field - columnsEmbed, pkColsEmbed, valuesEmbed := ReflectStructValues(fieldValue, namer, ignoreColumns) - for _, pkCol := range pkColsEmbed { - pkCols = append(pkCols, pkCol+len(columns)) - } - columns = append(columns, columnsEmbed...) - values = append(values, valuesEmbed...) - continue - } - - if ignoreColumn(ignoreColumns, column, flags, fieldType, fieldValue) { - continue - } - - if flags.PrimaryKey() { - pkCols = append(pkCols, len(columns)) - } - columns = append(columns, column) - values = append(values, fieldValue.Interface()) - } - return columns, pkCols, values -} - -func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { - if len(columns) == 0 { - return nil, errors.New("no columns") - } - pointers = make([]any, len(columns)) - err = reflectStructColumnPointers(structVal, namer, columns, pointers) - if err != nil { - return nil, err - } - for _, ptr := range pointers { - if ptr != nil { - continue - } - nilCols := new(strings.Builder) - for i, ptr := range pointers { - if ptr != nil { - continue - } - if nilCols.Len() > 0 { - nilCols.WriteString(", ") - } - fmt.Fprintf(nilCols, "column=%s, index=%d", columns[i], i) - } - return nil, fmt.Errorf("columns have no mapped struct fields in %s: %s", structVal.Type(), nilCols) - } - return pointers, nil -} - -func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error { - structType := structVal.Type() - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - _, column, _, use := namer.MapStructField(field) - if !use { - continue - } - fieldValue := structVal.Field(i) - - if column == "" { - // Embedded struct field - err := reflectStructColumnPointers(fieldValue, namer, columns, pointers) - if err != nil { - return err - } - continue - } - - colIndex := slices.Index(columns, column) - if colIndex == -1 { - continue - } - - if pointers[colIndex] != nil { - return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, field.Name, structType) - } - - pointer := fieldValue.Addr().Interface() - // If field is a slice or array that does not implement sql.Scanner - // and it's not a string scannable []byte type underneath - // then wrap it with WrapForArray to make it scannable - if NeedsArrayWrappingForScanning(fieldValue) { - pointer = WrapArray(pointer) - } - pointers[colIndex] = pointer - } - return nil -} - -func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool { - for _, filter := range filters { - if filter.IgnoreColumn(name, flags, fieldType, fieldValue) { - return true - } - } - return false -} diff --git a/impl/row.go b/impl/row.go deleted file mode 100644 index a3b34c7..0000000 --- a/impl/row.go +++ /dev/null @@ -1,13 +0,0 @@ -package impl - -// Row is an interface with the methods of sql.Rows -// that are needed for ScanStruct. -// Allows mocking for tests without an SQL driver. -type Row interface { - // Columns returns the column names. - Columns() ([]string, error) - // Scan copies the columns in the current row into the values pointed - // at by dest. The number of values in dest must be the same as the - // number of columns in Rows. - Scan(dest ...any) error -} diff --git a/impl/rowscanner.go b/impl/rowscanner.go deleted file mode 100644 index 6bc5826..0000000 --- a/impl/rowscanner.go +++ /dev/null @@ -1,129 +0,0 @@ -package impl - -import ( - "database/sql" - "errors" - - sqldb "github.com/domonda/go-sqldb" -) - -var ( - _ sqldb.RowScanner = &RowScanner{} - _ sqldb.RowScanner = CurrentRowScanner{} - _ sqldb.RowScanner = SingleRowScanner{} -) - -// RowScanner implements sqldb.RowScanner for a sql.Row -type RowScanner struct { - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowScanner(rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowScanner { - return &RowScanner{rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowScanner) Scan(dest ...any) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return s.rows.Scan(dest...) -} - -func (s *RowScanner) ScanStruct(dest any) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return ScanStruct(s.rows, dest, s.structFieldNamer) -} - -func (s *RowScanner) ScanValues() ([]any, error) { - return ScanValues(s.rows) -} - -func (s *RowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.rows) -} - -func (s *RowScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -// CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close -type CurrentRowScanner struct { - Rows Rows - StructFieldMapper sqldb.StructFieldMapper -} - -func (s CurrentRowScanner) Scan(dest ...any) error { - return s.Rows.Scan(dest...) -} - -func (s CurrentRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Rows, dest, s.StructFieldMapper) -} - -func (s CurrentRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Rows) -} - -func (s CurrentRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Rows) -} - -func (s CurrentRowScanner) Columns() ([]string, error) { - return s.Rows.Columns() -} - -// SingleRowScanner always uses the same Row -type SingleRowScanner struct { - Row Row - StructFieldMapper sqldb.StructFieldMapper -} - -func (s SingleRowScanner) Scan(dest ...any) error { - return s.Row.Scan(dest...) -} - -func (s SingleRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Row, dest, s.StructFieldMapper) -} - -func (s SingleRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Row) -} - -func (s SingleRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Row) -} - -func (s SingleRowScanner) Columns() ([]string, error) { - return s.Row.Columns() -} diff --git a/impl/rowsscanner.go b/impl/rowsscanner.go deleted file mode 100644 index 339a0e7..0000000 --- a/impl/rowsscanner.go +++ /dev/null @@ -1,154 +0,0 @@ -package impl - -import ( - "context" - "errors" - "fmt" - "reflect" - - sqldb "github.com/domonda/go-sqldb" -) - -var _ sqldb.RowsScanner = &RowsScanner{} - -// RowsScanner implements sqldb.RowsScanner with Rows -type RowsScanner struct { - ctx context.Context // ctx is checked for every row and passed through to callbacks - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowsScanner(ctx context.Context, rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowsScanner { - return &RowsScanner{ctx, rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowsScanner) ScanSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, nil) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) ScanStructSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldNamer) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -func (s *RowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { - cols, err := s.rows.Columns() - if err != nil { - return nil, err - } - if headerRow { - rows = [][]string{cols} - } - stringScannablePtrs := make([]any, len(cols)) - err = s.ForEachRow(func(rowScanner sqldb.RowScanner) error { - row := make([]string, len(cols)) - for i := range stringScannablePtrs { - stringScannablePtrs[i] = (*sqldb.StringScannable)(&row[i]) - } - err := rowScanner.Scan(stringScannablePtrs...) - if err != nil { - return err - } - rows = append(rows, row) - return nil - }) - return rows, err -} - -func (s *RowsScanner) ForEachRow(callback func(sqldb.RowScanner) error) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - for s.rows.Next() { - if s.ctx.Err() != nil { - return s.ctx.Err() - } - - err := callback(CurrentRowScanner{s.rows, s.structFieldNamer}) - if err != nil { - return err - } - } - return s.rows.Err() -} - -func (s *RowsScanner) ForEachRowCall(callback any) error { - forEachRowFunc, err := ForEachRowCallFunc(s.ctx, callback) - if err != nil { - return err - } - return s.ForEachRow(forEachRowFunc) -} - -// ScanRowsAsSlice scans all srcRows as slice into dest. -// The rows must either have only one column compatible with the element type of the slice, -// or if multiple columns are returned then the slice element type must me a struct or struction pointer -// so that every column maps on exactly one struct field using structFieldNamer. -// In case of single column rows, nil must be passed for structFieldNamer. -// ScanRowsAsSlice calls srcRows.Close(). -func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer sqldb.StructFieldMapper) error { - defer srcRows.Close() - - destVal := reflect.ValueOf(dest) - if destVal.Kind() != reflect.Ptr { - return fmt.Errorf("scan dest is not a pointer but %s", destVal.Type()) - } - if destVal.IsNil() { - return errors.New("scan dest is nil") - } - slice := destVal.Elem() - if slice.Kind() != reflect.Slice { - return fmt.Errorf("scan dest is not pointer to slice but %s", destVal.Type()) - } - sliceElemType := slice.Type().Elem() - - newSlice := reflect.MakeSlice(slice.Type(), 0, 32) - - for srcRows.Next() { - if ctx.Err() != nil { - return ctx.Err() - } - - newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType)) - target := newSlice.Index(newSlice.Len() - 1).Addr() - if structFieldNamer != nil { - err := ScanStruct(srcRows, target.Interface(), structFieldNamer) - if err != nil { - return err - } - } else { - err := srcRows.Scan(target.Interface()) - if err != nil { - return err - } - } - } - if srcRows.Err() != nil { - return srcRows.Err() - } - - // Assign newSlice if there were no errors - if newSlice.Len() == 0 { - slice.SetLen(0) - } else { - slice.Set(newSlice) - } - - return nil -} diff --git a/impl/scanresult.go b/impl/scanresult.go deleted file mode 100644 index d8a4bc3..0000000 --- a/impl/scanresult.go +++ /dev/null @@ -1,54 +0,0 @@ -package impl - -import "github.com/domonda/go-sqldb" - -// ScanValues returns the values of a row exactly how they are -// passed from the database driver to an sql.Scanner. -// Byte slices will be copied. -func ScanValues(src Row) ([]any, error) { - cols, err := src.Columns() - if err != nil { - return nil, err - } - var ( - anys = make([]sqldb.AnyValue, len(cols)) - result = make([]any, len(cols)) - ) - // result elements hold pointer to sqldb.AnyValue for scanning - for i := range result { - result[i] = &anys[i] - } - err = src.Scan(result...) - if err != nil { - return nil, err - } - // don't return pointers to sqldb.AnyValue - // but what internal value has been scanned - for i := range result { - result[i] = anys[i].Val - } - return result, nil -} - -// ScanStrings scans the values of a row as strings. -// Byte slices will be interpreted as strings, -// nil (SQL NULL) will be converted to an empty string, -// all other types are converted with fmt.Sprint. -func ScanStrings(src Row) ([]string, error) { - cols, err := src.Columns() - if err != nil { - return nil, err - } - var ( - result = make([]string, len(cols)) - resultPtrs = make([]any, len(cols)) - ) - for i := range resultPtrs { - resultPtrs[i] = (*sqldb.StringScannable)(&result[i]) - } - err = src.Scan(resultPtrs...) - if err != nil { - return nil, err - } - return result, nil -} diff --git a/impl/transaction.go b/impl/transaction.go deleted file mode 100644 index d4fed91..0000000 --- a/impl/transaction.go +++ /dev/null @@ -1,155 +0,0 @@ -package impl - -import ( - "context" - "database/sql" - "errors" - "fmt" - "time" - - "github.com/domonda/go-sqldb" -) - -type transaction struct { - // The parent non-transaction connection is needed - // for its ctx, Ping(), Stats(), and Config() - parent *connection - tx *sql.Tx - opts *sql.TxOptions - no uint64 - structFieldNamer sqldb.StructFieldMapper -} - -func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions, no uint64) *transaction { - return &transaction{ - parent: parent, - tx: tx, - opts: opts, - no: no, - structFieldNamer: parent.structFieldNamer, - } -} - -func (conn *transaction) clone() *transaction { - c := *conn - return &c -} - -func (conn *transaction) Context() context.Context { return conn.parent.ctx } - -func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { - if ctx == conn.parent.ctx { - return conn - } - parent := conn.parent.clone() - parent.ctx = ctx - return newTransaction(parent, conn.tx, conn.opts, conn.no) -} - -func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { - c := conn.clone() - c.structFieldNamer = namer - return c -} - -func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { - return conn.structFieldNamer -} - -func (conn *transaction) Ping(timeout time.Duration) error { return conn.parent.Ping(timeout) } -func (conn *transaction) Stats() sql.DBStats { return conn.parent.Stats() } -func (conn *transaction) Config() *sqldb.Config { return conn.parent.Config() } -func (conn *transaction) Placeholder(paramIndex int) string { - return conn.parent.Placeholder(paramIndex) -} - -func (conn *transaction) ValidateColumnName(name string) error { - return conn.parent.validateColumnName(name) -} - -func (conn *transaction) Exec(query string, args ...any) error { - _, err := conn.tx.Exec(query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) -} - -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.parent.argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.parent.argFmt, args) -} - -func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { - rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) - return sqldb.RowsScannerWithError(err) - } - return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) -} - -func (conn *transaction) IsTransaction() bool { - return true -} - -func (conn *transaction) TransactionNo() uint64 { - return conn.no -} - -func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn *transaction) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { - tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) - if err != nil { - return nil, err - } - return newTransaction(conn.parent, tx, opts, no), nil -} - -func (conn *transaction) Commit() error { - return conn.tx.Commit() -} - -func (conn *transaction) Rollback() error { - return conn.tx.Rollback() -} - -func (conn *transaction) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *transaction) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *transaction) IsListeningOnChannel(channel string) bool { - return false -} - -func (conn *transaction) Close() error { - return conn.Rollback() -} diff --git a/impl/update.go b/impl/update.go deleted file mode 100644 index e5f6ca1..0000000 --- a/impl/update.go +++ /dev/null @@ -1,113 +0,0 @@ -package impl - -import ( - "fmt" - "reflect" - "slices" - "strings" - - sqldb "github.com/domonda/go-sqldb" -) - -// Update table rows(s) with values using the where statement with passed in args starting at $1. -func Update(conn sqldb.Connection, table string, values sqldb.Values, where, argFmt string, args []any) error { - if len(values) == 0 { - return fmt.Errorf("Update table %s: no values passed", table) - } - - query, vals := buildUpdateQuery(table, values, where, args) - err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} - -// UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 -// and returning a single row with the columns specified in returning argument. -func UpdateReturningRow(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) - } - - query, vals := buildUpdateQuery(table, values, where, args) - query += " RETURNING " + returning - return conn.QueryRow(query, vals...) -} - -// UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 -// and returning multiple rows with the columns specified in returning argument. -func UpdateReturningRows(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - if len(values) == 0 { - return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) - } - - query, vals := buildUpdateQuery(table, values, where, args) - query += " RETURNING " + returning - return conn.QueryRows(query, vals...) -} - -func buildUpdateQuery(table string, values sqldb.Values, where string, args []any) (string, []any) { - names, vals := values.Sorted() - - var query strings.Builder - fmt.Fprintf(&query, `UPDATE %s SET `, table) - for i := range names { - if i > 0 { - query.WriteByte(',') - } - fmt.Fprintf(&query, `"%s"=$%d`, names[i], 1+len(args)+i) - } - fmt.Fprintf(&query, ` WHERE %s`, where) - - return query.String(), append(args, vals...) -} - -// UpdateStruct updates a row of table using the exported fields -// of rowStruct which have a `db` tag that is not "-". -// Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. -// If restrictToColumns are provided, then only struct fields with a `db` tag -// matching any of the passed column names will be used. -func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return fmt.Errorf("UpdateStruct of table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return fmt.Errorf("UpdateStruct of table %s: expected struct but got %T", table, rowStruct) - } - - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - if len(pkCols) == 0 { - return fmt.Errorf("UpdateStruct of table %s: %s has no mapped primary key field", table, v.Type()) - } - - var b strings.Builder - fmt.Fprintf(&b, `UPDATE %s SET `, table) - first := true - for i := range columns { - if slices.Contains(pkCols, i) { - continue - } - if first { - first = false - } else { - b.WriteByte(',') - } - fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) - } - - b.WriteString(` WHERE `) - for i, pkCol := range pkCols { - if i > 0 { - b.WriteString(` AND `) - } - fmt.Fprintf(&b, `"%s"=$%d`, columns[pkCol], i+1) - } - - query := b.String() - - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} diff --git a/impl/upsert.go b/impl/upsert.go deleted file mode 100644 index ebbbfbd..0000000 --- a/impl/upsert.go +++ /dev/null @@ -1,63 +0,0 @@ -package impl - -import ( - "fmt" - "reflect" - "slices" - "strings" - - "github.com/domonda/go-sqldb" -) - -// UpsertStruct upserts a row to table using the exported fields -// of rowStruct which have a `db` tag that is not "-". -// Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. -// If restrictToColumns are provided, then only struct fields with a `db` tag -// matching any of the passed column names will be used. -// If inserting conflicts on pkColumn, then an update of the existing row is performed. -func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return fmt.Errorf("UpsertStruct to table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return fmt.Errorf("UpsertStruct to table %s: expected struct but got %T", table, rowStruct) - } - - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - if len(pkCols) == 0 { - return fmt.Errorf("UpsertStruct of table %s: %s has no mapped primary key field", table, v.Type()) - } - - var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) - b.WriteString(` ON CONFLICT(`) - for i, pkCol := range pkCols { - if i > 0 { - b.WriteByte(',') - } - fmt.Fprintf(&b, `"%s"`, columns[pkCol]) - } - - b.WriteString(`) DO UPDATE SET `) - first := true - for i := range columns { - if slices.Contains(pkCols, i) { - continue - } - if first { - first = false - } else { - b.WriteByte(',') - } - fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) - } - query := b.String() - - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) -} diff --git a/information/column.go b/information/column.go index 075105d..9ee2ecd 100644 --- a/information/column.go +++ b/information/column.go @@ -4,7 +4,6 @@ import ( "context" "strings" - "github.com/domonda/go-errs" "github.com/domonda/go-sqldb/db" ) @@ -67,28 +66,24 @@ type KeyColumnUsage struct { PositionInUniqueConstraint *int `db:"position_in_unique_constraint"` } -func ColumnExists(ctx context.Context, table, column string) (exists bool, err error) { - defer errs.WrapWithFuncParams(&err, ctx, table, column) - +func ColumnExists(ctx context.Context, table, column string) (bool, error) { tableSchema, tableName, ok := strings.Cut(table, ".") if !ok { tableSchema = "public" tableName = table } - err = db.QueryRow(ctx, - `select exists( - select from information_schema.columns - where table_schema = $1 - and table_name = $2 - and column_name = $3 - )`, + return db.QueryRowValue[bool](ctx, + /*sql*/ ` + SELECT EXISTS ( + SELECT FROM information_schema.columns + WHERE table_schema = $1 + AND table_name = $2 + AND column_name = $3 + ) + `, tableSchema, tableName, column, - ).Scan(&exists) - if err != nil { - return false, err - } - return exists, nil + ) } diff --git a/information/primarykeys.go b/information/primarykeys.go index 87f6ffb..8461d52 100644 --- a/information/primarykeys.go +++ b/information/primarykeys.go @@ -12,7 +12,6 @@ import ( "strings" "text/template" - "github.com/domonda/go-errs" "github.com/domonda/go-sqldb/db" "github.com/domonda/go-types/uu" ) @@ -25,85 +24,75 @@ type PrimaryKeyColumn struct { } func GetPrimaryKeyColumns(ctx context.Context) (cols []PrimaryKeyColumn, err error) { - defer errs.WrapWithFuncParams(&err, ctx) - - err = db.QueryRows(ctx, ` - select - tc.table_schema||'.'||tc.table_name as "table", - kc.column_name as "column", - col.data_type as "type", - (select exists( - select from information_schema.table_constraints as fk_tc - inner join information_schema.key_column_usage as fk_kc - on fk_kc.table_schema = fk_tc.table_schema - and fk_kc.table_name = fk_tc.table_name - and fk_kc.constraint_name = fk_tc.constraint_name - where fk_tc.constraint_type = 'FOREIGN KEY' - and fk_tc.table_schema = tc.table_schema - and fk_tc.table_name = tc.table_name - and fk_kc.column_name = kc.column_name - )) as "foreign_key" - from information_schema.table_constraints as tc - inner join information_schema.key_column_usage as kc - on kc.table_schema = tc.table_schema - and kc.table_name = tc.table_name - and kc.constraint_name = tc.constraint_name - inner join information_schema.columns as col - on col.table_schema = tc.table_schema - and col.table_name = tc.table_name - and col.column_name = kc.column_name - where tc.constraint_type = 'PRIMARY KEY' - and kc.ordinal_position is not null - order by + return db.QueryRowsAsSlice[PrimaryKeyColumn](ctx, + /*sql*/ ` + SELECT + tc.table_schema||'.'||tc.table_name AS "table", + kc.column_name AS "column", + col.data_type AS "type", + (SELECT EXISTS( + SELECT FROM information_schema.table_constraints AS fk_tc + JOIN information_schema.key_column_usage AS fk_kc + ON fk_kc.table_schema = fk_tc.table_schema + AND fk_kc.table_name = fk_tc.table_name + AND fk_kc.constraint_name = fk_tc.constraint_name + WHERE fk_tc.constraint_type = 'FOREIGN KEY' + AND fk_tc.table_schema = tc.table_schema + AND fk_tc.table_name = tc.table_name + AND fk_kc.column_name = kc.column_name + )) AS "foreign_key" + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kc + ON kc.table_schema = tc.table_schema + AND kc.table_name = tc.table_name + AND kc.constraint_name = tc.constraint_name + JOIN information_schema.columns AS col + ON col.table_schema = tc.table_schema + AND col.table_name = tc.table_name + AND col.column_name = kc.column_name + WHERE tc.constraint_type = 'PRIMARY KEY' + AND kc.ordinal_positiON IS NOT NULL + ORDER BY tc.table_schema, tc.table_name`, - ).ScanStructSlice(&cols) - if err != nil { - return nil, err - } - return cols, nil + ) } func GetPrimaryKeyColumnsOfType(ctx context.Context, pkType string) (cols []PrimaryKeyColumn, err error) { - defer errs.WrapWithFuncParams(&err, ctx, pkType) - - err = db.QueryRows(ctx, ` - select - tc.table_schema||'.'||tc.table_name as "table", - kc.column_name as "column", - col.data_type as "type", - (select exists( - select from information_schema.table_constraints as fk_tc - inner join information_schema.key_column_usage as fk_kc - on fk_kc.table_schema = fk_tc.table_schema - and fk_kc.table_name = fk_tc.table_name - and fk_kc.constraint_name = fk_tc.constraint_name - where fk_tc.constraint_type = 'FOREIGN KEY' - and fk_tc.table_schema = tc.table_schema - and fk_tc.table_name = tc.table_name - and fk_kc.column_name = kc.column_name - )) as "foreign_key" - from information_schema.table_constraints as tc - inner join information_schema.key_column_usage as kc - on kc.table_schema = tc.table_schema - and kc.table_name = tc.table_name - and kc.constraint_name = tc.constraint_name - inner join information_schema.columns as col - on col.table_schema = tc.table_schema - and col.table_name = tc.table_name - and col.column_name = kc.column_name - where tc.constraint_type = 'PRIMARY KEY' - and kc.ordinal_position is not null - and col.data_type = $1 - order by + return db.QueryRowsAsSlice[PrimaryKeyColumn](ctx, + /*sql*/ ` + SELECT + tc.table_schema||'.'||tc.table_name AS "table", + kc.column_name AS "column", + col.data_type AS "type", + (SELECT EXISTS( + SELECT FROM information_schema.table_constraints AS fk_tc + JOIN information_schema.key_column_usage AS fk_kc + ON fk_kc.table_schema = fk_tc.table_schema + AND fk_kc.table_name = fk_tc.table_name + AND fk_kc.constraint_name = fk_tc.constraint_name + WHERE fk_tc.constraint_type = 'FOREIGN KEY' + AND fk_tc.table_schema = tc.table_schema + AND fk_tc.table_name = tc.table_name + AND fk_kc.column_name = kc.column_name + )) AS "foreign_key" + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kc + ON kc.table_schema = tc.table_schema + AND kc.table_name = tc.table_name + AND kc.constraint_name = tc.constraint_name + JOIN information_schema.columns AS col + ON col.table_schema = tc.table_schema + AND col.table_name = tc.table_name + AND col.column_name = kc.column_name + WHERE tc.constraint_type = 'PRIMARY KEY' + AND kc.ordinal_positiON IS NOT NULL + AND col.data_type = $1 + ORDER BY tc.table_schema, tc.table_name`, pkType, - ).ScanStructSlice(&cols) - if err != nil { - return nil, err - } - return cols, nil + ) } type TableRowWithPrimaryKey struct { @@ -113,33 +102,22 @@ type TableRowWithPrimaryKey struct { } func GetTableRowsWithPrimaryKey(ctx context.Context, pkCols []PrimaryKeyColumn, pk any) (tableRows []TableRowWithPrimaryKey, err error) { - defer errs.WrapWithFuncParams(&err, ctx, pkCols, pk) - - conn := db.Conn(ctx) for _, col := range pkCols { - query := fmt.Sprintf(`select * from %s where "%s" = $1`, col.Table, col.Column) - row := conn.QueryRows(query, pk) - cols, err := row.Columns() - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } - return nil, err - } - vals, err := row.ScanAllRowsAsStrings(false) + query := fmt.Sprintf(`SELECT * FROM %s WHERE "%s" = $1`, col.Table, col.Column) + strs, err := db.QueryRowsAsStrings(ctx, query, pk) if err != nil { if errors.Is(err, sql.ErrNoRows) { continue } return nil, err } - if len(vals) == 0 { + if len(strs) < 2 { continue } tableRows = append(tableRows, TableRowWithPrimaryKey{ PrimaryKeyColumn: col, - Header: cols, - Row: vals[0], + Header: strs[0], + Row: strs[1], }) } return tableRows, nil @@ -149,12 +127,16 @@ var RenderUUIDPrimaryKeyRefsHTML = http.HandlerFunc(func(writer http.ResponseWri var ( title string mainContent any - style = []string{StyleAllMonospace, StyleDefaultTable, ``} + style = []string{ + StyleAllMonospace, + StyleDefaultTable, + ``, + } ) pk, err := uu.IDFromString(request.URL.Query().Get("pk")) if err != nil { title = "Primary Key UUID" - mainContent = ` + mainContent = /*html*/ `