Skip to content

Commit 451e530

Browse files
authored
sqlreplay, api: support encryption method options to traffic capture (#718)
1 parent 56d17de commit 451e530

File tree

17 files changed

+876
-531
lines changed

17 files changed

+876
-531
lines changed

lib/cli/traffic.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ func GetTrafficCaptureCmd(ctx *Context) *cobra.Command {
2929
}
3030
output := captureCmd.PersistentFlags().String("output", "", "output directory for traffic files")
3131
duration := captureCmd.PersistentFlags().String("duration", "", "the duration of traffic capture")
32+
encrypt := captureCmd.PersistentFlags().String("encrypt-method", "", "the encryption method used for encrypting traffic files")
3233
captureCmd.RunE = func(cmd *cobra.Command, args []string) error {
33-
reader := GetFormReader(map[string]string{"output": *output, "duration": *duration})
34+
reader := GetFormReader(map[string]string{"output": *output, "duration": *duration, "encrypt-method": *encrypt})
3435
resp, err := doRequest(cmd.Context(), ctx, http.MethodPost, "/api/traffic/capture", reader)
3536
if err != nil {
3637
return err

lib/config/proxy.go

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -93,33 +93,6 @@ type LogFile struct {
9393
MaxBackups int `yaml:"max-backups,omitempty" toml:"max-backups,omitempty" json:"max-backups,omitempty"`
9494
}
9595

96-
type TLSConfig struct {
97-
Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"`
98-
Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"`
99-
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
100-
MinTLSVersion string `yaml:"min-tls-version,omitempty" toml:"min-tls-version,omitempty" json:"min-tls-version,omitempty"`
101-
AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"`
102-
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
103-
AutoExpireDuration string `yaml:"autocert-expire-duration,omitempty" toml:"autocert-expire-duration,omitempty" json:"autocert-expire-duration,omitempty"`
104-
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
105-
}
106-
107-
func (c TLSConfig) HasCert() bool {
108-
return !(c.Cert == "" && c.Key == "")
109-
}
110-
111-
func (c TLSConfig) HasCA() bool {
112-
return c.CA != ""
113-
}
114-
115-
type Security struct {
116-
ServerSQLTLS TLSConfig `yaml:"server-tls,omitempty" toml:"server-tls,omitempty" json:"server-tls,omitempty"`
117-
ServerHTTPTLS TLSConfig `yaml:"server-http-tls,omitempty" toml:"server-http-tls,omitempty" json:"server-http-tls,omitempty"`
118-
ClusterTLS TLSConfig `yaml:"cluster-tls,omitempty" toml:"cluster-tls,omitempty" json:"cluster-tls,omitempty"`
119-
SQLTLS TLSConfig `yaml:"sql-tls,omitempty" toml:"sql-tls,omitempty" json:"sql-tls,omitempty"`
120-
RequireBackendTLS bool `yaml:"require-backend-tls,omitempty" toml:"require-backend-tls,omitempty" json:"require-backend-tls,omitempty"`
121-
}
122-
12396
type HA struct {
12497
VirtualIP string `yaml:"virtual-ip,omitempty" toml:"virtual-ip,omitempty" json:"virtual-ip,omitempty"`
12598
Interface string `yaml:"interface,omitempty" toml:"interface,omitempty" json:"interface,omitempty"`

lib/config/security.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package config
5+
6+
type TLSConfig struct {
7+
Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"`
8+
Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"`
9+
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
10+
MinTLSVersion string `yaml:"min-tls-version,omitempty" toml:"min-tls-version,omitempty" json:"min-tls-version,omitempty"`
11+
AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"`
12+
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
13+
AutoExpireDuration string `yaml:"autocert-expire-duration,omitempty" toml:"autocert-expire-duration,omitempty" json:"autocert-expire-duration,omitempty"`
14+
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
15+
}
16+
17+
func (c TLSConfig) HasCert() bool {
18+
return !(c.Cert == "" && c.Key == "")
19+
}
20+
21+
func (c TLSConfig) HasCA() bool {
22+
return c.CA != ""
23+
}
24+
25+
type Security struct {
26+
ServerSQLTLS TLSConfig `yaml:"server-tls,omitempty" toml:"server-tls,omitempty" json:"server-tls,omitempty"`
27+
ServerHTTPTLS TLSConfig `yaml:"server-http-tls,omitempty" toml:"server-http-tls,omitempty" json:"server-http-tls,omitempty"`
28+
ClusterTLS TLSConfig `yaml:"cluster-tls,omitempty" toml:"cluster-tls,omitempty" json:"cluster-tls,omitempty"`
29+
SQLTLS TLSConfig `yaml:"sql-tls,omitempty" toml:"sql-tls,omitempty" json:"sql-tls,omitempty"`
30+
Encryption Encryption `yaml:"encryption,omitempty" toml:"encryption,omitempty" json:"encryption,omitempty"`
31+
RequireBackendTLS bool `yaml:"require-backend-tls,omitempty" toml:"require-backend-tls,omitempty" json:"require-backend-tls,omitempty"`
32+
}
33+
34+
type Encryption struct {
35+
KeyPath string `yaml:"key-path,omitempty" toml:"key-path,omitempty" json:"key-path,omitempty"`
36+
}

pkg/server/api/traffic.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ func (h *Server) TrafficCapture(c *gin.Context) {
3232
}
3333
cfg.Duration = duration
3434
}
35+
cfg.EncryptMethod = c.PostForm("encrypt-method")
36+
cfg.KeyFile = h.mgr.CfgMgr.GetConfig().Security.Encryption.KeyPath
3537

3638
if err := h.mgr.ReplayJobMgr.StartCapture(cfg); err != nil {
3739
c.String(http.StatusInternalServerError, err.Error())
@@ -54,6 +56,7 @@ func (h *Server) TrafficReplay(c *gin.Context) {
5456
cfg.Username = c.PostForm("username")
5557
cfg.Password = c.PostForm("password")
5658
cfg.ReadOnly = strings.EqualFold(c.PostForm("readonly"), "true")
59+
cfg.KeyFile = h.mgr.CfgMgr.GetConfig().Security.Encryption.KeyPath
5760

5861
if err := h.mgr.ReplayJobMgr.StartReplay(cfg); err != nil {
5962
c.String(http.StatusInternalServerError, err.Error())

pkg/sqlreplay/capture/capture.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ type Capture interface {
5252

5353
type CaptureConfig struct {
5454
Output string
55+
EncryptMethod string
56+
KeyFile string
5557
Duration time.Duration
5658
cmdLogger store.Writer
5759
bufferCap int
@@ -219,7 +221,16 @@ func (c *capture) flushBuffer(bufCh <-chan *bytes.Buffer) {
219221
// cfg.cmdLogger is set in tests
220222
cmdLogger := c.cfg.cmdLogger
221223
if cmdLogger == nil {
222-
cmdLogger = store.NewWriter(store.WriterCfg{Dir: c.cfg.Output})
224+
var err error
225+
cmdLogger, err = store.NewWriter(store.WriterCfg{
226+
Dir: c.cfg.Output,
227+
EncryptMethod: c.cfg.EncryptMethod,
228+
KeyFile: c.cfg.KeyFile,
229+
})
230+
if err != nil {
231+
c.lg.Error("failed to create capture writer", zap.Error(err))
232+
return
233+
}
223234
}
224235
// Flush all buffers even if the context is timeout.
225236
for buf := range bufCh {
@@ -337,7 +348,7 @@ func (c *capture) putCommand(command *cmd.Command) bool {
337348
}
338349

339350
func (c *capture) writeMeta(duration time.Duration, cmds, filteredCmds uint64) {
340-
meta := store.NewMeta(duration, cmds, filteredCmds)
351+
meta := store.NewMeta(duration, cmds, filteredCmds, c.cfg.EncryptMethod)
341352
if err := meta.Write(c.cfg.Output); err != nil {
342353
c.lg.Error("failed to write meta", zap.Error(err))
343354
}

pkg/sqlreplay/replay/replay.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type ReplayConfig struct {
5353
Input string
5454
Username string
5555
Password string
56+
KeyFile string
5657
Speed float64
5758
ReadOnly bool
5859
// the following fields are for testing
@@ -174,9 +175,16 @@ func (r *replay) readCommands(ctx context.Context) {
174175
// cfg.reader is set in tests
175176
reader := r.cfg.reader
176177
if reader == nil {
177-
reader = store.NewLoader(r.lg.Named("loader"), store.LoaderCfg{
178-
Dir: r.cfg.Input,
178+
var err error
179+
reader, err = store.NewReader(r.lg.Named("loader"), store.ReaderCfg{
180+
Dir: r.cfg.Input,
181+
KeyFile: r.cfg.KeyFile,
182+
EncryptMethod: r.meta.EncryptMethod,
179183
})
184+
if err != nil {
185+
r.stop(err)
186+
return
187+
}
180188
}
181189
defer reader.Close()
182190

pkg/sqlreplay/replay/replay_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func TestReplaySpeed(t *testing.T) {
154154

155155
func TestProgress(t *testing.T) {
156156
dir := t.TempDir()
157-
meta := store.NewMeta(10*time.Second, 10, 0)
157+
meta := store.NewMeta(10*time.Second, 10, 0, "")
158158
require.NoError(t, meta.Write(dir))
159159
loader := newMockNormalLoader()
160160
now := time.Now()
@@ -202,7 +202,7 @@ func TestProgress(t *testing.T) {
202202

203203
func TestPendingCmds(t *testing.T) {
204204
dir := t.TempDir()
205-
meta := store.NewMeta(10*time.Second, 10, 0)
205+
meta := store.NewMeta(10*time.Second, 10, 0, "")
206206
require.NoError(t, meta.Write(dir))
207207
loader := newMockNormalLoader()
208208
defer loader.Close()

pkg/sqlreplay/store/encrypt.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package store
5+
6+
import (
7+
"crypto/aes"
8+
"crypto/cipher"
9+
"crypto/rand"
10+
"io"
11+
"os"
12+
"reflect"
13+
14+
"github.com/pingcap/tiproxy/lib/util/errors"
15+
)
16+
17+
var _ Writer = (*aesCTRWriter)(nil)
18+
19+
type aesCTRWriter struct {
20+
Writer
21+
stream cipher.Stream
22+
iv []byte
23+
inited bool
24+
}
25+
26+
func newAESCTRWriter(writer Writer, keyFile string) (*aesCTRWriter, error) {
27+
key, err := readAesKey(keyFile)
28+
if err != nil {
29+
return nil, err
30+
}
31+
block, err := aes.NewCipher(key)
32+
if err != nil {
33+
return nil, errors.WithStack(err)
34+
}
35+
iv := make([]byte, aes.BlockSize)
36+
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
37+
return nil, errors.WithStack(err)
38+
}
39+
return &aesCTRWriter{
40+
Writer: writer,
41+
stream: cipher.NewCTR(block, iv),
42+
iv: iv,
43+
}, nil
44+
}
45+
46+
func (ctr *aesCTRWriter) Write(data []byte) error {
47+
if !ctr.inited {
48+
if err := ctr.writeIV(); err != nil {
49+
return err
50+
}
51+
ctr.inited = true
52+
}
53+
ctr.stream.XORKeyStream(data, data)
54+
return ctr.Writer.Write(data)
55+
}
56+
57+
func (ctr *aesCTRWriter) writeIV() error {
58+
return ctr.Writer.Write(ctr.iv)
59+
}
60+
61+
func (ctr *aesCTRWriter) Close() error {
62+
return ctr.Writer.Close()
63+
}
64+
65+
var _ Reader = (*aesCTRReader)(nil)
66+
67+
type aesCTRReader struct {
68+
Reader
69+
stream cipher.Stream
70+
key []byte
71+
}
72+
73+
func newAESCTRReader(reader Reader, keyFile string) (*aesCTRReader, error) {
74+
key, err := readAesKey(keyFile)
75+
if err != nil {
76+
return nil, err
77+
}
78+
return &aesCTRReader{
79+
Reader: reader,
80+
key: key,
81+
}, nil
82+
}
83+
84+
func (ctr *aesCTRReader) Read(data []byte) (int, error) {
85+
if ctr.stream == nil || reflect.ValueOf(ctr.stream).IsNil() {
86+
if err := ctr.init(); err != nil {
87+
return 0, err
88+
}
89+
}
90+
n, err := ctr.Reader.Read(data)
91+
if n > 0 {
92+
ctr.stream.XORKeyStream(data[:n], data[:n])
93+
}
94+
if err != nil {
95+
return n, err
96+
}
97+
return n, nil
98+
}
99+
100+
func (ctr *aesCTRReader) init() error {
101+
block, err := aes.NewCipher(ctr.key)
102+
if err != nil {
103+
return errors.WithStack(err)
104+
}
105+
iv := make([]byte, aes.BlockSize)
106+
for readLen := 0; readLen < len(iv); {
107+
m, err := ctr.Reader.Read(iv[readLen:])
108+
if err != nil {
109+
return err
110+
}
111+
readLen += m
112+
}
113+
ctr.stream = cipher.NewCTR(block, iv)
114+
return nil
115+
}
116+
117+
func (ctr *aesCTRReader) CurFile() string {
118+
return ctr.Reader.CurFile()
119+
}
120+
121+
func (ctr *aesCTRReader) Close() {
122+
ctr.Reader.Close()
123+
}
124+
125+
func readAesKey(filename string) ([]byte, error) {
126+
key, err := os.ReadFile(filename)
127+
if err != nil {
128+
return nil, errors.WithStack(err)
129+
}
130+
if len(key) != 32 {
131+
return nil, errors.Errorf("invalid aes-256 key length: %d, expecting 32", len(key))
132+
}
133+
return key, nil
134+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package store
5+
6+
import (
7+
"io"
8+
"os"
9+
"path/filepath"
10+
"testing"
11+
12+
"github.com/stretchr/testify/require"
13+
"go.uber.org/zap"
14+
)
15+
16+
func TestAes256(t *testing.T) {
17+
dir := t.TempDir()
18+
rotateWriter := newRotateWriter(WriterCfg{
19+
Dir: dir,
20+
FileSize: 1,
21+
})
22+
keyFile := filepath.Join(dir, "key")
23+
genAesKey(t, keyFile)
24+
aesWriter, err := newAESCTRWriter(rotateWriter, keyFile)
25+
require.NoError(t, err)
26+
require.NoError(t, aesWriter.Write([]byte("test")))
27+
require.NoError(t, aesWriter.Close())
28+
29+
rotateReader := newRotateReader(zap.NewNop(), dir)
30+
aesReader, err := newAESCTRReader(rotateReader, keyFile)
31+
require.NoError(t, err)
32+
data := make([]byte, 100)
33+
n, err := io.ReadFull(aesReader, data)
34+
require.Equal(t, len("test"), n)
35+
require.ErrorContains(t, err, "unexpected EOF")
36+
require.Equal(t, []byte("test"), data[:n])
37+
require.Equal(t, fileName, aesReader.CurFile())
38+
aesReader.Close()
39+
}
40+
41+
func TestAes256Error(t *testing.T) {
42+
dir := t.TempDir()
43+
rotateWriter := newRotateWriter(WriterCfg{
44+
Dir: dir,
45+
FileSize: 1,
46+
})
47+
keyFile := filepath.Join(dir, "key")
48+
_, err := newAESCTRWriter(rotateWriter, keyFile)
49+
require.Error(t, err)
50+
51+
rotateReader := newRotateReader(zap.NewNop(), dir)
52+
_, err = newAESCTRReader(rotateReader, keyFile)
53+
require.Error(t, err)
54+
}
55+
56+
func genAesKey(t *testing.T, keyFile string) {
57+
key := make([]byte, 32)
58+
for i := range key {
59+
key[i] = byte(i)
60+
}
61+
require.NoError(t, os.WriteFile(keyFile, key, 0600))
62+
}

0 commit comments

Comments
 (0)