Skip to content

Commit d9fcf60

Browse files
committed
Driver API.
1 parent ac6dd1a commit d9fcf60

File tree

6 files changed

+96
-27
lines changed

6 files changed

+96
-27
lines changed

conn.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package sqlite3
22

33
import (
44
"context"
5-
"database/sql/driver"
65
"errors"
76
"fmt"
87
"net/url"
@@ -240,6 +239,11 @@ func (c *Conn) Changes() int64 {
240239
//
241240
// https://www.sqlite.org/c3ref/interrupt.html
242241
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
242+
// Is it the same context?
243+
if ctx == c.interrupt {
244+
return ctx
245+
}
246+
243247
// Is a waiter running?
244248
if c.waiter != nil {
245249
c.waiter <- struct{}{} // Cancel the waiter.
@@ -331,15 +335,5 @@ func (c *Conn) error(rc uint64, sql ...string) error {
331335
// [online backup]: https://www.sqlite.org/backup.html
332336
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
333337
type DriverConn interface {
334-
driver.Conn
335-
driver.ConnBeginTx
336-
driver.ExecerContext
337-
driver.ConnPrepareContext
338-
339-
SetInterrupt(ctx context.Context) (old context.Context)
340-
341-
Savepoint() Savepoint
342-
Backup(srcDB, dstURI string) error
343-
Restore(dstDB, srcURI string) error
344-
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
338+
Raw() *Conn
345339
}

driver/driver.go

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,46 @@ import (
4040
"github.com/ncruces/go-sqlite3/internal/util"
4141
)
4242

43+
// This variable can be replaced with -ldflags:
44+
//
45+
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
46+
var driverName = "sqlite3"
47+
4348
func init() {
44-
sql.Register("sqlite3", sqlite{})
49+
if driverName != "" {
50+
sql.Register(driverName, sqlite{})
51+
}
52+
}
53+
54+
// Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
55+
//
56+
// The init function is called by the driver on new connections.
57+
// The conn can be used to execute queries, register functions, etc.
58+
// Any error return closes the conn and passes the error to database/sql.
59+
func Open(dataSourceName string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*sql.DB, error) {
60+
c, err := newConnector(dataSourceName, init)
61+
if err != nil {
62+
return nil, err
63+
}
64+
return sql.OpenDB(c), nil
4565
}
4666

4767
type sqlite struct{}
4868

4969
func (sqlite) Open(name string) (driver.Conn, error) {
50-
c, err := sqlite{}.OpenConnector(name)
70+
c, err := newConnector(name, nil)
5171
if err != nil {
5272
return nil, err
5373
}
5474
return c.Connect(context.Background())
5575
}
5676

5777
func (sqlite) OpenConnector(name string) (driver.Connector, error) {
58-
c := connector{name: name}
78+
return newConnector(name, nil)
79+
}
80+
81+
func newConnector(name string, init func(ctx context.Context, conn *sqlite3.Conn) error) (*connector, error) {
82+
c := connector{name: name, init: init}
5983
if strings.HasPrefix(name, "file:") {
6084
if _, after, ok := strings.Cut(name, "?"); ok {
6185
query, err := url.ParseQuery(after)
@@ -73,6 +97,7 @@ type connector struct {
7397
name string
7498
txlock string
7599
pragmas bool
100+
init func(ctx context.Context, conn *sqlite3.Conn) error
76101
}
77102

78103
func (n *connector) Driver() driver.Driver {
@@ -126,6 +151,12 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
126151
return nil, err
127152
}
128153
}
154+
if n.init != nil {
155+
err = n.init(ctx, c.Conn)
156+
if err != nil {
157+
return nil, err
158+
}
159+
}
129160
return &c, nil
130161
}
131162

@@ -140,12 +171,17 @@ type conn struct {
140171

141172
var (
142173
// Ensure these interfaces are implemented:
143-
_ driver.ExecerContext = &conn{}
144-
_ driver.ConnBeginTx = &conn{}
145-
_ driver.Validator = &conn{}
146-
_ sqlite3.DriverConn = &conn{}
174+
_ driver.ConnPrepareContext = &conn{}
175+
_ driver.ExecerContext = &conn{}
176+
_ driver.ConnBeginTx = &conn{}
177+
_ driver.Validator = &conn{}
178+
_ sqlite3.DriverConn = &conn{}
147179
)
148180

181+
func (c *conn) Raw() *sqlite3.Conn {
182+
return c.Conn
183+
}
184+
149185
func (c *conn) IsValid() bool {
150186
return c.reusable
151187
}
@@ -190,7 +226,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
190226

191227
func (c *conn) Commit() error {
192228
err := c.Conn.Exec(c.txCommit)
193-
if err != nil && !c.GetAutocommit() {
229+
if err != nil && !c.Conn.GetAutocommit() {
194230
c.Rollback()
195231
}
196232
return err

driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func ExampleDriverConn() {
4747
}
4848

4949
err = conn.Raw(func(driverConn any) error {
50-
conn := driverConn.(sqlite3.DriverConn)
50+
conn := driverConn.(sqlite3.DriverConn).Raw()
5151
savept := conn.Savepoint()
5252
defer savept.Release(&err)
5353

gormlite/sqlite.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package gormlite
33

44
import (
55
"context"
6-
"database/sql"
76
"strconv"
87

98
"gorm.io/gorm"
@@ -13,7 +12,7 @@ import (
1312
"gorm.io/gorm/migrator"
1413
"gorm.io/gorm/schema"
1514

16-
_ "github.com/ncruces/go-sqlite3/driver"
15+
"github.com/ncruces/go-sqlite3/driver"
1716
)
1817

1918
type Dialector struct {
@@ -33,7 +32,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
3332
if dialector.Conn != nil {
3433
db.ConnPool = dialector.Conn
3534
} else {
36-
conn, err := sql.Open("sqlite3", dialector.DSN)
35+
conn, err := driver.Open(dialector.DSN, nil)
3736
if err != nil {
3837
return err
3938
}

gormlite/sqlite_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
11
package gormlite
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67

78
"gorm.io/gorm"
89

10+
"github.com/ncruces/go-sqlite3"
11+
"github.com/ncruces/go-sqlite3/driver"
912
_ "github.com/ncruces/go-sqlite3/embed"
1013
)
1114

1215
func TestDialector(t *testing.T) {
1316
// This is the DSN of the in-memory SQLite database for these tests.
1417
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
1518

19+
// Custom connection with a custom function called "my_custom_function".
20+
conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error {
21+
return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC,
22+
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
23+
ctx.ResultText("my-result")
24+
})
25+
})
26+
if err != nil {
27+
t.Fatal(err)
28+
}
29+
1630
rows := []struct {
1731
description string
1832
dialector *Dialector
@@ -29,6 +43,33 @@ func TestDialector(t *testing.T) {
2943
query: "SELECT 1",
3044
querySuccess: true,
3145
},
46+
{
47+
description: "Custom function",
48+
dialector: &Dialector{
49+
DSN: InMemoryDSN,
50+
},
51+
openSuccess: true,
52+
query: "SELECT my_custom_function()",
53+
querySuccess: false,
54+
},
55+
{
56+
description: "Custom connection",
57+
dialector: &Dialector{
58+
Conn: conn,
59+
},
60+
openSuccess: true,
61+
query: "SELECT 1",
62+
querySuccess: true,
63+
},
64+
{
65+
description: "Custom connection, custom function",
66+
dialector: &Dialector{
67+
Conn: conn,
68+
},
69+
openSuccess: true,
70+
query: "SELECT my_custom_function()",
71+
querySuccess: true,
72+
},
3273
}
3374
for rowIndex, row := range rows {
3475
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {

tests/driver_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ package tests
22

33
import (
44
"context"
5-
"database/sql"
65
"testing"
76

8-
_ "github.com/ncruces/go-sqlite3/driver"
7+
"github.com/ncruces/go-sqlite3/driver"
98
_ "github.com/ncruces/go-sqlite3/embed"
109
)
1110

@@ -15,7 +14,7 @@ func TestDriver(t *testing.T) {
1514
ctx, cancel := context.WithCancel(context.Background())
1615
defer cancel()
1716

18-
db, err := sql.Open("sqlite3", ":memory:")
17+
db, err := driver.Open(":memory:", nil)
1918
if err != nil {
2019
t.Fatal(err)
2120
}

0 commit comments

Comments
 (0)