diff --git a/database/contracts.go b/database/contracts.go index bf55d320..6da723a8 100644 --- a/database/contracts.go +++ b/database/contracts.go @@ -1,5 +1,10 @@ package database +import ( + "context" + "github.com/jmoiron/sqlx" +) + // Entity is implemented by each type that works with the database package. type Entity interface { Fingerprinter @@ -54,3 +59,10 @@ type PgsqlOnConflictConstrainter interface { // PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table. PgsqlOnConflictConstraint() string } + +// TxOrDB is just a helper interface that can represent a *[sqlx.Tx] or *[DB] instance. +type TxOrDB interface { + sqlx.ExtContext + + PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) +} diff --git a/database/db.go b/database/db.go index 5d3ad2f3..f75a2f4a 100644 --- a/database/db.go +++ b/database/db.go @@ -652,6 +652,26 @@ func (db *DB) Delete( return db.DeleteStreamed(ctx, entityType, idsCh, onSuccess...) } +// RunInTx allows running a function in a database transaction without requiring manual transaction handling. +// +// A new transaction is started on [DB] which is then passed to fn. After fn returns, the transaction is +// committed unless an error was returned. If fn returns an error, that error is returned or when failing +// to start or/and commit the transaction. +func (db *DB) RunInTx(ctx context.Context, fn func(tx *sqlx.Tx) error) error { + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(err, "DB.RunInTx: cannot start a database transaction") + } + // We don't expect meaningful errors from rolling back the tx other than the sql.ErrTxDone, so just ignore it. + defer func() { _ = tx.Rollback() }() + + if err := fn(tx); err != nil { + return err + } + + return errors.Wrap(tx.Commit(), "DB.RunInTx: cannot commit a database transaction") +} + func (db *DB) GetSemaphoreForTable(table string) *semaphore.Weighted { db.tableSemaphoresMu.Lock() defer db.tableSemaphoresMu.Unlock() @@ -693,3 +713,9 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed) })) } + +var ( + // Assert TxOrDB interface compliance of the DB and sqlx.Tx types. + _ TxOrDB = (*DB)(nil) + _ TxOrDB = (*sqlx.Tx)(nil) +) diff --git a/database/utils.go b/database/utils.go index 2ae372ce..788b8310 100644 --- a/database/utils.go +++ b/database/utils.go @@ -7,6 +7,7 @@ import ( "github.com/icinga/icinga-go-library/com" "github.com/icinga/icinga-go-library/strcase" "github.com/icinga/icinga-go-library/types" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" ) @@ -43,6 +44,47 @@ func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] { } } +// InsertObtainID executes the given query and fetches the last inserted ID. +// +// Using this method for database tables that don't define an auto-incrementing ID, or none at all, +// will not work. The only supported column that can be retrieved with this method is id. +// This function expects [TxOrDB] as an executor of the provided query, and is usually a *[sqlx.Tx] or *[DB] instance. +// Returns the retrieved ID on success and error on any database inserting/retrieving failure. +func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (int64, error) { + var resultID int64 + switch conn.DriverName() { + case PostgreSQL: + query := stmt + " RETURNING id" + ps, err := conn.PrepareNamedContext(ctx, query) + if err != nil { + return 0, errors.Wrapf(err, "cannot prepare %q", query) + } + // We're deferring the ps#Close invocation here just to be on the safe side, otherwise it's + // closed manually later on and the error is handled gracefully (if any). + defer func() { _ = ps.Close() }() + + if err := ps.GetContext(ctx, &resultID, arg); err != nil { + return 0, CantPerformQuery(err, query) + } + + if err := ps.Close(); err != nil { + return 0, errors.Wrapf(err, "cannot close prepared statement %q", query) + } + default: + result, err := sqlx.NamedExecContext(ctx, conn, stmt, arg) + if err != nil { + return 0, CantPerformQuery(err, stmt) + } + + resultID, err = result.LastInsertId() + if err != nil { + return 0, errors.Wrap(err, "cannot retrieve last inserted ID") + } + } + + return resultID, nil +} + // setGaleraOpts sets the "wsrep_sync_wait" variable for each session ensures that causality checks are performed // before execution and that each statement is executed on a fully synchronized node. Doing so prevents foreign key // violation when inserting into dependent tables on different MariaDB/MySQL nodes. When using MySQL single nodes,