Skip to content

Commit 3b2195f

Browse files
authored
Merge pull request #64 from Icinga/last-insert-id
Introduce `DB#InsertObtainID()` function
2 parents 94140ed + 1317301 commit 3b2195f

File tree

4 files changed

+99
-7
lines changed

4 files changed

+99
-7
lines changed

database/contracts.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
package database
22

3+
import (
4+
"context"
5+
"github.com/jmoiron/sqlx"
6+
)
7+
38
// Entity is implemented by each type that works with the database package.
49
type Entity interface {
510
Fingerprinter
@@ -54,3 +59,10 @@ type PgsqlOnConflictConstrainter interface {
5459
// PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table.
5560
PgsqlOnConflictConstraint() string
5661
}
62+
63+
// TxOrDB is just a helper interface that can represent a *[sqlx.Tx] or *[DB] instance.
64+
type TxOrDB interface {
65+
sqlx.ExtContext
66+
67+
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
68+
}

database/db.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,3 +904,9 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio
904904
db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed)
905905
}))
906906
}
907+
908+
var (
909+
// Assert TxOrDB interface compliance of the DB and sqlx.Tx types.
910+
_ TxOrDB = (*DB)(nil)
911+
_ TxOrDB = (*sqlx.Tx)(nil)
912+
)

database/utils.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/icinga/icinga-go-library/com"
99
"github.com/icinga/icinga-go-library/strcase"
1010
"github.com/icinga/icinga-go-library/types"
11+
"github.com/jmoiron/sqlx"
1112
"github.com/pkg/errors"
1213
)
1314

@@ -44,6 +45,42 @@ func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] {
4445
}
4546
}
4647

48+
// InsertObtainID executes the given query and fetches the last inserted ID.
49+
//
50+
// Using this method for database tables that don't define an auto-incrementing ID, or none at all,
51+
// will not work. The only supported column that can be retrieved with this method is id.
52+
//
53+
// This function expects [TxOrDB] as an executor of the provided query, and is usually a *[sqlx.Tx] or *[DB] instance.
54+
//
55+
// Returns the retrieved ID on success and error on any database inserting/retrieving failure.
56+
func InsertObtainID(ctx context.Context, conn TxOrDB, stmt string, arg any) (int64, error) {
57+
var resultID int64
58+
switch conn.DriverName() {
59+
case PostgreSQL:
60+
stmt = stmt + " RETURNING id"
61+
query, args, err := conn.BindNamed(stmt, arg)
62+
if err != nil {
63+
return 0, errors.Wrapf(err, "can't bind named query %q", stmt)
64+
}
65+
66+
if err := sqlx.GetContext(ctx, conn, &resultID, query, args...); err != nil {
67+
return 0, CantPerformQuery(err, query)
68+
}
69+
default:
70+
result, err := sqlx.NamedExecContext(ctx, conn, stmt, arg)
71+
if err != nil {
72+
return 0, CantPerformQuery(err, stmt)
73+
}
74+
75+
resultID, err = result.LastInsertId()
76+
if err != nil {
77+
return 0, errors.Wrap(err, "can't retrieve last inserted ID")
78+
}
79+
}
80+
81+
return resultID, nil
82+
}
83+
4784
// unsafeSetSessionVariableIfExists sets the given MySQL/MariaDB system variable for the specified database session.
4885
//
4986
// NOTE: It is unsafe to use this function with untrusted/user supplied inputs and poses an SQL injection,

database/utils_test.go

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,48 @@ import (
1616
"time"
1717
)
1818

19-
func TestSetMysqlSessionVars(t *testing.T) {
19+
func TestDatabaseUtils(t *testing.T) {
20+
t.Parallel()
21+
22+
ctx := context.Background()
23+
db := GetTestDB(ctx, t, "ICINGAGOLIBRARY")
24+
25+
t.Run("SetMySQLSessionVars", func(t *testing.T) {
26+
t.Parallel()
27+
if db.DriverName() != MySQL {
28+
t.Skipf("skipping set session vars test for %q driver", db.DriverName())
29+
}
30+
31+
setMysqlSessionVars(ctx, db, t)
32+
})
33+
34+
t.Run("InsertObtainID", func(t *testing.T) {
35+
t.Parallel()
36+
37+
defer func() {
38+
_, err := db.ExecContext(ctx, "DROP TABLE IF EXISTS igl_test_insert_obtain")
39+
assert.NoError(t, err, "dropping test database table should not fail")
40+
}()
41+
42+
var err error
43+
if db.DriverName() == PostgreSQL {
44+
_, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id SERIAL PRIMARY KEY, name VARCHAR(255))")
45+
} else {
46+
_, err = db.ExecContext(ctx, "CREATE TABLE igl_test_insert_obtain (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255))")
47+
}
48+
require.NoError(t, err, "creating test database table should not fail")
49+
50+
id, err := InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test1"})
51+
require.NoError(t, err, "inserting new row into test database table should not fail")
52+
assert.Equal(t, id, int64(1))
53+
54+
id, err = InsertObtainID(ctx, db, "INSERT INTO igl_test_insert_obtain (name) VALUES (:name)", map[string]any{"name": "test2"})
55+
require.NoError(t, err, "inserting new row into test database table should not fail")
56+
assert.Equal(t, id, int64(2))
57+
})
58+
}
59+
60+
func setMysqlSessionVars(ctx context.Context, db *DB, t *testing.T) {
2061
vars := map[string][]struct {
2162
name string
2263
value string
@@ -45,14 +86,10 @@ func TestSetMysqlSessionVars(t *testing.T) {
4586
},
4687
}
4788

48-
ctx := context.Background()
49-
db := GetTestDB(ctx, t, "ICINGAGOLIBRARY")
50-
if db.DriverName() != MySQL {
51-
t.Skipf("skipping set session vars test for %q driver", db.DriverName())
52-
}
53-
5489
for name, vs := range vars {
5590
t.Run(name, func(t *testing.T) {
91+
t.Parallel()
92+
5693
for _, v := range vs {
5794
conn, err := db.DB.Conn(ctx)
5895
require.NoError(t, err, "connecting to MySQL/MariaDB database should not fail")

0 commit comments

Comments
 (0)