diff --git a/blob_io.go b/blob_io.go new file mode 100644 index 00000000..2acf2133 --- /dev/null +++ b/blob_io.go @@ -0,0 +1,169 @@ +// Copyright (C) 2022 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package sqlite3 + +/* +#ifndef USE_LIBSQLITE3 +#include "sqlite3-binding.h" +#else +#include +#endif +#include +*/ +import "C" + +import ( + "errors" + "fmt" + "io" + "math" + "runtime" + "unsafe" +) + +// SQLiteBlob implements the SQLite Blob I/O interface. +type SQLiteBlob struct { + conn *SQLiteConn + blob *C.sqlite3_blob + size int + offset int +} + +// Blob opens a blob. +// +// See https://www.sqlite.org/c3ref/blob_open.html for usage. +// +// Should only be used with conn.Raw. +func (conn *SQLiteConn) Blob(database, table, column string, rowid int64, flags int) (*SQLiteBlob, error) { + databaseptr := C.CString(database) + defer C.free(unsafe.Pointer(databaseptr)) + + tableptr := C.CString(table) + defer C.free(unsafe.Pointer(tableptr)) + + columnptr := C.CString(column) + defer C.free(unsafe.Pointer(columnptr)) + + var blob *C.sqlite3_blob + ret := C.sqlite3_blob_open(conn.db, databaseptr, tableptr, columnptr, C.longlong(rowid), C.int(flags), &blob) + + if ret != C.SQLITE_OK { + return nil, conn.lastError() + } + + size := int(C.sqlite3_blob_bytes(blob)) + bb := &SQLiteBlob{conn: conn, blob: blob, size: size, offset: 0} + + runtime.SetFinalizer(bb, (*SQLiteBlob).Close) + + return bb, nil +} + +// Read implements the io.Reader interface. +func (s *SQLiteBlob) Read(b []byte) (n int, err error) { + if s.offset >= s.size { + return 0, io.EOF + } + + if len(b) == 0 { + return 0, nil + } + + n = s.size - s.offset + if len(b) < n { + n = len(b) + } + + p := &b[0] + ret := C.sqlite3_blob_read(s.blob, unsafe.Pointer(p), C.int(n), C.int(s.offset)) + if ret != C.SQLITE_OK { + return 0, s.conn.lastError() + } + + s.offset += n + + return n, nil +} + +// Write implements the io.Writer interface. +func (s *SQLiteBlob) Write(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + + if s.offset >= s.size { + return 0, fmt.Errorf("sqlite3.SQLiteBlob.Write: insufficient space in %d-byte blob", s.size) + } + + n = s.size - s.offset + if len(b) < n { + n = len(b) + } + + if n != len(b) { + return 0, fmt.Errorf("sqlite3.SQLiteBlob.Write: insufficient space in %d-byte blob", s.size) + } + + p := &b[0] + ret := C.sqlite3_blob_write(s.blob, unsafe.Pointer(p), C.int(n), C.int(s.offset)) + if ret != C.SQLITE_OK { + return 0, s.conn.lastError() + } + + s.offset += n + + return n, nil +} + +// Seek implements the io.Seeker interface. +func (s *SQLiteBlob) Seek(offset int64, whence int) (int64, error) { + if offset > math.MaxInt32 { + return 0, fmt.Errorf("sqlite3.SQLiteBlob.Seek: invalid offset %d", offset) + } + + var abs int64 + switch whence { + case io.SeekStart: + abs = offset + case io.SeekCurrent: + abs = int64(s.offset) + offset + case io.SeekEnd: + abs = int64(s.size) + offset + default: + return 0, fmt.Errorf("sqlite3.SQLiteBlob.Seek: invalid whence %d", whence) + } + + if abs < 0 { + return 0, errors.New("sqlite.SQLiteBlob.Seek: negative position") + } + + if abs > math.MaxInt32 || abs > int64(s.size) { + return 0, errors.New("sqlite3.SQLiteBlob.Seek: overflow position") + } + + s.offset = int(abs) + + return abs, nil +} + +// Size returns the size of the blob. +func (s *SQLiteBlob) Size() int { + return s.size +} + +// Close implements the io.Closer interface. +func (s *SQLiteBlob) Close() error { + ret := C.sqlite3_blob_close(s.blob) + + s.blob = nil + runtime.SetFinalizer(s, nil) + + if ret != C.SQLITE_OK { + return s.conn.lastError() + } + + return nil +} diff --git a/blob_io_test.go b/blob_io_test.go new file mode 100644 index 00000000..01941ca1 --- /dev/null +++ b/blob_io_test.go @@ -0,0 +1,252 @@ +// Copyright (C) 2022 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build cgo +// +build cgo + +package sqlite3 + +import ( + "bytes" + "context" + "database/sql" + "io" + "testing" +) + +// Verify interface implementations +var _ io.Reader = &SQLiteBlob{} +var _ io.Writer = &SQLiteBlob{} +var _ io.Seeker = &SQLiteBlob{} +var _ io.Closer = &SQLiteBlob{} + +type driverConnCallback func(*testing.T, *SQLiteConn) + +func blobTestData(t *testing.T, dbname string, rowid int64, blob []byte, c driverConnCallback) { + + // This test uses :memory: for compatibility with SQLite versions < 3.37.0. + // Using memdb vfs is the right way to do this for more recent versions. + + // db, err := sql.Open("sqlite3", "file:/"+dbname+"?vfs=memdb") + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + db.SetMaxOpenConns(1) + + // Test data + query := ` + CREATE TABLE data ( + value BLOB + ); + + INSERT INTO data (_rowid_, value) + VALUES (:rowid, :value); + ` + + _, err = db.Exec(query, sql.Named("rowid", rowid), sql.Named("value", blob)) + if err != nil { + t.Fatal(err) + } + + // Get raw connection + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + var driverConn *SQLiteConn + err = conn.Raw(func(conn interface{}) error { + driverConn = conn.(*SQLiteConn) + c(t, driverConn) + return nil + }) + if err != nil { + t.Fatal(err) + } + defer driverConn.Close() +} + +func TestBlobRead(t *testing.T) { + rowid := int64(6581) + expected := []byte("I ❤️ SQLite in \x00\x01\x02…") + + blobTestData(t, "testblobread", rowid, expected, func(t *testing.T, driverConn *SQLiteConn) { + + // Open blob + blob, err := driverConn.Blob("main", "data", "value", rowid, 0) + if err != nil { + t.Error("failed", err) + } + defer blob.Close() + + // Read blob incrementally + middle := len(expected) / 2 + first := expected[:middle] + second := expected[middle:] + + // Read part Ⅰ + b1 := make([]byte, len(first)) + n1, err := blob.Read(b1) + + if err != nil || n1 != len(b1) { + t.Errorf("Failed to read %d bytes", n1) + } + + if bytes.Compare(first, b1) != 0 { + t.Error("Expected\n", first, "got\n", b1) + } + + // Read part Ⅱ + b2 := make([]byte, len(second)) + n2, err := blob.Read(b2) + + if err != nil || n2 != len(b2) { + t.Errorf("Failed to read %d bytes", n2) + } + + if bytes.Compare(second, b2) != 0 { + t.Error("Expected\n", second, "got\n", b2) + } + + // EOF + b3 := make([]byte, 10) + n3, err := blob.Read(b3) + + if err != io.EOF || n3 != 0 { + t.Error("Expected EOF", err) + } + }) +} + +func TestBlobWrite(t *testing.T) { + rowid := int64(8580) + expected := []byte{ + // Random data from /dev/urandom + 0xe5, 0x48, 0x94, 0xad, 0xa6, 0x7c, 0x81, 0xa2, 0x70, 0x07, 0x79, 0x60, + 0x33, 0xbc, 0x64, 0x33, 0x8f, 0x48, 0x43, 0xa6, 0x33, 0x5c, 0x08, 0x32, + } + + // Allocate a zero blob + data := make([]byte, len(expected)) + blobTestData(t, "testblobwrite", rowid, data, func(t *testing.T, driverConn *SQLiteConn) { + + // Open blob for read/write + blob, err := driverConn.Blob("main", "data", "value", rowid, 1) + if err != nil { + t.Error("failed", err) + } + defer blob.Close() + + // Write blob incrementally + middle := len(expected) / 2 + first := expected[:middle] + second := expected[middle:] + + // Write part Ⅰ + n1, err := blob.Write(first) + + if err != nil || n1 != len(first) { + t.Errorf("Failed to write %d bytes", n1) + } + + // Write part Ⅱ + n2, err := blob.Write(second) + + if err != nil || n2 != len(second) { + t.Errorf("Failed to write %d bytes", n2) + } + + // Insufficient space + b3 := make([]byte, 10) + n3, err := blob.Write(b3) + + if err.Error() != "sqlite3.SQLiteBlob.Write: insufficient space in 24-byte blob" || n3 != 0 { + t.Error("Expected insufficient space error", err, n3) + } + + // Verify written data + _, err = blob.Seek(0, io.SeekStart) + if err != nil { + t.Fatal("Failed to seek:", err) + } + + b4 := make([]byte, len(expected)) + n4, err := blob.Read(b4) + + if err != nil || n4 != len(b4) { + t.Errorf("Failed to read %d bytes", n4) + } + + if bytes.Compare(expected, b4) != 0 { + t.Error("Expected\n", expected, "got\n", b4) + } + + }) +} + +func TestBlobSeek(t *testing.T) { + rowid := int64(6510) + data := make([]byte, 1000) + + blobTestData(t, "testblobseek", rowid, data, func(t *testing.T, driverConn *SQLiteConn) { + + // Open blob + blob, err := driverConn.Blob("main", "data", "value", rowid, 0) + if err != nil { + t.Error("failed", err) + } + defer blob.Close() + + // Test data + begin := int64(0) + middle := int64(len(data) / 2) + end := int64(len(data) - 1) + eof := int64(len(data)) + + tests := []struct { + offset int64 + whence int + expected int64 + }{ + {offset: begin, whence: io.SeekStart, expected: begin}, + {offset: middle, whence: io.SeekStart, expected: middle}, + {offset: end, whence: io.SeekStart, expected: end}, + {offset: eof, whence: io.SeekStart, expected: eof}, + + {offset: -1, whence: io.SeekCurrent, expected: middle - 1}, + {offset: 0, whence: io.SeekCurrent, expected: middle}, + {offset: 1, whence: io.SeekCurrent, expected: middle + 1}, + {offset: -middle, whence: io.SeekCurrent, expected: begin}, + + {offset: -2, whence: io.SeekEnd, expected: end - 1}, + {offset: -1, whence: io.SeekEnd, expected: end}, + {offset: 0, whence: io.SeekEnd, expected: eof}, + {offset: -eof, whence: io.SeekEnd, expected: begin}, + } + + for _, tc := range tests { + // Start in the middle + _, err := blob.Seek(middle, io.SeekStart) + if err != nil { + t.Fatal("Failed to seek:", err) + } + + // Test + got, err := blob.Seek(tc.offset, tc.whence) + if err != nil { + t.Fatal("Failed to seek:", err) + } + + if tc.expected != got { + t.Error("For", tc, "expected", tc.expected, "got", got) + } + } + + }) +}