Skip to content

Commit 419e26d

Browse files
djshow832xhebox
andauthored
backend, net: Support compression protocol (#373)
Co-authored-by: xhe <[email protected]>
1 parent d3cc47c commit 419e26d

19 files changed

+1321
-180
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ require (
1111
github.com/gin-gonic/gin v1.8.1
1212
github.com/go-mysql-org/go-mysql v1.6.0
1313
github.com/go-sql-driver/mysql v1.7.0
14+
github.com/klauspost/compress v1.16.6
1415
github.com/pingcap/tidb v1.1.0-beta.0.20230103132820-3ccff46aa3bc
1516
github.com/pingcap/tidb/parser v0.0.0-20230103132820-3ccff46aa3bc
1617
github.com/pingcap/tiproxy/lib v0.0.0-00010101000000-000000000000

go.sum

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
415415
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
416416
github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
417417
github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
418-
github.com/klauspost/compress v1.15.13 h1:NFn1Wr8cfnenSJSA46lLq4wHCcBzKTSjnBIexDMMOV0=
418+
github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk=
419+
github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
419420
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
420421
github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s=
421422
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=

pkg/proxy/backend/authenticator.go

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ const defRequiredBackendCaps = pnet.ClientDeprecateEOF
2929

3030
// SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported.
3131
// TiDB supports ClientDeprecateEOF since v6.3.0.
32+
// TiDB supports ClientCompress and ClientZstdCompressionAlgorithm since v7.2.0.
3233
const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRows | pnet.ClientConnectWithDB |
3334
pnet.ClientODBC | pnet.ClientLocalFiles | pnet.ClientInteractive | pnet.ClientLongFlag | pnet.ClientSSL |
3435
pnet.ClientTransactions | pnet.ClientReserved | pnet.ClientSecureConnection | pnet.ClientMultiStatements |
3536
pnet.ClientMultiResults | pnet.ClientPluginAuth | pnet.ClientConnectAttrs | pnet.ClientPluginAuthLenencClientData |
36-
requiredFrontendCaps | defRequiredBackendCaps
37+
pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm | requiredFrontendCaps | defRequiredBackendCaps
3738

3839
// Authenticator handshakes with the client and the backend.
3940
type Authenticator struct {
@@ -42,6 +43,7 @@ type Authenticator struct {
4243
attrs map[string]string
4344
salt []byte
4445
capability pnet.Capability
46+
zstdLevel int
4547
collation uint8
4648
proxyProtocol bool
4749
requireBackendTLS bool
@@ -64,9 +66,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
6466
}
6567
// either from another proxy or directly from clients, we are acting as a proxy
6668
proxy.Command = proxyprotocol.ProxyCommandProxy
67-
if err := backendIO.WriteProxyV2(proxy); err != nil {
68-
return err
69-
}
69+
backendIO.EnableProxyClient(proxy)
7070
}
7171
return nil
7272
}
@@ -157,6 +157,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
157157
auth.dbname = clientResp.DB
158158
auth.collation = clientResp.Collation
159159
auth.attrs = clientResp.Attrs
160+
auth.zstdLevel = clientResp.ZstdLevel
160161

161162
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
162163
backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second)
@@ -225,6 +226,12 @@ loop:
225226
pktIdx++
226227
switch serverPkt[0] {
227228
case pnet.OKHeader.Byte():
229+
if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil {
230+
return err
231+
}
232+
if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
233+
return err
234+
}
228235
return nil
229236
case pnet.ErrHeader.Byte():
230237
return pnet.ParseErrorPacket(serverPkt)
@@ -277,7 +284,10 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
277284
return err
278285
}
279286

280-
return auth.handleSecondAuthResult(backendIO)
287+
if err = auth.handleSecondAuthResult(backendIO); err == nil {
288+
return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel)
289+
}
290+
return err
281291
}
282292

283293
func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
@@ -307,8 +317,9 @@ func (auth *Authenticator) writeAuthHandshake(
307317
Attrs: auth.attrs,
308318
Collation: auth.collation,
309319
AuthData: authData,
310-
Capability: auth.capability | authCap,
320+
Capability: auth.capability&backendCapability | authCap,
311321
AuthPlugin: authPlugin,
322+
ZstdLevel: auth.zstdLevel,
312323
}
313324

314325
if len(resp.Attrs) > 0 {
@@ -382,3 +393,13 @@ func (auth *Authenticator) changeUser(req *pnet.ChangeUserReq) {
382393
func (auth *Authenticator) updateCurrentDB(db string) {
383394
auth.dbname = db
384395
}
396+
397+
func setCompress(packetIO *pnet.PacketIO, capability pnet.Capability, zstdLevel int) error {
398+
algorithm := pnet.CompressionNone
399+
if capability&pnet.ClientCompress > 0 {
400+
algorithm = pnet.CompressionZlib
401+
} else if capability&pnet.ClientZstdCompressionAlgorithm > 0 {
402+
algorithm = pnet.CompressionZstd
403+
}
404+
return packetIO.SetCompressionAlgorithm(algorithm, zstdLevel)
405+
}

pkg/proxy/backend/authenticator_test.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,30 @@ func TestCapability(t *testing.T) {
164164
cfg.clientConfig.capability |= pnet.ClientSecureConnection
165165
},
166166
},
167+
{
168+
func(cfg *testConfig) {
169+
cfg.backendConfig.capability &= ^pnet.ClientCompress
170+
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
171+
},
172+
func(cfg *testConfig) {
173+
cfg.backendConfig.capability |= pnet.ClientCompress
174+
cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
175+
},
176+
},
177+
{
178+
func(cfg *testConfig) {
179+
cfg.clientConfig.capability &= ^pnet.ClientCompress
180+
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
181+
},
182+
func(cfg *testConfig) {
183+
cfg.clientConfig.capability |= pnet.ClientCompress
184+
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
185+
},
186+
func(cfg *testConfig) {
187+
cfg.clientConfig.capability |= pnet.ClientCompress
188+
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
189+
},
190+
},
167191
}
168192

169193
tc := newTCPConnSuite(t)
@@ -387,3 +411,138 @@ func TestProxyProtocol(t *testing.T) {
387411
clean()
388412
}
389413
}
414+
415+
func TestCompressProtocol(t *testing.T) {
416+
cfgs := [][]cfgOverrider{
417+
{
418+
func(cfg *testConfig) {
419+
cfg.backendConfig.capability &= ^pnet.ClientCompress
420+
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
421+
},
422+
func(cfg *testConfig) {
423+
cfg.backendConfig.capability |= pnet.ClientCompress
424+
cfg.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
425+
},
426+
},
427+
{
428+
func(cfg *testConfig) {
429+
cfg.clientConfig.capability &= ^pnet.ClientCompress
430+
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
431+
},
432+
func(cfg *testConfig) {
433+
cfg.clientConfig.capability |= pnet.ClientCompress
434+
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
435+
cfg.clientConfig.zstdLevel = 3
436+
},
437+
func(cfg *testConfig) {
438+
cfg.clientConfig.capability |= pnet.ClientCompress
439+
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
440+
cfg.clientConfig.zstdLevel = 9
441+
},
442+
func(cfg *testConfig) {
443+
cfg.clientConfig.capability |= pnet.ClientCompress
444+
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
445+
},
446+
},
447+
}
448+
449+
checker := func(t *testing.T, ts *testSuite, referCfg *testConfig) {
450+
// If the client enables compression, client <-> proxy enables compression.
451+
if referCfg.clientConfig.capability&pnet.ClientCompress > 0 {
452+
require.Greater(t, ts.mp.authenticator.capability&pnet.ClientCompress, pnet.Capability(0))
453+
require.Greater(t, ts.mc.capability&pnet.ClientCompress, pnet.Capability(0))
454+
} else {
455+
require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientCompress)
456+
require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientCompress)
457+
}
458+
// If both the client and the backend enables compression, proxy <-> backend enables compression.
459+
if referCfg.clientConfig.capability&referCfg.backendConfig.capability&pnet.ClientCompress > 0 {
460+
require.Greater(t, ts.mb.capability&pnet.ClientCompress, pnet.Capability(0))
461+
} else {
462+
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress)
463+
}
464+
// If the client enables zstd compression, client <-> proxy enables zstd compression.
465+
zstdCap := pnet.ClientCompress | pnet.ClientZstdCompressionAlgorithm
466+
if referCfg.clientConfig.capability&zstdCap == zstdCap {
467+
require.Greater(t, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
468+
require.Greater(t, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
469+
require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mp.authenticator.zstdLevel)
470+
} else {
471+
require.Equal(t, pnet.Capability(0), ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm)
472+
require.Equal(t, pnet.Capability(0), ts.mc.capability&pnet.ClientZstdCompressionAlgorithm)
473+
}
474+
// If both the client and the backend enables zstd compression, proxy <-> backend enables zstd compression.
475+
if referCfg.clientConfig.capability&referCfg.backendConfig.capability&zstdCap == zstdCap {
476+
require.Greater(t, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm, pnet.Capability(0))
477+
require.Equal(t, referCfg.clientConfig.zstdLevel, ts.mb.zstdLevel)
478+
} else {
479+
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientZstdCompressionAlgorithm)
480+
}
481+
}
482+
483+
tc := newTCPConnSuite(t)
484+
cfgOverriders := getCfgCombinations(cfgs)
485+
for _, cfgs := range cfgOverriders {
486+
referCfg := newTestConfig(cfgs...)
487+
ts, clean := newTestSuite(t, tc, cfgs...)
488+
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
489+
checker(t, ts, referCfg)
490+
})
491+
ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {
492+
checker(t, ts, referCfg)
493+
})
494+
clean()
495+
}
496+
}
497+
498+
// After upgrading the backend, the backend capability may change.
499+
func TestUpgradeBackendCap(t *testing.T) {
500+
cfgs := [][]cfgOverrider{
501+
{
502+
func(cfg *testConfig) {
503+
cfg.clientConfig.capability &= ^pnet.ClientCompress
504+
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
505+
},
506+
func(cfg *testConfig) {
507+
cfg.clientConfig.capability |= pnet.ClientCompress
508+
cfg.clientConfig.capability |= pnet.ClientZstdCompressionAlgorithm
509+
cfg.clientConfig.zstdLevel = 3
510+
},
511+
func(cfg *testConfig) {
512+
cfg.clientConfig.capability |= pnet.ClientCompress
513+
cfg.clientConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
514+
},
515+
},
516+
{
517+
func(cfg *testConfig) {
518+
cfg.backendConfig.capability &= ^pnet.ClientCompress
519+
cfg.backendConfig.capability &= ^pnet.ClientZstdCompressionAlgorithm
520+
},
521+
},
522+
}
523+
524+
tc := newTCPConnSuite(t)
525+
cfgOverriders := getCfgCombinations(cfgs)
526+
for _, cfgs := range cfgOverriders {
527+
referCfg := newTestConfig(cfgs...)
528+
ts, clean := newTestSuite(t, tc, cfgs...)
529+
// Before upgrade, the backend doesn't support compression.
530+
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
531+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress)
532+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress)
533+
require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress)
534+
})
535+
// After upgrade, the backend also supports compression.
536+
ts.mb.backendConfig.capability |= pnet.ClientCompress
537+
ts.mb.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm
538+
ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) {
539+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress)
540+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress)
541+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mb.capability&pnet.ClientCompress)
542+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mc.capability&pnet.ClientZstdCompressionAlgorithm)
543+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mp.authenticator.capability&pnet.ClientZstdCompressionAlgorithm)
544+
require.Equal(t, referCfg.clientConfig.capability&pnet.ClientZstdCompressionAlgorithm, ts.mb.capability&pnet.ClientZstdCompressionAlgorithm)
545+
})
546+
clean()
547+
}
548+
}

pkg/proxy/backend/backend_conn_mgr_test.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ func TestSpecialCmds(t *testing.T) {
527527
require.NoError(t, ts.redirectSucceed4Backend(packetIO))
528528
require.Equal(t, "another_user", ts.mb.username)
529529
require.Equal(t, "session_db", ts.mb.db)
530-
expectCap := pnet.Capability(ts.mp.handshakeHandler.GetCapability() &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData))
531-
gotCap := pnet.Capability(ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData)
530+
expectCap := ts.mp.handshakeHandler.GetCapability() & defaultTestClientCapability &^ (pnet.ClientMultiStatements | pnet.ClientPluginAuthLenencClientData)
531+
gotCap := ts.mb.capability &^ pnet.ClientPluginAuthLenencClientData
532532
require.Equal(t, expectCap, gotCap, "expected=%s,got=%s", expectCap, gotCap)
533533
return nil
534534
},
@@ -793,18 +793,16 @@ func TestHandlerReturnError(t *testing.T) {
793793
}
794794

795795
func TestOnTraffic(t *testing.T) {
796-
i := 0
797-
inbytes, outbytes := []int{
798-
0x99,
799-
}, []int{
800-
0xce,
801-
}
796+
var inBytes, outBytes uint64
802797
ts := newBackendMgrTester(t, func(config *testConfig) {
803798
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
804799
config.proxyConfig.handler.onTraffic = func(cc ConnContext) {
805-
require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes())
806-
require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes())
807-
i++
800+
require.Greater(t, cc.ClientInBytes(), uint64(0))
801+
require.GreaterOrEqual(t, cc.ClientInBytes(), inBytes)
802+
inBytes = cc.ClientInBytes()
803+
require.Greater(t, cc.ClientOutBytes(), uint64(0))
804+
require.GreaterOrEqual(t, cc.ClientOutBytes(), outBytes)
805+
outBytes = cc.ClientOutBytes()
808806
}
809807
})
810808
runners := []runner{

pkg/proxy/backend/common_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,24 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() {
8484
}
8585
}
8686

87+
func (tc *tcpConnSuite) reconnectBackend(t *testing.T) {
88+
lg, _ := logger.CreateLoggerForTest(t)
89+
var wg waitgroup.WaitGroup
90+
wg.Run(func() {
91+
_ = tc.backendIO.Close()
92+
conn, err := tc.backendListener.Accept()
93+
require.NoError(t, err)
94+
tc.backendIO = pnet.NewPacketIO(conn, lg)
95+
})
96+
wg.Run(func() {
97+
_ = tc.proxyBIO.Close()
98+
backendConn, err := net.Dial("tcp", tc.backendListener.Addr().String())
99+
require.NoError(t, err)
100+
tc.proxyBIO = pnet.NewPacketIO(backendConn, lg)
101+
})
102+
wg.Wait()
103+
}
104+
87105
func (tc *tcpConnSuite) run(clientRunner, backendRunner func(*pnet.PacketIO) error, proxyRunner func(*pnet.PacketIO, *pnet.PacketIO) error) (cerr, berr, perr error) {
88106
var wg waitgroup.WaitGroup
89107
if clientRunner != nil {

pkg/proxy/backend/mock_backend_test.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ type mockBackend struct {
4747
// Inputs that assigned by the test and will be sent to the client.
4848
*backendConfig
4949
// Outputs that received from the client and will be checked by the test.
50-
username string
51-
db string
52-
attrs map[string]string
53-
authData []byte
50+
username string
51+
db string
52+
attrs map[string]string
53+
authData []byte
54+
zstdLevel int
5455
}
5556

5657
func newMockBackend(cfg *backendConfig) *mockBackend {
@@ -98,6 +99,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
9899
mb.authData = resp.AuthData
99100
mb.attrs = resp.Attrs
100101
mb.capability = resp.Capability
102+
mb.zstdLevel = resp.ZstdLevel
101103
// verify password
102104
return mb.verifyPassword(packetIO, resp)
103105
}
@@ -125,6 +127,9 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
125127
if err := packetIO.WriteOKPacket(mb.status, pnet.OKHeader); err != nil {
126128
return err
127129
}
130+
if err := setCompress(packetIO, mb.capability, mb.zstdLevel); err != nil {
131+
return err
132+
}
128133
} else {
129134
if err := packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR)); err != nil {
130135
return err

pkg/proxy/backend/mock_client_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type clientConfig struct {
2626
capability pnet.Capability
2727
collation uint8
2828
cmd pnet.Command
29+
zstdLevel int
2930
// for both auth and cmd
3031
abnormalExit bool
3132
}
@@ -82,6 +83,7 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error {
8283
AuthData: mc.authData,
8384
Capability: mc.capability,
8485
Collation: mc.collation,
86+
ZstdLevel: mc.zstdLevel,
8587
}
8688
pkt = pnet.MakeHandshakeResponse(resp)
8789
if mc.capability&pnet.ClientSSL > 0 {

0 commit comments

Comments
 (0)