Skip to content

Commit c9c5817

Browse files
authored
auth: reconnect backend (#389)
Signed-off-by: disksing <[email protected]>
1 parent 28d53d3 commit c9c5817

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

pkg/proxy/backend/authenticator.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/go-mysql-org/go-mysql/mysql"
15+
gomysql "github.com/go-mysql-org/go-mysql/mysql"
1516
"github.com/pingcap/tidb/util/hack"
1617
"github.com/pingcap/tiproxy/lib/util/errors"
1718
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
@@ -159,6 +160,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
159160
auth.attrs = clientResp.Attrs
160161
auth.zstdLevel = clientResp.ZstdLevel
161162

163+
RECONNECT:
164+
162165
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
163166
backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second)
164167
if err != nil {
@@ -214,7 +217,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
214217
pktIdx := 0
215218
loop:
216219
for {
217-
serverPkt, err := forwardMsg(backendIO, clientIO)
220+
serverPkt, err := backendIO.ReadPacket()
218221
if err != nil {
219222
// tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence
220223
// tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence
@@ -223,6 +226,23 @@ loop:
223226
}
224227
return err
225228
}
229+
var packetErr error
230+
if serverPkt[0] == pnet.ErrHeader.Byte() {
231+
packetErr = pnet.ParseErrorPacket(serverPkt)
232+
if handshakeHandler.HandleHandshakeErr(cctx, packetErr.(*gomysql.MyError)) {
233+
logger.Warn("handle handshake error, start reconnect", zap.Error(err))
234+
backendIO.Close()
235+
goto RECONNECT
236+
}
237+
}
238+
err = clientIO.WritePacket(serverPkt, true)
239+
if err != nil {
240+
return err
241+
}
242+
if packetErr != nil {
243+
return packetErr
244+
}
245+
226246
pktIdx++
227247
switch serverPkt[0] {
228248
case pnet.OKHeader.Byte():
@@ -233,8 +253,6 @@ loop:
233253
return err
234254
}
235255
return nil
236-
case pnet.ErrHeader.Byte():
237-
return pnet.ParseErrorPacket(serverPkt)
238256
default: // mysql.AuthSwitchRequest, ShaCommand
239257
if serverPkt[0] == pnet.AuthSwitchHeader.Byte() {
240258
pluginName = string(serverPkt[1 : bytes.IndexByte(serverPkt[1:], 0)+1])

pkg/proxy/backend/handshake_handler.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package backend
55

66
import (
7+
gomysql "github.com/go-mysql-org/go-mysql/mysql"
78
"github.com/pingcap/tiproxy/lib/util/errors"
89
"github.com/pingcap/tiproxy/pkg/manager/namespace"
910
"github.com/pingcap/tiproxy/pkg/manager/router"
@@ -70,6 +71,7 @@ type ConnContext interface {
7071

7172
type HandshakeHandler interface {
7273
HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error
74+
HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool // return true means retry connect
7375
GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error)
7476
OnHandshake(ctx ConnContext, to string, err error)
7577
OnConnClose(ctx ConnContext) error
@@ -94,6 +96,10 @@ func (handler *DefaultHandshakeHandler) HandleHandshakeResp(ConnContext, *pnet.H
9496
return nil
9597
}
9698

99+
func (handler *DefaultHandshakeHandler) HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool {
100+
return false
101+
}
102+
97103
func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
98104
ns, ok := handler.nsManager.GetNamespaceByUser(resp.User)
99105
if !ok {
@@ -142,6 +148,7 @@ type CustomHandshakeHandler struct {
142148
onTraffic func(ConnContext)
143149
onConnClose func(ConnContext) error
144150
handleHandshakeResp func(ctx ConnContext, resp *pnet.HandshakeResp) error
151+
handleHandshakeErr func(ctx ConnContext, err *gomysql.MyError) bool
145152
getCapability func() pnet.Capability
146153
getServerVersion func() string
147154
}
@@ -179,6 +186,13 @@ func (h *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet
179186
return nil
180187
}
181188

189+
func (h *CustomHandshakeHandler) HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool {
190+
if h.handleHandshakeErr != nil {
191+
return h.handleHandshakeErr(ctx, err)
192+
}
193+
return false
194+
}
195+
182196
func (h *CustomHandshakeHandler) GetCapability() pnet.Capability {
183197
if h.getCapability != nil {
184198
return h.getCapability()

0 commit comments

Comments
 (0)