diff --git a/_example/wal_hook/main.go b/_example/wal_hook/main.go new file mode 100644 index 00000000..80e3b3d6 --- /dev/null +++ b/_example/wal_hook/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "database/sql" + "log" + "os" + + "github.com/mattn/go-sqlite3" +) + +func main() { + sql.Register("sqlite3_with_wal_hook_example", + &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + conn.RegisterWalHook(func(dbName string, nPages int) int { + if nPages >= 1 { + if _, err := conn.Exec("PRAGMA wal_checkpoint(TRUNCATE);", nil); err != nil { + log.Fatal(err) + } + } + return sqlite3.SQLITE_OK + }) + return nil + }, + }) + defer os.Remove("./foo.db") + + db, err := sql.Open("sqlite3_with_wal_hook_example", "./foo.db?_journal=WAL") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("create table foo(id int, value text)") + if err != nil { + log.Fatal(err) + } + + tx, err := db.Begin() + if err != nil { + log.Fatal(err) + } + stmt, err := tx.Prepare("insert into foo(id, value) values(?, ?)") + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + for i := 0; i < 100; i++ { + if _, err := stmt.Exec(i, "value"); err != nil { + log.Fatal(err) + } + } + if err := tx.Commit(); err != nil { + log.Fatal(err) + } + + var busy, log_, checkpointed int + err = db.QueryRow("PRAGMA wal_checkpoint(PASSIVE);").Scan(&busy, &log_, &checkpointed) + if err != nil { + log.Fatal(err) + } + log.Printf("busy=%d log=%d checkpointed=%d\n", busy, log_, checkpointed) // busy=0 log=0 checkpointed=0 +} diff --git a/callback.go b/callback.go index 0c518fa2..5c8765ea 100644 --- a/callback.go +++ b/callback.go @@ -76,6 +76,12 @@ func updateHookTrampoline(handle unsafe.Pointer, op int, db *C.char, table *C.ch callback(op, C.GoString(db), C.GoString(table), rowid) } +//export walHookTrampoline +func walHookTrampoline(hadlePtr unsafe.Pointer, _ *C.sqlite3, name *C.char, pages int) int { + callback := lookupHandle(hadlePtr).(func(string, int) int) + return callback(C.GoString(name), pages) +} + //export authorizerTrampoline func authorizerTrampoline(handle unsafe.Pointer, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int { callback := lookupHandle(handle).(func(int, string, string, string) int) diff --git a/sqlite3.go b/sqlite3.go index a967cab0..691989b3 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -167,6 +167,7 @@ int compareTrampoline(void*, int, char*, int, char*); int commitHookTrampoline(void*); void rollbackHookTrampoline(void*); void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64); +int walHookTrampoline(void *, sqlite3*, const char*, int); int authorizerTrampoline(void*, int, char*, char*, char*, char*); @@ -590,6 +591,27 @@ func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64 } } +// RegisterWalHook sets the WAL hook for a connection. +// +// The callback is invoked after each commit in WAL mode. The parameters +// are the connection, the database name ("main" or attached), and the +// number of pages currently in the WAL. +// +// The callback should normally return SQLITE_OK (0). A non-zero return +// propagates as an error on the committing statement, though the commit +// itself still occurs. +// +// If there is an existing WAL hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterWalHook(callback func(dbName string, nPages int) int) { + if callback == nil { + C.sqlite3_wal_hook(c.db, nil, nil) + } else { + C.sqlite3_wal_hook(c.db, (*[0]byte)(C.walHookTrampoline), newHandle(c, callback)) + } +} + // RegisterAuthorizer sets the authorizer for connection. // // The parameters to the callback are the operation (one of the constants diff --git a/sqlite3_test.go b/sqlite3_test.go index ed8fa646..0506843b 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "net/url" "os" + "path/filepath" "reflect" "regexp" "runtime" @@ -1786,6 +1787,46 @@ func TestUpdateAndTransactionHooks(t *testing.T) { } } +func TestWalHook(t *testing.T) { + var walHookCalled bool + sql.Register("sqlite3_WalHook", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + conn.RegisterWalHook(func(dbName string, nPages int) int { + walHookCalled = true + if dbName != "main" { + t.Errorf("Expected dbName to be 'main', got %q", dbName) + } + if nPages <= 0 { + t.Errorf("Expected nPages to be positive, got %d", nPages) + } + return SQLITE_OK + }) + return nil + }, + }) + + dbPath := filepath.Join(t.TempDir(), "test.db?cache=shared&_journal_mode=WAL") + db, err := sql.Open("sqlite3_WalHook", dbPath) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo (id integer primary key)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("INSERT INTO foo VALUES (1)") + if err != nil { + t.Fatal("Failed to insert:", err) + } + + if !walHookCalled { + t.Error("Expected wal hook to be called, but it wasn't") + } +} + func TestAuthorizer(t *testing.T) { var authorizerReturn = 0