@@ -7,27 +7,43 @@ import (
77 "crypto/aes"
88 "crypto/cipher"
99 "crypto/rand"
10+ "fmt"
1011 "io"
1112 "os"
12- "reflect "
13+ "strings "
1314
1415 "github.com/pingcap/tiproxy/lib/util/errors"
1516)
1617
17- var _ Writer = (* aesCTRWriter )(nil )
18+ const (
19+ EncryptPlain = "plaintext"
20+ EncryptAes = "aes256-ctr"
21+ )
22+
23+ var _ io.WriteCloser = (* aesCTRWriter )(nil )
1824
1925type aesCTRWriter struct {
20- Writer
26+ io. WriteCloser
2127 stream cipher.Stream
22- iv []byte
23- inited bool
2428}
2529
26- func newAESCTRWriter (writer Writer , keyFile string ) (* aesCTRWriter , error ) {
30+ func newWriterWithEncryptOpts (writer io.WriteCloser , encryptMethod string , keyFile string ) (io.WriteCloser , error ) {
31+ switch strings .ToLower (encryptMethod ) {
32+ case "" , EncryptPlain :
33+ return writer , nil
34+ case EncryptAes :
35+ default :
36+ return nil , fmt .Errorf ("unsupported encrypt method: %s" , encryptMethod )
37+ }
38+
2739 key , err := readAesKey (keyFile )
2840 if err != nil {
2941 return nil , err
3042 }
43+ return newAESCTRWriter (writer , key )
44+ }
45+
46+ func newAESCTRWriter (writer io.WriteCloser , key []byte ) (* aesCTRWriter , error ) {
3147 block , err := aes .NewCipher (key )
3248 if err != nil {
3349 return nil , errors .WithStack (err )
@@ -36,93 +52,76 @@ func newAESCTRWriter(writer Writer, keyFile string) (*aesCTRWriter, error) {
3652 if _ , err := io .ReadFull (rand .Reader , iv ); err != nil {
3753 return nil , errors .WithStack (err )
3854 }
39- return & aesCTRWriter {
40- Writer : writer ,
41- stream : cipher .NewCTR (block , iv ),
42- iv : iv ,
43- }, nil
44- }
45-
46- func (ctr * aesCTRWriter ) Write (data []byte ) error {
47- if ! ctr .inited {
48- if err := ctr .writeIV (); err != nil {
49- return err
50- }
51- ctr .inited = true
55+ ctr := & aesCTRWriter {
56+ WriteCloser : writer ,
57+ stream : cipher .NewCTR (block , iv ),
5258 }
53- ctr .stream . XORKeyStream ( data , data )
54- return ctr . Writer . Write ( data )
59+ _ , err = ctr .WriteCloser . Write ( iv )
60+ return ctr , err
5561}
5662
57- func (ctr * aesCTRWriter ) writeIV () error {
58- return ctr .Writer .Write (ctr .iv )
59- }
60-
61- func (ctr * aesCTRWriter ) Close () error {
62- return ctr .Writer .Close ()
63+ func (ctr * aesCTRWriter ) Write (data []byte ) (int , error ) {
64+ ctr .stream .XORKeyStream (data , data )
65+ return ctr .WriteCloser .Write (data )
6366}
6467
65- var _ Reader = (* aesCTRReader )(nil )
68+ var _ io. Reader = (* aesCTRReader )(nil )
6669
6770type aesCTRReader struct {
68- Reader
71+ io. Reader
6972 stream cipher.Stream
70- key []byte
7173}
7274
73- func newAESCTRReader (reader Reader , keyFile string ) (* aesCTRReader , error ) {
75+ func newReaderWithEncryptOpts (reader io.Reader , encryptMethod string , keyFile string ) (io.Reader , error ) {
76+ switch strings .ToLower (encryptMethod ) {
77+ case "" , EncryptPlain :
78+ return reader , nil
79+ case EncryptAes :
80+ default :
81+ return nil , fmt .Errorf ("unsupported encrypt method: %s" , encryptMethod )
82+ }
83+
7484 key , err := readAesKey (keyFile )
7585 if err != nil {
7686 return nil , err
7787 }
88+ return newAESCTRReader (reader , key )
89+ }
90+
91+ func newAESCTRReader (reader io.Reader , key []byte ) (* aesCTRReader , error ) {
92+ block , err := aes .NewCipher (key )
93+ if err != nil {
94+ return nil , errors .WithStack (err )
95+ }
96+ iv := make ([]byte , aes .BlockSize )
97+ for readLen := 0 ; readLen < len (iv ); {
98+ m , err := reader .Read (iv [readLen :])
99+ if err != nil {
100+ return nil , err
101+ }
102+ readLen += m
103+ }
78104 return & aesCTRReader {
79105 Reader : reader ,
80- key : key ,
106+ stream : cipher . NewCTR ( block , iv ) ,
81107 }, nil
82108}
83109
84110func (ctr * aesCTRReader ) Read (data []byte ) (int , error ) {
85- if ctr .stream == nil || reflect .ValueOf (ctr .stream ).IsNil () {
86- if err := ctr .init (); err != nil {
87- return 0 , err
88- }
89- }
90111 n , err := ctr .Reader .Read (data )
91112 if n > 0 {
92113 ctr .stream .XORKeyStream (data [:n ], data [:n ])
93114 }
94115 if err != nil {
95- return n , err
116+ return n , errors . WithStack ( err )
96117 }
97118 return n , nil
98119}
99120
100- func (ctr * aesCTRReader ) init () error {
101- block , err := aes .NewCipher (ctr .key )
102- if err != nil {
103- return errors .WithStack (err )
104- }
105- iv := make ([]byte , aes .BlockSize )
106- for readLen := 0 ; readLen < len (iv ); {
107- m , err := ctr .Reader .Read (iv [readLen :])
108- if err != nil {
109- return err
110- }
111- readLen += m
112- }
113- ctr .stream = cipher .NewCTR (block , iv )
114- return nil
115- }
116-
117- func (ctr * aesCTRReader ) CurFile () string {
118- return ctr .Reader .CurFile ()
119- }
120-
121- func (ctr * aesCTRReader ) Close () {
122- ctr .Reader .Close ()
123- }
124-
125121func readAesKey (filename string ) ([]byte , error ) {
122+ if len (filename ) == 0 {
123+ return nil , errors .New ("encryption key file name is not set" )
124+ }
126125 key , err := os .ReadFile (filename )
127126 if err != nil {
128127 return nil , errors .WithStack (err )
0 commit comments