Skip to content

Commit 89805fa

Browse files
committed
Introduce DB#InsertObtainID() method
1 parent a66da14 commit 89805fa

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

database/contracts.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
package database
22

3+
import (
4+
"context"
5+
"github.com/jmoiron/sqlx"
6+
)
7+
38
// Entity is implemented by each type that works with the database package.
49
type Entity interface {
510
Fingerprinter
@@ -54,3 +59,11 @@ type PgsqlOnConflictConstrainter interface {
5459
// PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table.
5560
PgsqlOnConflictConstraint() string
5661
}
62+
63+
// TxOrDB is just a helper interface that can represent a *[sqlx.Tx] or *[DB] instance.
64+
type TxOrDB interface {
65+
sqlx.ExtContext
66+
sqlx.PreparerContext
67+
68+
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
69+
}

database/utils.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package database
22

33
import (
44
"context"
5+
"database/sql"
56
"database/sql/driver"
67
"github.com/go-sql-driver/mysql"
78
"github.com/icinga/icinga-go-library/com"
89
"github.com/icinga/icinga-go-library/strcase"
910
"github.com/icinga/icinga-go-library/types"
11+
"github.com/jmoiron/sqlx"
1012
"github.com/pkg/errors"
1113
)
1214

@@ -43,6 +45,49 @@ func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] {
4345
}
4446
}
4547

48+
// InsertObtainID executes the given query and fetches the last inserted ID.
49+
//
50+
// Using this method for database tables that don't define an auto-incrementing ID, or none at all,
51+
// will not work. The only supported column that can be retrieved with this method is id.
52+
// This function expects [TxOrDB] as an executor of the provided query, and is usually a *[sqlx.Tx] or *[DB] instance.
53+
// Returns the retrieved ID wrapped in [types.Int] on success and error on any database inserting/retrieving failure.
54+
func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (types.Int, error) {
55+
var resultID int64
56+
switch conn.DriverName() {
57+
case PostgreSQL:
58+
query := stmt + " RETURNING id"
59+
ps, err := conn.PrepareNamedContext(ctx, query)
60+
if err != nil {
61+
return types.Int{}, errors.Wrapf(err, "cannot prepare %q", query)
62+
}
63+
// We're deferring the ps#Close invocation here just to be on the safe side, otherwise it's
64+
// closed manually later on and the error is handled gracefully (if any).
65+
defer func() { _ = ps.Close() }()
66+
67+
if err = ps.GetContext(ctx, &resultID, arg); err != nil {
68+
return types.Int{}, CantPerformQuery(err, query)
69+
}
70+
71+
if err = ps.Close(); err != nil {
72+
return types.Int{}, errors.Wrapf(err, "cannot close prepared statement %q", query)
73+
}
74+
case MySQL:
75+
result, err := sqlx.NamedExecContext(ctx, conn, stmt, arg)
76+
if err != nil {
77+
return types.Int{}, CantPerformQuery(err, stmt)
78+
}
79+
80+
resultID, err = result.LastInsertId()
81+
if err != nil {
82+
return types.Int{}, errors.Wrap(err, "cannot retrieve last inserted ID")
83+
}
84+
default:
85+
return types.Int{}, errors.Errorf("unsupported driver: %s", conn.DriverName())
86+
}
87+
88+
return types.Int{NullInt64: sql.NullInt64{Int64: resultID, Valid: true}}, nil
89+
}
90+
4691
// setGaleraOpts sets the "wsrep_sync_wait" variable for each session ensures that causality checks are performed
4792
// before execution and that each statement is executed on a fully synchronized node. Doing so prevents foreign key
4893
// violation when inserting into dependent tables on different MariaDB/MySQL nodes. When using MySQL single nodes,

0 commit comments

Comments
 (0)