Skip to content

Commit 968d316

Browse files
authored
sqlreplay: report an error immediately when the encryption file key is not set (#751)
1 parent af46321 commit 968d316

File tree

11 files changed

+122
-79
lines changed

11 files changed

+122
-79
lines changed

lib/go.mod

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/tiancaiamao/gp v0.0.0-20230126082955-4f9e4f1ed9b5
1010
go.uber.org/atomic v1.9.0
1111
go.uber.org/zap v1.23.0
12+
golang.org/x/term v0.29.0
1213
)
1314

1415
require (
@@ -17,7 +18,6 @@ require (
1718
github.com/pmezard/go-difflib v1.0.0 // indirect
1819
github.com/spf13/pflag v1.0.5 // indirect
1920
golang.org/x/sys v0.30.0 // indirect
20-
golang.org/x/term v0.29.0 // indirect
2121
gopkg.in/yaml.v3 v3.0.1 // indirect
2222
)
2323

@@ -26,7 +26,6 @@ require (
2626
github.com/kr/pretty v0.3.0 // indirect
2727
github.com/pkg/errors v0.9.1 // indirect
2828
go.uber.org/multierr v1.8.0 // indirect
29-
golang.org/x/crypto v0.33.0
3029
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
3130
gopkg.in/natefinch/lumberjack.v2 v2.0.0
3231
)

lib/go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8=
5050
go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
5151
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
5252
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
53-
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
54-
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
5553
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
5654
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
5755
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=

pkg/sqlreplay/capture/capture.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type CaptureConfig struct {
6161
StartTime time.Time
6262
Duration time.Duration
6363
Compress bool
64+
encryptionKey []byte
6465
cmdLogger io.WriteCloser
6566
bufferCap int
6667
flushThreshold int
@@ -91,6 +92,11 @@ func (cfg *CaptureConfig) Validate() (storage.ExternalStorage, error) {
9192
} else if cfg.StartTime.Add(cfg.Duration).Before(now) {
9293
return storage, errors.New("start time should not be in the past")
9394
}
95+
key, err := store.LoadEncryptionKey(cfg.EncryptionMethod, cfg.KeyFile)
96+
if err != nil {
97+
return storage, errors.Wrapf(err, "failed to load encryption key")
98+
}
99+
cfg.encryptionKey = key
94100
if cfg.bufferCap == 0 {
95101
cfg.bufferCap = bufferCap
96102
}
@@ -244,10 +250,10 @@ func (c *capture) flushBuffer(bufCh <-chan *bytes.Buffer) {
244250
if cmdLogger == nil {
245251
var err error
246252
cmdLogger, err = store.NewWriter(c.lg.Named("writer"), c.storage, store.WriterCfg{
247-
Dir: c.cfg.Output,
248-
EncryptMethod: c.cfg.EncryptionMethod,
249-
KeyFile: c.cfg.KeyFile,
250-
Compress: c.cfg.Compress,
253+
Dir: c.cfg.Output,
254+
EncryptionMethod: c.cfg.EncryptionMethod,
255+
EncryptionKey: c.cfg.encryptionKey,
256+
Compress: c.cfg.Compress,
251257
})
252258
if err != nil {
253259
c.lg.Error("failed to create capture writer", zap.Error(err))

pkg/sqlreplay/capture/capture_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ func TestCaptureCfgError(t *testing.T) {
160160
Output: path,
161161
StartTime: now,
162162
},
163+
{
164+
Duration: 10 * time.Second,
165+
Output: dir,
166+
StartTime: now,
167+
EncryptionMethod: store.EncryptAes,
168+
KeyFile: "",
169+
},
163170
}
164171

165172
for i, cfg := range cfgs {

pkg/sqlreplay/replay/replay.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ type ReplayConfig struct {
5656
KeyFile string
5757
// It's specified when executing with the statement `TRAFFIC REPLAY` so that all TiProxy instances
5858
// use the same start time and the time acts as the job ID.
59-
StartTime time.Time
60-
Speed float64
61-
ReadOnly bool
59+
StartTime time.Time
60+
Speed float64
61+
ReadOnly bool
62+
encryptionKey []byte
6263
// the following fields are for testing
6364
reader cmd.LineReader
6465
report report.Report
@@ -154,6 +155,12 @@ func (r *replay) Start(cfg ReplayConfig, backendTLSConfig *tls.Config, hsHandler
154155
r.replayStats.Reset()
155156
r.exceptionCh = make(chan conn.Exception, maxPendingExceptions)
156157
r.closeCh = make(chan uint64, maxPendingExceptions)
158+
key, err := store.LoadEncryptionKey(r.meta.EncryptMethod, cfg.KeyFile)
159+
if err != nil {
160+
return errors.Wrapf(err, "failed to load encryption key")
161+
}
162+
cfg.encryptionKey = key
163+
157164
hsHandler = NewHandshakeHandler(hsHandler)
158165
r.connCreator = cfg.connCreator
159166
if r.connCreator == nil {
@@ -190,9 +197,9 @@ func (r *replay) readCommands(ctx context.Context) {
190197
if reader == nil {
191198
var err error
192199
reader, err = store.NewReader(r.lg.Named("loader"), r.storage, store.ReaderCfg{
193-
Dir: r.cfg.Input,
194-
KeyFile: r.cfg.KeyFile,
195-
EncryptMethod: r.meta.EncryptMethod,
200+
Dir: r.cfg.Input,
201+
EncryptionKey: r.cfg.encryptionKey,
202+
EncryptionMethod: r.meta.EncryptMethod,
196203
})
197204
if err != nil {
198205
r.stop(err)

pkg/sqlreplay/replay/replay_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,22 @@ func TestPendingCmds(t *testing.T) {
273273
require.Contains(t, logs, `"total_wait_time": "150ms"`)
274274
require.Contains(t, logs, "too many pending commands")
275275
}
276+
277+
func TestStartError(t *testing.T) {
278+
replay := NewReplay(zap.NewNop(), id.NewIDManager())
279+
dir := t.TempDir()
280+
meta := store.NewMeta(10*time.Second, 20, 0, store.EncryptAes)
281+
storage, err := store.NewStorage(dir)
282+
require.NoError(t, err)
283+
defer storage.Close()
284+
require.NoError(t, meta.Write(storage))
285+
now := time.Now()
286+
287+
cfg := ReplayConfig{
288+
Input: dir,
289+
Username: "u1",
290+
StartTime: now,
291+
}
292+
err = replay.Start(cfg, nil, nil, &backend.BCConfig{})
293+
require.ErrorContains(t, err, "encryption")
294+
}

pkg/sqlreplay/store/encrypt.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,15 @@ type aesCTRWriter struct {
2727
stream cipher.Stream
2828
}
2929

30-
func newWriterWithEncryptOpts(writer io.WriteCloser, encryptMethod string, keyFile string) (io.WriteCloser, error) {
31-
switch strings.ToLower(encryptMethod) {
30+
func newWriterWithEncryptOpts(writer io.WriteCloser, encryptionMethod string, encrytionKey []byte) (io.WriteCloser, error) {
31+
switch strings.ToLower(encryptionMethod) {
3232
case "", EncryptPlain:
3333
return writer, nil
3434
case EncryptAes:
35+
return newAESCTRWriter(writer, encrytionKey)
3536
default:
36-
return nil, fmt.Errorf("unsupported encrypt method: %s", encryptMethod)
37+
return nil, fmt.Errorf("unsupported encrypt method: %s", encryptionMethod)
3738
}
38-
39-
key, err := readAesKey(keyFile)
40-
if err != nil {
41-
return nil, err
42-
}
43-
return newAESCTRWriter(writer, key)
4439
}
4540

4641
func newAESCTRWriter(writer io.WriteCloser, key []byte) (*aesCTRWriter, error) {
@@ -72,20 +67,15 @@ type aesCTRReader struct {
7267
stream cipher.Stream
7368
}
7469

75-
func newReaderWithEncryptOpts(reader io.Reader, encryptMethod string, keyFile string) (io.Reader, error) {
76-
switch strings.ToLower(encryptMethod) {
70+
func newReaderWithEncryptOpts(reader io.Reader, encryptionMethod string, encryptionKey []byte) (io.Reader, error) {
71+
switch strings.ToLower(encryptionMethod) {
7772
case "", EncryptPlain:
7873
return reader, nil
7974
case EncryptAes:
75+
return newAESCTRReader(reader, encryptionKey)
8076
default:
81-
return nil, fmt.Errorf("unsupported encrypt method: %s", encryptMethod)
82-
}
83-
84-
key, err := readAesKey(keyFile)
85-
if err != nil {
86-
return nil, err
77+
return nil, fmt.Errorf("unsupported encrypt method: %s", encryptionMethod)
8778
}
88-
return newAESCTRReader(reader, key)
8979
}
9080

9181
func newAESCTRReader(reader io.Reader, key []byte) (*aesCTRReader, error) {
@@ -118,16 +108,28 @@ func (ctr *aesCTRReader) Read(data []byte) (int, error) {
118108
return n, errors.WithStack(err)
119109
}
120110

111+
func LoadEncryptionKey(encryptionMethod, keyFile string) ([]byte, error) {
112+
switch strings.ToLower(encryptionMethod) {
113+
case "", EncryptPlain:
114+
return nil, nil
115+
case EncryptAes:
116+
return readAesKey(keyFile)
117+
default:
118+
return nil, fmt.Errorf("unsupported encrypt method: %s", encryptionMethod)
119+
}
120+
}
121+
121122
func readAesKey(filename string) ([]byte, error) {
122123
if len(filename) == 0 {
123-
return nil, errors.New("encryption key file name is not set")
124+
return nil, errors.New("security.encryption-key-file is not set")
124125
}
125126
key, err := os.ReadFile(filename)
126127
if err != nil {
127128
return nil, errors.WithStack(err)
128129
}
129-
if len(key) != 32 {
130+
if len(key) < 32 {
130131
return nil, errors.Errorf("invalid aes-256 key length: %d, expecting 32", len(key))
131132
}
132-
return key, nil
133+
// in case it's ended with a new line
134+
return key[:32], nil
133135
}

pkg/sqlreplay/store/encrypt_test.go

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,21 @@ func TestAes256(t *testing.T) {
5050
func TestEncryptOpts(t *testing.T) {
5151
dir := t.TempDir()
5252
path := filepath.Join(dir, "test")
53-
keyFile := filepath.Join(dir, "key")
54-
require.NoError(t, os.WriteFile(keyFile, genAesKey(), 0600))
53+
aesKey := genAesKey()
5554

5655
tests := []struct {
57-
method string
58-
keyFile string
56+
method string
57+
key []byte
5958
}{
60-
{EncryptPlain, ""},
61-
{"", ""},
62-
{EncryptAes, keyFile},
59+
{EncryptPlain, nil},
60+
{"", nil},
61+
{EncryptAes, aesKey},
6362
}
6463
for i, test := range tests {
6564
// write
6665
file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
6766
require.NoError(t, err)
68-
writer, err := newWriterWithEncryptOpts(file, test.method, test.keyFile)
67+
writer, err := newWriterWithEncryptOpts(file, test.method, test.key)
6968
require.NoError(t, err, "case %d", i)
7069
n, err := writer.Write([]byte("test"))
7170
require.NoError(t, err)
@@ -75,7 +74,7 @@ func TestEncryptOpts(t *testing.T) {
7574
// read
7675
file, err = os.OpenFile(path, os.O_RDONLY, 0600)
7776
require.NoError(t, err)
78-
reader, err := newReaderWithEncryptOpts(file, test.method, test.keyFile)
77+
reader, err := newReaderWithEncryptOpts(file, test.method, test.key)
7978
require.NoError(t, err)
8079
data := make([]byte, 100)
8180
n, err = io.ReadFull(reader, data)
@@ -85,32 +84,40 @@ func TestEncryptOpts(t *testing.T) {
8584
}
8685
}
8786

88-
func TestAes256Error(t *testing.T) {
87+
func TestLoadKey(t *testing.T) {
8988
dir := t.TempDir()
90-
path := filepath.Join(dir, "test")
91-
file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
92-
require.NoError(t, err)
93-
defer file.Close()
89+
key := genAesKey()
90+
9491
keyFile := filepath.Join(dir, "key")
95-
require.NoError(t, os.WriteFile(keyFile, genAesKey(), 0600))
92+
require.NoError(t, os.WriteFile(keyFile, key, 0600))
9693
invalidKeyFile := filepath.Join(dir, "invalid")
9794
require.NoError(t, os.WriteFile(invalidKeyFile, []byte("invalid"), 0600))
9895
noKeyFile := filepath.Join(dir, "nonexist")
96+
longKeyFile := filepath.Join(dir, "valid")
97+
longKey := make([]byte, 33)
98+
copy(longKey, key)
99+
longKey[32] = '\n'
100+
require.NoError(t, os.WriteFile(longKeyFile, longKey, 0600))
99101

100102
tests := []struct {
101103
method string
102104
keyFile string
105+
err bool
103106
}{
104-
{"unknown", keyFile},
105-
{EncryptAes, ""},
106-
{EncryptAes, noKeyFile},
107-
{EncryptAes, invalidKeyFile},
107+
{"unknown", keyFile, true},
108+
{EncryptAes, "", true},
109+
{EncryptAes, noKeyFile, true},
110+
{EncryptAes, invalidKeyFile, true},
111+
{EncryptAes, keyFile, false},
112+
{EncryptAes, longKeyFile, false},
108113
}
109114
for i, test := range tests {
110-
_, err = newWriterWithEncryptOpts(file, test.method, test.keyFile)
111-
require.Error(t, err, "case %d", i)
112-
_, err = newReaderWithEncryptOpts(file, test.method, test.keyFile)
113-
require.Error(t, err, "case %d", i)
115+
actual, err := LoadEncryptionKey(test.method, test.keyFile)
116+
if test.err {
117+
require.Error(t, err, "case %d", i)
118+
} else {
119+
require.Equal(t, key, actual, "case %d", i)
120+
}
114121
}
115122
}
116123

pkg/sqlreplay/store/line.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ import (
1414
)
1515

1616
type WriterCfg struct {
17-
Dir string
18-
EncryptMethod string
19-
KeyFile string
20-
FileSize int
21-
Compress bool
17+
Dir string
18+
EncryptionMethod string
19+
EncryptionKey []byte
20+
FileSize int
21+
Compress bool
2222
}
2323

2424
// NewWriter just wraps the rotate writer. It doesn't use a buffer because Capture writes data in a big batch.
@@ -28,9 +28,9 @@ func NewWriter(lg *zap.Logger, externalStorage storage.ExternalStorage, cfg Writ
2828
}
2929

3030
type ReaderCfg struct {
31-
Dir string
32-
EncryptMethod string
33-
KeyFile string
31+
Dir string
32+
EncryptionMethod string
33+
EncryptionKey []byte
3434
}
3535

3636
var _ cmd.LineReader = (*loader)(nil)

pkg/sqlreplay/store/rotate.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func (w *rotateWriter) createFile() error {
7373
if w.cfg.Compress {
7474
w.writer = newCompressWriter(w.lg, w.writer)
7575
}
76-
if w.writer, err = newWriterWithEncryptOpts(w.writer, w.cfg.EncryptMethod, w.cfg.KeyFile); err != nil {
76+
if w.writer, err = newWriterWithEncryptOpts(w.writer, w.cfg.EncryptionMethod, w.cfg.EncryptionKey); err != nil {
7777
return err
7878
}
7979
return nil
@@ -207,7 +207,7 @@ func (r *rotateReader) nextReader() error {
207207
return err
208208
}
209209
}
210-
r.reader, err = newReaderWithEncryptOpts(r.reader, r.cfg.EncryptMethod, r.cfg.KeyFile)
210+
r.reader, err = newReaderWithEncryptOpts(r.reader, r.cfg.EncryptionMethod, r.cfg.EncryptionKey)
211211
if err != nil {
212212
return err
213213
}

0 commit comments

Comments
 (0)