Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions _example/wal_hook/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package main

import (
"database/sql"
"log"
"os"

"github.com/mattn/go-sqlite3"
)

func main() {
sql.Register("sqlite3_with_wal_hook_example",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
conn.RegisterWalHook(func(dbName string, nPages int) int {
if nPages >= 1 {
if _, err := conn.Exec("PRAGMA wal_checkpoint(TRUNCATE);", nil); err != nil {
log.Fatal(err)
}
}
return sqlite3.SQLITE_OK
})
return nil
},
})
defer os.Remove("./foo.db")

db, err := sql.Open("sqlite3_with_wal_hook_example", "./foo.db?_journal=WAL")
if err != nil {
log.Fatal(err)
}
defer db.Close()

_, err = db.Exec("create table foo(id int, value text)")
if err != nil {
log.Fatal(err)
}

tx, err := db.Begin()
if err != nil {
log.Fatal(err)
}
stmt, err := tx.Prepare("insert into foo(id, value) values(?, ?)")
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for i := 0; i < 100; i++ {
if _, err := stmt.Exec(i, "value"); err != nil {
log.Fatal(err)
}
}
if err := tx.Commit(); err != nil {
log.Fatal(err)
}

var busy, log_, checkpointed int
err = db.QueryRow("PRAGMA wal_checkpoint(PASSIVE);").Scan(&busy, &log_, &checkpointed)
if err != nil {
log.Fatal(err)
}
log.Printf("busy=%d log=%d checkpointed=%d\n", busy, log_, checkpointed) // busy=0 log=0 checkpointed=0
}
6 changes: 6 additions & 0 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ func updateHookTrampoline(handle unsafe.Pointer, op int, db *C.char, table *C.ch
callback(op, C.GoString(db), C.GoString(table), rowid)
}

//export walHookTrampoline
func walHookTrampoline(hadlePtr unsafe.Pointer, _ *C.sqlite3, name *C.char, pages int) int {
callback := lookupHandle(hadlePtr).(func(string, int) int)
return callback(C.GoString(name), pages)
}

//export authorizerTrampoline
func authorizerTrampoline(handle unsafe.Pointer, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int {
callback := lookupHandle(handle).(func(int, string, string, string) int)
Expand Down
22 changes: 22 additions & 0 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ int compareTrampoline(void*, int, char*, int, char*);
int commitHookTrampoline(void*);
void rollbackHookTrampoline(void*);
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
int walHookTrampoline(void *, sqlite3*, const char*, int);

int authorizerTrampoline(void*, int, char*, char*, char*, char*);

Expand Down Expand Up @@ -590,6 +591,27 @@ func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64
}
}

// RegisterWalHook sets the WAL hook for a connection.
//
// The callback is invoked after each commit in WAL mode. The parameters
// are the connection, the database name ("main" or attached), and the
// number of pages currently in the WAL.
//
// The callback should normally return SQLITE_OK (0). A non-zero return
// propagates as an error on the committing statement, though the commit
// itself still occurs.
//
// If there is an existing WAL hook for this connection, it will be
// removed. If callback is nil the existing hook (if any) will be removed
// without creating a new one.
func (c *SQLiteConn) RegisterWalHook(callback func(dbName string, nPages int) int) {
if callback == nil {
C.sqlite3_wal_hook(c.db, nil, nil)
} else {
C.sqlite3_wal_hook(c.db, (*[0]byte)(C.walHookTrampoline), newHandle(c, callback))
}
}

// RegisterAuthorizer sets the authorizer for connection.
//
// The parameters to the callback are the operation (one of the constants
Expand Down
41 changes: 41 additions & 0 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"math/rand"
"net/url"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
Expand Down Expand Up @@ -1786,6 +1787,46 @@ func TestUpdateAndTransactionHooks(t *testing.T) {
}
}

func TestWalHook(t *testing.T) {
var walHookCalled bool
sql.Register("sqlite3_WalHook", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
conn.RegisterWalHook(func(dbName string, nPages int) int {
walHookCalled = true
if dbName != "main" {
t.Errorf("Expected dbName to be 'main', got %q", dbName)
}
if nPages <= 0 {
t.Errorf("Expected nPages to be positive, got %d", nPages)
}
return SQLITE_OK
})
return nil
},
})

dbPath := filepath.Join(t.TempDir(), "test.db?cache=shared&_journal_mode=WAL")
db, err := sql.Open("sqlite3_WalHook", dbPath)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()

_, err = db.Exec("CREATE TABLE foo (id integer primary key)")
if err != nil {
t.Fatal("Failed to create table:", err)
}

_, err = db.Exec("INSERT INTO foo VALUES (1)")
if err != nil {
t.Fatal("Failed to insert:", err)
}

if !walHookCalled {
t.Error("Expected wal hook to be called, but it wasn't")
}
}

func TestAuthorizer(t *testing.T) {
var authorizerReturn = 0

Expand Down