@@ -12,6 +12,7 @@ import (
1212 "time"
1313
1414 "github.com/go-mysql-org/go-mysql/mysql"
15+ gomysql "github.com/go-mysql-org/go-mysql/mysql"
1516 "github.com/pingcap/tidb/util/hack"
1617 "github.com/pingcap/tiproxy/lib/util/errors"
1718 pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
@@ -159,6 +160,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
159160 auth .attrs = clientResp .Attrs
160161 auth .zstdLevel = clientResp .ZstdLevel
161162
163+ RECONNECT:
164+
162165 // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
163166 backendIO , err := getBackendIO (cctx , auth , clientResp , 15 * time .Second )
164167 if err != nil {
@@ -214,7 +217,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
214217 pktIdx := 0
215218loop:
216219 for {
217- serverPkt , err := forwardMsg ( backendIO , clientIO )
220+ serverPkt , err := backendIO . ReadPacket ( )
218221 if err != nil {
219222 // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence
220223 // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence
@@ -223,6 +226,23 @@ loop:
223226 }
224227 return err
225228 }
229+ var packetErr error
230+ if serverPkt [0 ] == pnet .ErrHeader .Byte () {
231+ packetErr = pnet .ParseErrorPacket (serverPkt )
232+ if handshakeHandler .HandleHandshakeErr (cctx , packetErr .(* gomysql.MyError )) {
233+ logger .Warn ("handle handshake error, start reconnect" , zap .Error (err ))
234+ backendIO .Close ()
235+ goto RECONNECT
236+ }
237+ }
238+ err = clientIO .WritePacket (serverPkt , true )
239+ if err != nil {
240+ return err
241+ }
242+ if packetErr != nil {
243+ return packetErr
244+ }
245+
226246 pktIdx ++
227247 switch serverPkt [0 ] {
228248 case pnet .OKHeader .Byte ():
@@ -233,8 +253,6 @@ loop:
233253 return err
234254 }
235255 return nil
236- case pnet .ErrHeader .Byte ():
237- return pnet .ParseErrorPacket (serverPkt )
238256 default : // mysql.AuthSwitchRequest, ShaCommand
239257 if serverPkt [0 ] == pnet .AuthSwitchHeader .Byte () {
240258 pluginName = string (serverPkt [1 : bytes .IndexByte (serverPkt [1 :], 0 )+ 1 ])
0 commit comments