Skip to content

Commit e7db4cf

Browse files
Move SQLite "auth" metadata to a separate package internal/authentication/sqlite (#3135)
Signed-off-by: ItalyPaleAle <[email protected]> Co-authored-by: Bernd Verst <[email protected]>
1 parent a874485 commit e7db4cf

File tree

9 files changed

+464
-212
lines changed

9 files changed

+464
-212
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package sqlite
15+
16+
import (
17+
"errors"
18+
"fmt"
19+
"net/url"
20+
"strings"
21+
"time"
22+
23+
"github.com/dapr/kit/logger"
24+
)
25+
26+
const (
27+
DefaultTimeout = 20 * time.Second // Default timeout for database requests, in seconds
28+
DefaultBusyTimeout = 2 * time.Second
29+
)
30+
31+
// SqliteAuthMetadata contains the auth metadata for a SQLite component.
32+
type SqliteAuthMetadata struct {
33+
ConnectionString string `mapstructure:"connectionString" mapstructurealiases:"url"`
34+
Timeout time.Duration `mapstructure:"timeout" mapstructurealiases:"timeoutInSeconds"`
35+
BusyTimeout time.Duration `mapstructure:"busyTimeout"`
36+
DisableWAL bool `mapstructure:"disableWAL"` // Disable WAL journaling. You should not use WAL if the database is stored on a network filesystem (or data corruption may happen). This is ignored if the database is in-memory.
37+
}
38+
39+
// Reset the object
40+
func (m *SqliteAuthMetadata) Reset() {
41+
m.ConnectionString = ""
42+
m.Timeout = DefaultTimeout
43+
m.BusyTimeout = DefaultBusyTimeout
44+
m.DisableWAL = false
45+
}
46+
47+
func (m *SqliteAuthMetadata) Validate() error {
48+
// Validate and sanitize input
49+
if m.ConnectionString == "" {
50+
return errors.New("missing connection string")
51+
}
52+
if m.Timeout < time.Second {
53+
return errors.New("invalid value for 'timeout': must be greater than 1s")
54+
}
55+
56+
// Busy timeout
57+
// Truncate values to milliseconds. Values <= 0 do not set any timeout
58+
m.BusyTimeout = m.BusyTimeout.Truncate(time.Millisecond)
59+
60+
return nil
61+
}
62+
63+
func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, error) {
64+
// Check if we're using the in-memory database
65+
lc := strings.ToLower(m.ConnectionString)
66+
isMemoryDB := strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:")
67+
68+
// Get the "query string" from the connection string if present
69+
idx := strings.IndexRune(m.ConnectionString, '?')
70+
var qs url.Values
71+
if idx > 0 {
72+
qs, _ = url.ParseQuery(m.ConnectionString[(idx + 1):])
73+
}
74+
if len(qs) == 0 {
75+
qs = make(url.Values, 2)
76+
}
77+
78+
// If the database is in-memory, we must ensure that cache=shared is set
79+
if isMemoryDB {
80+
qs["cache"] = []string{"shared"}
81+
}
82+
83+
// Check if the database is read-only or immutable
84+
isReadOnly := false
85+
if len(qs["mode"]) > 0 {
86+
// Keep the first value only
87+
qs["mode"] = []string{
88+
qs["mode"][0],
89+
}
90+
if qs["mode"][0] == "ro" {
91+
isReadOnly = true
92+
}
93+
}
94+
if len(qs["immutable"]) > 0 {
95+
// Keep the first value only
96+
qs["immutable"] = []string{
97+
qs["immutable"][0],
98+
}
99+
if qs["immutable"][0] == "1" {
100+
isReadOnly = true
101+
}
102+
}
103+
104+
// We do not want to override a _txlock if set, but we'll show a warning if it's not "immediate"
105+
if len(qs["_txlock"]) > 0 {
106+
// Keep the first value only
107+
qs["_txlock"] = []string{
108+
strings.ToLower(qs["_txlock"][0]),
109+
}
110+
if qs["_txlock"][0] != "immediate" {
111+
log.Warn("Database connection is being created with a _txlock different from the recommended value 'immediate'")
112+
}
113+
} else {
114+
qs["_txlock"] = []string{"immediate"}
115+
}
116+
117+
// Add pragma values
118+
if len(qs["_pragma"]) == 0 {
119+
qs["_pragma"] = make([]string, 0, 2)
120+
} else {
121+
for _, p := range qs["_pragma"] {
122+
p = strings.ToLower(p)
123+
if strings.HasPrefix(p, "busy_timeout") {
124+
log.Error("Cannot set `_pragma=busy_timeout` option in the connection string; please use the `busyTimeout` metadata property instead")
125+
return "", errors.New("found forbidden option '_pragma=busy_timeout' in the connection string")
126+
} else if strings.HasPrefix(p, "journal_mode") {
127+
log.Error("Cannot set `_pragma=journal_mode` option in the connection string; please use the `disableWAL` metadata property instead")
128+
return "", errors.New("found forbidden option '_pragma=journal_mode' in the connection string")
129+
}
130+
}
131+
}
132+
if m.BusyTimeout > 0 {
133+
qs["_pragma"] = append(qs["_pragma"], fmt.Sprintf("busy_timeout(%d)", m.BusyTimeout.Milliseconds()))
134+
}
135+
if isMemoryDB {
136+
// For in-memory databases, set the journal to MEMORY, the only allowed option besides OFF (which would make transactions ineffective)
137+
qs["_pragma"] = append(qs["_pragma"], "journal_mode(MEMORY)")
138+
} else if m.DisableWAL || isReadOnly {
139+
// Set the journaling mode to "DELETE" (the default) if WAL is disabled or if the database is read-only
140+
qs["_pragma"] = append(qs["_pragma"], "journal_mode(DELETE)")
141+
} else {
142+
// Enable WAL
143+
qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)")
144+
}
145+
146+
// Build the final connection string
147+
connString := m.ConnectionString
148+
if idx > 0 {
149+
connString = connString[:idx]
150+
}
151+
connString += "?" + qs.Encode()
152+
153+
// If the connection string doesn't begin with "file:", add the prefix
154+
if !strings.HasPrefix(lc, "file:") {
155+
log.Debug("prefix 'file:' added to the connection string")
156+
connString = "file:" + connString
157+
}
158+
159+
return connString, nil
160+
}
161+
162+
// Validates an identifier, such as table or DB name.
163+
func ValidIdentifier(v string) bool {
164+
if v == "" {
165+
return false
166+
}
167+
168+
// Loop through the string as byte slice as we only care about ASCII characters
169+
b := []byte(v)
170+
for i := 0; i < len(b); i++ {
171+
if (b[i] >= '0' && b[i] <= '9') ||
172+
(b[i] >= 'a' && b[i] <= 'z') ||
173+
(b[i] >= 'A' && b[i] <= 'Z') ||
174+
b[i] == '_' {
175+
continue
176+
}
177+
return false
178+
}
179+
return true
180+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package sqlite
15+
16+
import (
17+
"testing"
18+
"time"
19+
20+
"github.com/stretchr/testify/assert"
21+
"github.com/stretchr/testify/require"
22+
23+
"github.com/dapr/components-contrib/metadata"
24+
"github.com/dapr/components-contrib/state"
25+
)
26+
27+
func TestSqliteMetadata(t *testing.T) {
28+
stateMetadata := func(props map[string]string) state.Metadata {
29+
return state.Metadata{Base: metadata.Base{Properties: props}}
30+
}
31+
32+
t.Run("default options", func(t *testing.T) {
33+
md := &SqliteAuthMetadata{}
34+
md.Reset()
35+
36+
err := metadata.DecodeMetadata(stateMetadata(map[string]string{
37+
"connectionString": "file:data.db",
38+
}), &md)
39+
require.NoError(t, err)
40+
41+
err = md.Validate()
42+
43+
require.NoError(t, err)
44+
assert.Equal(t, "file:data.db", md.ConnectionString)
45+
assert.Equal(t, DefaultTimeout, md.Timeout)
46+
assert.Equal(t, DefaultBusyTimeout, md.BusyTimeout)
47+
assert.False(t, md.DisableWAL)
48+
})
49+
50+
t.Run("empty connection string", func(t *testing.T) {
51+
md := &SqliteAuthMetadata{}
52+
md.Reset()
53+
54+
err := metadata.DecodeMetadata(stateMetadata(map[string]string{}), &md)
55+
require.NoError(t, err)
56+
57+
err = md.Validate()
58+
59+
require.Error(t, err)
60+
assert.ErrorContains(t, err, "missing connection string")
61+
})
62+
63+
t.Run("invalid timeout", func(t *testing.T) {
64+
md := &SqliteAuthMetadata{}
65+
md.Reset()
66+
67+
err := metadata.DecodeMetadata(stateMetadata(map[string]string{
68+
"connectionString": "file:data.db",
69+
"timeout": "500ms",
70+
}), &md)
71+
require.NoError(t, err)
72+
73+
err = md.Validate()
74+
75+
require.Error(t, err)
76+
assert.ErrorContains(t, err, "timeout")
77+
})
78+
79+
t.Run("aliases", func(t *testing.T) {
80+
md := &SqliteAuthMetadata{}
81+
md.Reset()
82+
83+
err := metadata.DecodeMetadata(stateMetadata(map[string]string{
84+
"url": "file:data.db",
85+
"timeoutinseconds": "1200",
86+
}), &md)
87+
require.NoError(t, err)
88+
89+
err = md.Validate()
90+
91+
require.NoError(t, err)
92+
assert.Equal(t, "file:data.db", md.ConnectionString)
93+
assert.Equal(t, 20*time.Minute, md.Timeout)
94+
})
95+
}

metadata/utils.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func DecodeMetadata(input any, result any) error {
158158
}
159159

160160
// Handle aliases
161-
err = resolveAliases(inputMap, result)
161+
err = resolveAliases(inputMap, reflect.TypeOf(result))
162162
if err != nil {
163163
return fmt.Errorf("failed to resolve aliases: %w", err)
164164
}
@@ -183,7 +183,7 @@ func DecodeMetadata(input any, result any) error {
183183
return err
184184
}
185185

186-
func resolveAliases(md map[string]string, result any) error {
186+
func resolveAliases(md map[string]string, t reflect.Type) error {
187187
// Get the list of all keys in the map
188188
keys := make(map[string]string, len(md))
189189
for k := range md {
@@ -199,7 +199,6 @@ func resolveAliases(md map[string]string, result any) error {
199199
}
200200

201201
// Error if result is not pointer to struct, or pointer to pointer to struct
202-
t := reflect.TypeOf(result)
203202
if t.Kind() != reflect.Pointer {
204203
return fmt.Errorf("not a pointer: %s", t.Kind().String())
205204
}
@@ -211,7 +210,14 @@ func resolveAliases(md map[string]string, result any) error {
211210
return fmt.Errorf("not a struct: %s", t.Kind().String())
212211
}
213212

214-
// Iterate through all the properties of result to see if anyone has the "mapstructurealiases" property
213+
// Iterate through all the properties, possibly recursively
214+
resolveAliasesInType(md, keys, t)
215+
216+
return nil
217+
}
218+
219+
func resolveAliasesInType(md map[string]string, keys map[string]string, t reflect.Type) {
220+
// Iterate through all the properties of the type to see if anyone has the "mapstructurealiases" property
215221
for i := 0; i < t.NumField(); i++ {
216222
currentField := t.Field(i)
217223

@@ -221,6 +227,12 @@ func resolveAliases(md map[string]string, result any) error {
221227
continue
222228
}
223229

230+
// Check if this is an embedded struct
231+
if mapstructureTag == ",squash" {
232+
resolveAliasesInType(md, keys, currentField.Type)
233+
continue
234+
}
235+
224236
// If the current property has a value in the metadata, then we don't need to handle aliases
225237
_, ok := keys[strings.ToLower(mapstructureTag)]
226238
if ok {
@@ -246,8 +258,6 @@ func resolveAliases(md map[string]string, result any) error {
246258
break
247259
}
248260
}
249-
250-
return nil
251261
}
252262

253263
func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {

0 commit comments

Comments
 (0)