Skip to content

Commit bacaeaa

Browse files
committed
Refactor LockForUpdate() to reuse tx if possible
1 parent 0e8dc6e commit bacaeaa

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

table.go

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -175,56 +175,65 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] {
175175

176176
// LockForUpdate locks and processes records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
177177
// for safe concurrent processing. Each record is processed exactly once across multiple workers.
178-
// Records are automatically updated after updateFn() completes. Complete updateFn() quickly to avoid
178+
// Records are automatically updated after updateFn() completes. Keep updateFn() fast to avoid
179179
// holding the transaction. For long-running work, update status to "processing" and return early,
180-
// then process asynchronously and update status to "completed" or "failed" when done.
181-
func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(pgTx pgx.Tx, records []PT)) error {
180+
// then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed".
181+
func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error {
182+
// Check if we're already in a transaction
183+
if t.DB.Query.Tx != nil {
184+
return t.lockForUpdateWithTx(ctx, t.DB.Query.Tx, cond, orderBy, limit, updateFn)
185+
}
186+
182187
return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error {
183-
if len(orderBy) == 0 {
184-
orderBy = []string{t.IDColumn}
185-
}
188+
return t.lockForUpdateWithTx(ctx, pgTx, cond, orderBy, limit, updateFn)
189+
})
190+
}
186191

187-
tx := t.WithTx(pgTx)
192+
func (t *Table[T, PT, IDT]) lockForUpdateWithTx(ctx context.Context, pgTx pgx.Tx, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error {
193+
if len(orderBy) == 0 {
194+
orderBy = []string{t.IDColumn}
195+
}
188196

189-
q := tx.SQL.
190-
Select("*").
191-
From(t.Name).
192-
Where(cond).
193-
OrderBy(orderBy...).
194-
Limit(limit).
195-
Suffix("FOR UPDATE SKIP LOCKED")
197+
q := t.SQL.
198+
Select("*").
199+
From(t.Name).
200+
Where(cond).
201+
OrderBy(orderBy...).
202+
Limit(limit).
203+
Suffix("FOR UPDATE SKIP LOCKED")
196204

197-
var records []PT
198-
if err := tx.Query.GetAll(ctx, q, &records); err != nil {
199-
return fmt.Errorf("select for update skip locked: %w", err)
200-
}
205+
txQuery := t.DB.TxQuery(pgTx)
206+
207+
var records []PT
208+
if err := txQuery.GetAll(ctx, q, &records); err != nil {
209+
return fmt.Errorf("select for update skip locked: %w", err)
210+
}
201211

202-
updateFn(pgTx, records)
212+
updateFn(records)
203213

204-
for _, record := range records {
205-
q := tx.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name)
206-
if _, err := tx.Query.Exec(ctx, q); err != nil {
207-
return fmt.Errorf("update record: %w", err)
208-
}
214+
for _, record := range records {
215+
q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name)
216+
if _, err := txQuery.Exec(ctx, q); err != nil {
217+
return fmt.Errorf("update record: %w", err)
209218
}
219+
}
210220

211-
return nil
212-
})
221+
return nil
213222
}
214223

215224
// LockOneForUpdate locks and processes one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
216225
// for safe concurrent processing. The record is processed exactly once across multiple workers.
217-
// Records are automatically updated after updateFn() completes. Complete updateFn() quickly to avoid
226+
// The record is automatically updated after updateFn() completes. Keep updateFn() fast to avoid
218227
// holding the transaction. For long-running work, update status to "processing" and return early,
219-
// then process asynchronously and update status to "completed" or "failed" when done.
228+
// then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed".
220229
//
221-
// Returns ErrNoRows if no records match the condition.
222-
func (t *Table[T, PT, IDT]) LockOneForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, updateFn func(pgTx pgx.Tx, record PT)) error {
230+
// Returns ErrNoRows if no matching records are available for locking.
231+
func (t *Table[T, PT, IDT]) LockOneForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, updateFn func(record PT)) error {
223232
var noRows bool
224233

225-
err := t.LockForUpdate(ctx, cond, orderBy, 1, func(pgTx pgx.Tx, records []PT) {
234+
err := t.LockForUpdate(ctx, cond, orderBy, 1, func(records []PT) {
226235
if len(records) > 0 {
227-
updateFn(pgTx, records[0])
236+
updateFn(records[0])
228237
} else {
229238
noRows = true
230239
}

tests/database_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ func initDB(db *pgkit.DB) *Database {
2626

2727
func (db *Database) BeginTx(ctx context.Context, fn func(tx *Database) error) error {
2828
return pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error {
29-
tx := db.WithTxQuery(pgTx)
29+
tx := db.WithTx(pgTx)
3030
return fn(tx)
3131
})
3232
}
3333

34-
func (db *Database) WithTxQuery(tx pgx.Tx) *Database {
34+
func (db *Database) WithTx(tx pgx.Tx) *Database {
3535
pgkitDB := &pgkit.DB{
3636
Conn: db.Conn,
3737
SQL: db.SQL,

0 commit comments

Comments
 (0)