Skip to content

Commit 172649a

Browse files
authored
sqlreplay: remove lumberjack to support large commands (#727)
1 parent ed64d5f commit 172649a

File tree

12 files changed

+461
-350
lines changed

12 files changed

+461
-350
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ require (
3434
go.uber.org/ratelimit v0.2.0
3535
go.uber.org/zap v1.26.0
3636
google.golang.org/grpc v1.54.0
37-
gopkg.in/natefinch/lumberjack.v2 v2.2.1
3837
)
3938

4039
require (
@@ -117,6 +116,7 @@ require (
117116
golang.org/x/time v0.3.0 // indirect
118117
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
119118
google.golang.org/protobuf v1.31.0 // indirect
119+
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
120120
gopkg.in/yaml.v2 v2.4.0 // indirect
121121
gopkg.in/yaml.v3 v3.0.1 // indirect
122122
sigs.k8s.io/yaml v1.3.0 // indirect

pkg/sqlreplay/capture/capture.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package capture
66
import (
77
"bytes"
88
"context"
9+
"io"
910
"os"
1011
"sync"
1112
"time"
@@ -59,7 +60,7 @@ type CaptureConfig struct {
5960
StartTime time.Time
6061
Duration time.Duration
6162
Compress bool
62-
cmdLogger store.Writer
63+
cmdLogger io.WriteCloser
6364
bufferCap int
6465
flushThreshold int
6566
maxBuffers int
@@ -231,12 +232,14 @@ func (c *capture) collectCmds(bufCh chan<- *bytes.Buffer) {
231232
}
232233
}
233234

235+
// Writing commands requires a bytes buffer instead of a simple bufio.Writer,
236+
// so the buffer can not be pushed down to the store package.
234237
func (c *capture) flushBuffer(bufCh <-chan *bytes.Buffer) {
235238
// cfg.cmdLogger is set in tests
236239
cmdLogger := c.cfg.cmdLogger
237240
if cmdLogger == nil {
238241
var err error
239-
cmdLogger, err = store.NewWriter(store.WriterCfg{
242+
cmdLogger, err = store.NewWriter(c.lg.Named("writer"), store.WriterCfg{
240243
Dir: c.cfg.Output,
241244
EncryptMethod: c.cfg.EncryptMethod,
242245
KeyFile: c.cfg.KeyFile,
@@ -249,8 +252,7 @@ func (c *capture) flushBuffer(bufCh <-chan *bytes.Buffer) {
249252
}
250253
// Flush all buffers even if the context is timeout.
251254
for buf := range bufCh {
252-
// TODO: each write size should be less than MaxSize.
253-
if err := cmdLogger.Write(buf.Bytes()); err != nil {
255+
if _, err := cmdLogger.Write(buf.Bytes()); err != nil {
254256
c.stop(errors.Wrapf(err, "failed to flush traffic to disk"))
255257
break
256258
}

pkg/sqlreplay/capture/mock_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ package capture
55

66
import (
77
"bytes"
8+
"io"
89
"sync"
910

1011
"github.com/pingcap/tiproxy/pkg/sqlreplay/store"
1112
)
1213

13-
var _ store.Writer = (*mockWriter)(nil)
14+
var _ io.WriteCloser = (*mockWriter)(nil)
1415

1516
type mockWriter struct {
1617
sync.Mutex
@@ -21,11 +22,10 @@ func newMockWriter(store.WriterCfg) *mockWriter {
2122
return &mockWriter{}
2223
}
2324

24-
func (w *mockWriter) Write(p []byte) error {
25+
func (w *mockWriter) Write(p []byte) (int, error) {
2526
w.Lock()
2627
defer w.Unlock()
27-
_, err := w.buf.Write(p)
28-
return err
28+
return w.buf.Write(p)
2929
}
3030

3131
func (w *mockWriter) getData() []byte {

pkg/sqlreplay/store/compress.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package store
5+
6+
import (
7+
"compress/gzip"
8+
"io"
9+
10+
"go.uber.org/zap"
11+
)
12+
13+
var _ io.WriteCloser = (*compressWriter)(nil)
14+
15+
type compressWriter struct {
16+
io.WriteCloser
17+
internalWriter io.WriteCloser
18+
lg *zap.Logger
19+
}
20+
21+
func newCompressWriter(lg *zap.Logger, writer io.WriteCloser) *compressWriter {
22+
return &compressWriter{
23+
WriteCloser: gzip.NewWriter(writer),
24+
internalWriter: writer,
25+
lg: lg,
26+
}
27+
}
28+
29+
func (w *compressWriter) Close() error {
30+
if err := w.WriteCloser.Close(); err != nil {
31+
w.lg.Warn("failed to close writer", zap.Error(err))
32+
}
33+
return w.internalWriter.Close()
34+
}
35+
36+
var _ io.Reader = (*compressReader)(nil)
37+
38+
type compressReader struct {
39+
*gzip.Reader
40+
}
41+
42+
func newCompressReader(reader io.Reader) (*compressReader, error) {
43+
gr, err := gzip.NewReader(reader)
44+
if err != nil {
45+
return nil, err
46+
}
47+
return &compressReader{
48+
Reader: gr,
49+
}, nil
50+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package store
5+
6+
import (
7+
"io"
8+
"os"
9+
"path/filepath"
10+
"testing"
11+
12+
"github.com/stretchr/testify/require"
13+
"go.uber.org/zap"
14+
)
15+
16+
func TestCompressReadWrite(t *testing.T) {
17+
dir := t.TempDir()
18+
path := filepath.Join(dir, "test")
19+
20+
file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
21+
require.NoError(t, err)
22+
writer := newCompressWriter(zap.NewNop(), file)
23+
n, err := writer.Write([]byte("test"))
24+
require.NoError(t, err)
25+
require.Equal(t, 4, n)
26+
require.NoError(t, writer.Close())
27+
// file is already closed
28+
require.Error(t, file.Close())
29+
30+
file, err = os.OpenFile(path, os.O_RDONLY, 0600)
31+
require.NoError(t, err)
32+
reader, err := newCompressReader(file)
33+
require.NoError(t, err)
34+
data := make([]byte, 100)
35+
n, err = io.ReadFull(reader, data)
36+
require.Equal(t, 4, n)
37+
require.ErrorContains(t, err, "EOF")
38+
}

pkg/sqlreplay/store/const.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
package store
55

66
const (
7-
fileNamePrefix = "traffic"
7+
fileNamePrefix = "traffic-"
88
fileNameSuffix = ".log"
9-
fileName = fileNamePrefix + fileNameSuffix
10-
fileTsLayout = "2006-01-02T15-04-05.000"
119
fileCompressFormat = ".gz"
12-
fileSize = 300 // 300MB
10+
fileSize = 300 << 20
11+
bufferSize = 1 << 20
1312
)

pkg/sqlreplay/store/encrypt.go

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,43 @@ import (
77
"crypto/aes"
88
"crypto/cipher"
99
"crypto/rand"
10+
"fmt"
1011
"io"
1112
"os"
12-
"reflect"
13+
"strings"
1314

1415
"github.com/pingcap/tiproxy/lib/util/errors"
1516
)
1617

17-
var _ Writer = (*aesCTRWriter)(nil)
18+
const (
19+
EncryptPlain = "plaintext"
20+
EncryptAes = "aes256-ctr"
21+
)
22+
23+
var _ io.WriteCloser = (*aesCTRWriter)(nil)
1824

1925
type aesCTRWriter struct {
20-
Writer
26+
io.WriteCloser
2127
stream cipher.Stream
22-
iv []byte
23-
inited bool
2428
}
2529

26-
func newAESCTRWriter(writer Writer, keyFile string) (*aesCTRWriter, error) {
30+
func newWriterWithEncryptOpts(writer io.WriteCloser, encryptMethod string, keyFile string) (io.WriteCloser, error) {
31+
switch strings.ToLower(encryptMethod) {
32+
case "", EncryptPlain:
33+
return writer, nil
34+
case EncryptAes:
35+
default:
36+
return nil, fmt.Errorf("unsupported encrypt method: %s", encryptMethod)
37+
}
38+
2739
key, err := readAesKey(keyFile)
2840
if err != nil {
2941
return nil, err
3042
}
43+
return newAESCTRWriter(writer, key)
44+
}
45+
46+
func newAESCTRWriter(writer io.WriteCloser, key []byte) (*aesCTRWriter, error) {
3147
block, err := aes.NewCipher(key)
3248
if err != nil {
3349
return nil, errors.WithStack(err)
@@ -36,93 +52,76 @@ func newAESCTRWriter(writer Writer, keyFile string) (*aesCTRWriter, error) {
3652
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
3753
return nil, errors.WithStack(err)
3854
}
39-
return &aesCTRWriter{
40-
Writer: writer,
41-
stream: cipher.NewCTR(block, iv),
42-
iv: iv,
43-
}, nil
44-
}
45-
46-
func (ctr *aesCTRWriter) Write(data []byte) error {
47-
if !ctr.inited {
48-
if err := ctr.writeIV(); err != nil {
49-
return err
50-
}
51-
ctr.inited = true
55+
ctr := &aesCTRWriter{
56+
WriteCloser: writer,
57+
stream: cipher.NewCTR(block, iv),
5258
}
53-
ctr.stream.XORKeyStream(data, data)
54-
return ctr.Writer.Write(data)
59+
_, err = ctr.WriteCloser.Write(iv)
60+
return ctr, err
5561
}
5662

57-
func (ctr *aesCTRWriter) writeIV() error {
58-
return ctr.Writer.Write(ctr.iv)
59-
}
60-
61-
func (ctr *aesCTRWriter) Close() error {
62-
return ctr.Writer.Close()
63+
func (ctr *aesCTRWriter) Write(data []byte) (int, error) {
64+
ctr.stream.XORKeyStream(data, data)
65+
return ctr.WriteCloser.Write(data)
6366
}
6467

65-
var _ Reader = (*aesCTRReader)(nil)
68+
var _ io.Reader = (*aesCTRReader)(nil)
6669

6770
type aesCTRReader struct {
68-
Reader
71+
io.Reader
6972
stream cipher.Stream
70-
key []byte
7173
}
7274

73-
func newAESCTRReader(reader Reader, keyFile string) (*aesCTRReader, error) {
75+
func newReaderWithEncryptOpts(reader io.Reader, encryptMethod string, keyFile string) (io.Reader, error) {
76+
switch strings.ToLower(encryptMethod) {
77+
case "", EncryptPlain:
78+
return reader, nil
79+
case EncryptAes:
80+
default:
81+
return nil, fmt.Errorf("unsupported encrypt method: %s", encryptMethod)
82+
}
83+
7484
key, err := readAesKey(keyFile)
7585
if err != nil {
7686
return nil, err
7787
}
88+
return newAESCTRReader(reader, key)
89+
}
90+
91+
func newAESCTRReader(reader io.Reader, key []byte) (*aesCTRReader, error) {
92+
block, err := aes.NewCipher(key)
93+
if err != nil {
94+
return nil, errors.WithStack(err)
95+
}
96+
iv := make([]byte, aes.BlockSize)
97+
for readLen := 0; readLen < len(iv); {
98+
m, err := reader.Read(iv[readLen:])
99+
if err != nil {
100+
return nil, err
101+
}
102+
readLen += m
103+
}
78104
return &aesCTRReader{
79105
Reader: reader,
80-
key: key,
106+
stream: cipher.NewCTR(block, iv),
81107
}, nil
82108
}
83109

84110
func (ctr *aesCTRReader) Read(data []byte) (int, error) {
85-
if ctr.stream == nil || reflect.ValueOf(ctr.stream).IsNil() {
86-
if err := ctr.init(); err != nil {
87-
return 0, err
88-
}
89-
}
90111
n, err := ctr.Reader.Read(data)
91112
if n > 0 {
92113
ctr.stream.XORKeyStream(data[:n], data[:n])
93114
}
94115
if err != nil {
95-
return n, err
116+
return n, errors.WithStack(err)
96117
}
97118
return n, nil
98119
}
99120

100-
func (ctr *aesCTRReader) init() error {
101-
block, err := aes.NewCipher(ctr.key)
102-
if err != nil {
103-
return errors.WithStack(err)
104-
}
105-
iv := make([]byte, aes.BlockSize)
106-
for readLen := 0; readLen < len(iv); {
107-
m, err := ctr.Reader.Read(iv[readLen:])
108-
if err != nil {
109-
return err
110-
}
111-
readLen += m
112-
}
113-
ctr.stream = cipher.NewCTR(block, iv)
114-
return nil
115-
}
116-
117-
func (ctr *aesCTRReader) CurFile() string {
118-
return ctr.Reader.CurFile()
119-
}
120-
121-
func (ctr *aesCTRReader) Close() {
122-
ctr.Reader.Close()
123-
}
124-
125121
func readAesKey(filename string) ([]byte, error) {
122+
if len(filename) == 0 {
123+
return nil, errors.New("encryption key file name is not set")
124+
}
126125
key, err := os.ReadFile(filename)
127126
if err != nil {
128127
return nil, errors.WithStack(err)

0 commit comments

Comments
 (0)