Skip to content

Commit 41a0a39

Browse files
committed
Add function to register wal hook
1 parent 8bf7a8a commit 41a0a39

File tree

4 files changed

+132
-0
lines changed

4 files changed

+132
-0
lines changed

_example/wal_hook/main.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package main
2+
3+
import (
4+
"database/sql"
5+
"log"
6+
"os"
7+
8+
"github.com/mattn/go-sqlite3"
9+
)
10+
11+
func main() {
12+
sql.Register("sqlite3_with_wal_hook_example",
13+
&sqlite3.SQLiteDriver{
14+
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
15+
conn.RegisterWalHook(func(dbName string, nPages int) int {
16+
if nPages >= 1 {
17+
if _, err := conn.Exec("PRAGMA wal_checkpoint(TRUNCATE);", nil); err != nil {
18+
log.Fatal(err)
19+
}
20+
}
21+
return sqlite3.SQLITE_OK
22+
})
23+
return nil
24+
},
25+
})
26+
defer os.Remove("./foo.db")
27+
28+
db, err := sql.Open("sqlite3_with_wal_hook_example", "./foo.db?_journal=WAL")
29+
if err != nil {
30+
log.Fatal(err)
31+
}
32+
defer db.Close()
33+
34+
_, err = db.Exec("create table foo(id int, value text)")
35+
if err != nil {
36+
log.Fatal(err)
37+
}
38+
39+
tx, err := db.Begin()
40+
if err != nil {
41+
log.Fatal(err)
42+
}
43+
stmt, err := tx.Prepare("insert into foo(id, value) values(?, ?)")
44+
if err != nil {
45+
log.Fatal(err)
46+
}
47+
defer stmt.Close()
48+
for i := 0; i < 100; i++ {
49+
if _, err := stmt.Exec(i, "value"); err != nil {
50+
log.Fatal(err)
51+
}
52+
}
53+
if err := tx.Commit(); err != nil {
54+
log.Fatal(err)
55+
}
56+
57+
var busy, log_, checkpointed int
58+
err = db.QueryRow("PRAGMA wal_checkpoint(PASSIVE);").Scan(&busy, &log_, &checkpointed)
59+
if err != nil {
60+
log.Fatal(err)
61+
}
62+
log.Printf("busy=%d log=%d checkpointed=%d\n", busy, log_, checkpointed) // busy=0 log=0 checkpointed=0
63+
}

callback.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ func updateHookTrampoline(handle unsafe.Pointer, op int, db *C.char, table *C.ch
7676
callback(op, C.GoString(db), C.GoString(table), rowid)
7777
}
7878

79+
//export walHookTrampoline
80+
func walHookTrampoline(hadlePtr unsafe.Pointer, _ *C.sqlite3, name *C.char, pages int) int {
81+
callback := lookupHandle(hadlePtr).(func(string, int) int)
82+
return callback(C.GoString(name), pages)
83+
}
84+
7985
//export authorizerTrampoline
8086
func authorizerTrampoline(handle unsafe.Pointer, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int {
8187
callback := lookupHandle(handle).(func(int, string, string, string) int)

sqlite3.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ int compareTrampoline(void*, int, char*, int, char*);
167167
int commitHookTrampoline(void*);
168168
void rollbackHookTrampoline(void*);
169169
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
170+
int walHookTrampoline(void *, sqlite3*, const char*, int);
170171
171172
int authorizerTrampoline(void*, int, char*, char*, char*, char*);
172173
@@ -590,6 +591,27 @@ func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64
590591
}
591592
}
592593

594+
// RegisterWalHook sets the WAL hook for a connection.
595+
//
596+
// The callback is invoked after each commit in WAL mode. The parameters
597+
// are the connection, the database name ("main" or attached), and the
598+
// number of pages currently in the WAL.
599+
//
600+
// The callback should normally return SQLITE_OK (0). A non-zero return
601+
// propagates as an error on the committing statement, though the commit
602+
// itself still occurs.
603+
//
604+
// If there is an existing WAL hook for this connection, it will be
605+
// removed. If callback is nil the existing hook (if any) will be removed
606+
// without creating a new one.
607+
func (c *SQLiteConn) RegisterWalHook(callback func(dbName string, nPages int) int) {
608+
if callback == nil {
609+
C.sqlite3_wal_hook(c.db, nil, nil)
610+
} else {
611+
C.sqlite3_wal_hook(c.db, (*[0]byte)(C.walHookTrampoline), newHandle(c, callback))
612+
}
613+
}
614+
593615
// RegisterAuthorizer sets the authorizer for connection.
594616
//
595617
// The parameters to the callback are the operation (one of the constants

sqlite3_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"math/rand"
1919
"net/url"
2020
"os"
21+
"path/filepath"
2122
"reflect"
2223
"regexp"
2324
"runtime"
@@ -1786,6 +1787,46 @@ func TestUpdateAndTransactionHooks(t *testing.T) {
17861787
}
17871788
}
17881789

1790+
func TestWalHook(t *testing.T) {
1791+
var walHookCalled bool
1792+
sql.Register("sqlite3_WalHook", &SQLiteDriver{
1793+
ConnectHook: func(conn *SQLiteConn) error {
1794+
conn.RegisterWalHook(func(dbName string, nPages int) int {
1795+
walHookCalled = true
1796+
if dbName != "main" {
1797+
t.Errorf("Expected dbName to be 'main', got %q", dbName)
1798+
}
1799+
if nPages <= 0 {
1800+
t.Errorf("Expected nPages to be positive, got %d", nPages)
1801+
}
1802+
return SQLITE_OK
1803+
})
1804+
return nil
1805+
},
1806+
})
1807+
1808+
dbPath := filepath.Join(t.TempDir(), "test.db?cache=shared&_journal_mode=WAL")
1809+
db, err := sql.Open("sqlite3_WalHook", dbPath)
1810+
if err != nil {
1811+
t.Fatal("Failed to open database:", err)
1812+
}
1813+
defer db.Close()
1814+
1815+
_, err = db.Exec("CREATE TABLE foo (id integer primary key)")
1816+
if err != nil {
1817+
t.Fatal("Failed to create table:", err)
1818+
}
1819+
1820+
_, err = db.Exec("INSERT INTO foo VALUES (1)")
1821+
if err != nil {
1822+
t.Fatal("Failed to insert:", err)
1823+
}
1824+
1825+
if !walHookCalled {
1826+
t.Error("Expected wal hook to be called, but it wasn't")
1827+
}
1828+
}
1829+
17891830
func TestAuthorizer(t *testing.T) {
17901831
var authorizerReturn = 0
17911832

0 commit comments

Comments
 (0)