Skip to content

Commit 79adc56

Browse files
authored
SQLite: allow enabling foreign keys in GetConnectionString (dapr#3253)
Signed-off-by: ItalyPaleAle <[email protected]>
1 parent ca00355 commit 79adc56

File tree

4 files changed

+35
-19
lines changed

4 files changed

+35
-19
lines changed

common/authentication/sqlite/metadata.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,14 @@ func (m SqliteAuthMetadata) IsInMemoryDB() bool {
7070
return strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:")
7171
}
7272

73+
// GetConnectionStringOpts contains options for GetConnectionString
74+
type GetConnectionStringOpts struct {
75+
// Enabled foreign keys
76+
EnableForeignKeys bool
77+
}
78+
7379
// GetConnectionString returns the parsed connection string.
74-
func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, error) {
80+
func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger, opts GetConnectionStringOpts) (string, error) {
7581
// Check if we're using the in-memory database
7682
isMemoryDB := m.IsInMemoryDB()
7783

@@ -126,16 +132,20 @@ func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, err
126132

127133
// Add pragma values
128134
if len(qs["_pragma"]) == 0 {
129-
qs["_pragma"] = make([]string, 0, 2)
135+
qs["_pragma"] = make([]string, 0, 3)
130136
} else {
131137
for _, p := range qs["_pragma"] {
132138
p = strings.ToLower(p)
133-
if strings.HasPrefix(p, "busy_timeout") {
139+
switch {
140+
case strings.HasPrefix(p, "busy_timeout"):
134141
log.Error("Cannot set `_pragma=busy_timeout` option in the connection string; please use the `busyTimeout` metadata property instead")
135142
return "", errors.New("found forbidden option '_pragma=busy_timeout' in the connection string")
136-
} else if strings.HasPrefix(p, "journal_mode") {
143+
case strings.HasPrefix(p, "journal_mode"):
137144
log.Error("Cannot set `_pragma=journal_mode` option in the connection string; please use the `disableWAL` metadata property instead")
138145
return "", errors.New("found forbidden option '_pragma=journal_mode' in the connection string")
146+
case strings.HasPrefix(p, "foreign_keys"):
147+
log.Error("Cannot set `_pragma=foreign_keys` option in the connection string")
148+
return "", errors.New("found forbidden option '_pragma=foreign_keys' in the connection string")
139149
}
140150
}
141151
}
@@ -152,6 +162,9 @@ func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, err
152162
// Enable WAL
153163
qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)")
154164
}
165+
if opts.EnableForeignKeys {
166+
qs["_pragma"] = append(qs["_pragma"], "foreign_keys(1)")
167+
}
155168

156169
// Build the final connection string
157170
connString := m.ConnectionString

nameresolution/sqlite/sqlite.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/cenkalti/backoff/v4"
2727
"github.com/google/uuid"
2828

29+
"github.com/dapr/components-contrib/common/authentication/sqlite"
2930
commonsql "github.com/dapr/components-contrib/common/component/sql"
3031
"github.com/dapr/components-contrib/nameresolution"
3132
"github.com/dapr/kit/logger"
@@ -67,7 +68,7 @@ func (s *resolver) Init(ctx context.Context, md nameresolution.Metadata) error {
6768
return err
6869
}
6970

70-
connString, err := s.metadata.GetConnectionString(s.logger)
71+
connString, err := s.metadata.GetConnectionString(s.logger, sqlite.GetConnectionStringOpts{})
7172
if err != nil {
7273
// Already logged
7374
return err

state/sqlite/sqlite_dbaccess.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
// Blank import for the underlying SQLite Driver.
3030
_ "modernc.org/sqlite"
3131

32+
"github.com/dapr/components-contrib/common/authentication/sqlite"
3233
commonsql "github.com/dapr/components-contrib/common/component/sql"
3334
"github.com/dapr/components-contrib/state"
3435
stateutils "github.com/dapr/components-contrib/state/utils"
@@ -77,7 +78,7 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error {
7778
return err
7879
}
7980

80-
connString, err := a.metadata.GetConnectionString(a.logger)
81+
connString, err := a.metadata.GetConnectionString(a.logger, sqlite.GetConnectionStringOpts{})
8182
if err != nil {
8283
// Already logged
8384
return err

state/sqlite/sqlite_test.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/stretchr/testify/assert"
2525
"github.com/stretchr/testify/require"
2626

27+
"github.com/dapr/components-contrib/common/authentication/sqlite"
2728
"github.com/dapr/components-contrib/metadata"
2829
"github.com/dapr/components-contrib/state"
2930
"github.com/dapr/kit/logger"
@@ -48,7 +49,7 @@ func TestGetConnectionString(t *testing.T) {
4849
db.metadata.reset()
4950
db.metadata.ConnectionString = "file:test.db"
5051

51-
connString, err := db.metadata.GetConnectionString(log)
52+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
5253
require.NoError(t, err)
5354

5455
values := url.Values{
@@ -64,7 +65,7 @@ func TestGetConnectionString(t *testing.T) {
6465
db.metadata.reset()
6566
db.metadata.ConnectionString = "test.db"
6667

67-
connString, err := db.metadata.GetConnectionString(log)
68+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
6869
require.NoError(t, err)
6970

7071
values := url.Values{
@@ -82,7 +83,7 @@ func TestGetConnectionString(t *testing.T) {
8283
db.metadata.reset()
8384
db.metadata.ConnectionString = ":memory:"
8485

85-
connString, err := db.metadata.GetConnectionString(log)
86+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
8687
require.NoError(t, err)
8788

8889
values := url.Values{
@@ -103,7 +104,7 @@ func TestGetConnectionString(t *testing.T) {
103104
db.metadata.reset()
104105
db.metadata.ConnectionString = "file:test.db?_txlock=immediate"
105106

106-
connString, err := db.metadata.GetConnectionString(log)
107+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
107108
require.NoError(t, err)
108109

109110
values := url.Values{
@@ -121,7 +122,7 @@ func TestGetConnectionString(t *testing.T) {
121122
db.metadata.reset()
122123
db.metadata.ConnectionString = "file:test.db?_txlock=deferred"
123124

124-
connString, err := db.metadata.GetConnectionString(log)
125+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
125126
require.NoError(t, err)
126127

127128
values := url.Values{
@@ -141,7 +142,7 @@ func TestGetConnectionString(t *testing.T) {
141142
db.metadata.reset()
142143
db.metadata.ConnectionString = "file:test.db?_pragma=busy_timeout(50)"
143144

144-
_, err := db.metadata.GetConnectionString(log)
145+
_, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
145146
require.Error(t, err)
146147
require.ErrorContains(t, err, "found forbidden option '_pragma=busy_timeout' in the connection string")
147148
})
@@ -150,7 +151,7 @@ func TestGetConnectionString(t *testing.T) {
150151
db.metadata.reset()
151152
db.metadata.ConnectionString = "file:test.db?_pragma=journal_mode(WAL)"
152153

153-
_, err := db.metadata.GetConnectionString(log)
154+
_, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
154155
require.Error(t, err)
155156
require.ErrorContains(t, err, "found forbidden option '_pragma=journal_mode' in the connection string")
156157
})
@@ -162,7 +163,7 @@ func TestGetConnectionString(t *testing.T) {
162163
db.metadata.ConnectionString = "file:test.db"
163164
db.metadata.BusyTimeout = time.Second
164165

165-
connString, err := db.metadata.GetConnectionString(log)
166+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
166167
require.NoError(t, err)
167168

168169
values := url.Values{
@@ -179,7 +180,7 @@ func TestGetConnectionString(t *testing.T) {
179180
db.metadata.ConnectionString = "file:test.db"
180181
db.metadata.DisableWAL = false
181182

182-
connString, err := db.metadata.GetConnectionString(log)
183+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
183184
require.NoError(t, err)
184185

185186
values := url.Values{
@@ -195,7 +196,7 @@ func TestGetConnectionString(t *testing.T) {
195196
db.metadata.ConnectionString = "file:test.db"
196197
db.metadata.DisableWAL = true
197198

198-
connString, err := db.metadata.GetConnectionString(log)
199+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
199200
require.NoError(t, err)
200201

201202
values := url.Values{
@@ -210,7 +211,7 @@ func TestGetConnectionString(t *testing.T) {
210211
db.metadata.reset()
211212
db.metadata.ConnectionString = "file::memory:"
212213

213-
connString, err := db.metadata.GetConnectionString(log)
214+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
214215
require.NoError(t, err)
215216

216217
values := url.Values{
@@ -226,7 +227,7 @@ func TestGetConnectionString(t *testing.T) {
226227
db.metadata.reset()
227228
db.metadata.ConnectionString = "file:test.db?mode=ro"
228229

229-
connString, err := db.metadata.GetConnectionString(log)
230+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
230231
require.NoError(t, err)
231232

232233
values := url.Values{
@@ -242,7 +243,7 @@ func TestGetConnectionString(t *testing.T) {
242243
db.metadata.reset()
243244
db.metadata.ConnectionString = "file:test.db?immutable=1"
244245

245-
connString, err := db.metadata.GetConnectionString(log)
246+
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
246247
require.NoError(t, err)
247248

248249
values := url.Values{

0 commit comments

Comments
 (0)