2525package net
2626
2727import (
28- "bufio"
2928 "crypto/tls"
3029 "io"
3130 "net"
@@ -37,6 +36,7 @@ import (
3736 "github.com/pingcap/tiproxy/lib/util/errors"
3837 "github.com/pingcap/tiproxy/pkg/proxy/keepalive"
3938 "github.com/pingcap/tiproxy/pkg/proxy/proxyprotocol"
39+ "github.com/pingcap/tiproxy/pkg/util/bufio"
4040 "go.uber.org/zap"
4141)
4242
@@ -64,6 +64,7 @@ type packetReadWriter interface {
6464 Discard (n int ) (int , error )
6565 Flush () error
6666 DirectWrite (p []byte ) (int , error )
67+ ReadFrom (r io.Reader ) (int64 , error )
6768 Proxy () * proxyprotocol.Proxy
6869 TLSConnectionState () tls.ConnectionState
6970 InBytes () uint64
@@ -110,6 +111,12 @@ func (brw *basicReadWriter) Write(p []byte) (int, error) {
110111 return n , errors .WithStack (err )
111112}
112113
114+ func (brw * basicReadWriter ) ReadFrom (r io.Reader ) (int64 , error ) {
115+ n , err := brw .ReadWriter .ReadFrom (r )
116+ brw .outBytes += uint64 (n )
117+ return n , errors .WithStack (err )
118+ }
119+
113120func (brw * basicReadWriter ) DirectWrite (p []byte ) (int , error ) {
114121 n , err := brw .Conn .Write (p )
115122 brw .outBytes += uint64 (n )
@@ -187,7 +194,8 @@ type PacketIO struct {
187194 lastKeepAlive config.KeepAlive
188195 rawConn net.Conn
189196 readWriter packetReadWriter
190- header []byte
197+ limitReader io.LimitedReader // reuse memory to reduce allocation
198+ header []byte // reuse memory to reduce allocation
191199 logger * zap.Logger
192200 remoteAddr net.Addr
193201 wrap error
@@ -317,6 +325,42 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) {
317325 return nil
318326}
319327
328+ func (p * PacketIO ) ForwardUntil (dest * PacketIO , isEnd func (firstByte byte , firstPktLen int ) bool , process func (response []byte ) error ) error {
329+ p .readWriter .BeginRW (rwRead )
330+ dest .readWriter .BeginRW (rwWrite )
331+ p .limitReader .R = p .readWriter
332+ for {
333+ header , err := p .readWriter .Peek (5 )
334+ if err != nil {
335+ return errors .Wrap (ErrReadConn , err )
336+ }
337+ length := int (header [0 ]) | int (header [1 ])<< 8 | int (header [2 ])<< 16
338+ if isEnd (header [4 ], length ) {
339+ // TODO: allocate a buffer from pool and return the buffer after `process`.
340+ data , err := p .ReadPacket ()
341+ if err != nil {
342+ return errors .Wrap (ErrReadConn , err )
343+ }
344+ if err := dest .WritePacket (data , false ); err != nil {
345+ return errors .Wrap (ErrWriteConn , err )
346+ }
347+ return process (data )
348+ } else {
349+ sequence , pktSequence := header [3 ], p .readWriter .Sequence ()
350+ if sequence != pktSequence {
351+ return ErrInvalidSequence .GenWithStack ("invalid sequence, expected %d, actual %d" , pktSequence , sequence )
352+ }
353+ p .readWriter .SetSequence (sequence + 1 )
354+ // Sequence may be different (e.g. with compression) so we can't just copy the data to the destination.
355+ dest .readWriter .SetSequence (dest .readWriter .Sequence () + 1 )
356+ p .limitReader .N = int64 (length + 4 )
357+ if _ , err := dest .readWriter .ReadFrom (& p .limitReader ); err != nil {
358+ return errors .Wrap (ErrRelayConn , err )
359+ }
360+ }
361+ }
362+ }
363+
320364func (p * PacketIO ) InBytes () uint64 {
321365 return p .readWriter .InBytes ()
322366}
0 commit comments