Skip to content

Commit 811e6e6

Browse files
committed
Adiantum pragmas.
1 parent 3c21784 commit 811e6e6

File tree

3 files changed

+69
-36
lines changed

3 files changed

+69
-36
lines changed

tests/parallel/parallel_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func Test_adiantum(t *testing.T) {
7575
name := "file:" +
7676
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db")) +
7777
"?vfs=adiantum" +
78-
"&hexkey=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
78+
"&_pragma=hexkey(e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855)"
7979
testParallel(t, name, iter)
8080
testIntegrity(t, name)
8181
}

vfs/adiantum/hbsh.go

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,33 @@ func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag,
2121
}
2222

2323
func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
24-
var hbsh *hbsh.HBSH
25-
26-
// Encrypt everything except super journals.
27-
if flags&vfs.OPEN_SUPER_JOURNAL == 0 {
28-
if f, ok := name.DatabaseFile().(*hbshFile); ok {
29-
hbsh = f.hbsh
30-
} else {
31-
var key []byte
32-
if params := name.URIParameters(); name == nil {
33-
key = h.hbsh.KDF("") // Temporary files get a random key.
34-
} else if t, ok := params["key"]; ok {
35-
key = []byte(t[0])
36-
} else if t, ok := params["hexkey"]; ok {
37-
key, _ = hex.DecodeString(t[0])
38-
} else if t, ok := params["textkey"]; ok {
39-
key = h.hbsh.KDF(t[0])
40-
}
41-
if hbsh = h.hbsh.HBSH(key); hbsh == nil {
42-
// Can't open without a valid key.
43-
return nil, flags, sqlite3.CANTOPEN
44-
}
45-
}
46-
}
47-
4824
if h, ok := h.VFS.(vfs.VFSFilename); ok {
4925
file, flags, err = h.OpenFilename(name, flags)
5026
} else {
5127
file, flags, err = h.Open(name.String(), flags)
5228
}
53-
if err != nil || hbsh == nil || flags&vfs.OPEN_MEMORY != 0 {
54-
// Error, or no encryption (super journals, memory files).
29+
// Encrypt everything except super journals and memory files.
30+
if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 {
5531
return file, flags, err
5632
}
57-
return &hbshFile{File: file, hbsh: hbsh}, flags, err
33+
34+
var hbsh *hbsh.HBSH
35+
if f, ok := name.DatabaseFile().(*hbshFile); ok {
36+
hbsh = f.hbsh
37+
} else {
38+
var key []byte
39+
if params := name.URIParameters(); name == nil {
40+
key = h.hbsh.KDF("") // Temporary files get a random key.
41+
} else if t, ok := params["key"]; ok {
42+
key = []byte(t[0])
43+
} else if t, ok := params["hexkey"]; ok {
44+
key, _ = hex.DecodeString(t[0])
45+
} else if t, ok := params["textkey"]; ok {
46+
key = h.hbsh.KDF(t[0])
47+
}
48+
hbsh = h.hbsh.HBSH(key)
49+
}
50+
return &hbshFile{File: file, hbsh: hbsh, reset: h.hbsh}, flags, err
5851
}
5952

6053
const (
@@ -65,11 +58,43 @@ const (
6558
type hbshFile struct {
6659
vfs.File
6760
hbsh *hbsh.HBSH
61+
reset HBSHCreator
6862
block [blockSize]byte
6963
tweak [tweakSize]byte
7064
}
7165

66+
func (h *hbshFile) Pragma(name string, value string) (string, error) {
67+
var key []byte
68+
switch name {
69+
case "key":
70+
key = []byte(value)
71+
case "hexkey":
72+
key, _ = hex.DecodeString(value)
73+
case "textkey":
74+
key = h.reset.KDF(value)
75+
default:
76+
if f, ok := h.File.(vfs.FilePragma); ok {
77+
return f.Pragma(name, value)
78+
}
79+
return "", sqlite3.NOTFOUND
80+
}
81+
82+
if h.hbsh = h.reset.HBSH(key); h.hbsh != nil {
83+
return "ok", nil
84+
}
85+
return "", sqlite3.CANTOPEN
86+
}
87+
7288
func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
89+
if h.hbsh == nil {
90+
// If it's trying to read the header, pretend the file is empty,
91+
// so the key can be specified later.
92+
if off == 0 && len(p) == 100 {
93+
return 0, io.EOF
94+
}
95+
return 0, sqlite3.CANTOPEN
96+
}
97+
7398
min := (off) &^ (blockSize - 1) // round down
7499
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up
75100

@@ -96,6 +121,10 @@ func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
96121
}
97122

98123
func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
124+
if h.hbsh == nil {
125+
return 0, sqlite3.READONLY
126+
}
127+
99128
min := (off) &^ (blockSize - 1) // round down
100129
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up
101130

vfs/vfs.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -347,12 +347,17 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
347347

348348
case _FCNTL_PRAGMA:
349349
if file, ok := file.(FilePragma); ok {
350-
name := util.ReadUint32(mod, pArg+1*ptrlen)
351-
value := util.ReadUint32(mod, pArg+2*ptrlen)
352-
out, err := file.Pragma(
353-
util.ReadString(mod, name, _MAX_SQL_LENGTH),
354-
util.ReadString(mod, value, _MAX_SQL_LENGTH))
355-
if err != nil {
350+
ptr := util.ReadUint32(mod, pArg+1*ptrlen)
351+
name := util.ReadString(mod, ptr, _MAX_SQL_LENGTH)
352+
var value string
353+
if ptr := util.ReadUint32(mod, pArg+2*ptrlen); ptr != 0 {
354+
value = util.ReadString(mod, ptr, _MAX_SQL_LENGTH)
355+
}
356+
357+
out, err := file.Pragma(name, value)
358+
359+
ret := vfsErrorCode(err, _ERROR)
360+
if ret == _ERROR {
356361
out = err.Error()
357362
}
358363
if out != "" {
@@ -363,9 +368,8 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
363368
}
364369
util.WriteUint32(mod, pArg, uint32(stack[0]))
365370
util.WriteString(mod, uint32(stack[0]), out)
366-
return _ERROR
367371
}
368-
return vfsErrorCode(err, _ERROR)
372+
return ret
369373
}
370374
}
371375

0 commit comments

Comments
 (0)