Skip to content

Commit bff3314

Browse files
committed
fix broken connections in case of high latency
When listening and accepting an incoming connection request, the response might be received by the peer with some delay due to latency. This causes the peer to send a second connection request, that is not detected as duplicate because the first connection request has already been removed from the map that is used to check for duplicates (connReqs), so it is treated as a brand new connection request, breaking the first connection. This patch fixes the issue by replacing the map that is used to check for duplicates (connReqs) with a map that keeps track of both connection requests and accepted connections by their peer ID (connByPeers), allowing to discard handshakes with an existing peer ID.
1 parent d421390 commit bff3314

File tree

5 files changed

+156
-30
lines changed

5 files changed

+156
-30
lines changed

conn_request.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -264,24 +264,24 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest {
264264
}
265265

266266
ln.lock.Lock()
267-
_, exists := ln.connReqs[cif.SRTSocketId]
268-
if !exists {
269-
ln.connReqs[cif.SRTSocketId] = req
270-
}
271-
ln.lock.Unlock()
272267

273-
// We received a duplicate request: reject silently
268+
// Duplicate request: reject silently
269+
_, exists := ln.connsByPeer[cif.SRTSocketId]
274270
if exists {
271+
ln.lock.Unlock()
275272
return nil
276273
}
277274

275+
// Already fill connsByPeer for this peer socket ID
276+
ln.connsByPeer[cif.SRTSocketId] = nil
277+
278278
// Already reserve a socketId for this connection
279-
ln.lock.Lock()
280279
socketId, err := req.generateSocketId()
281280
if err == nil {
282281
ln.conns[socketId] = nil
283282
req.socketId = socketId
284283
}
284+
285285
ln.lock.Unlock()
286286

287287
// We couldn't create a socketId: reject silently
@@ -350,7 +350,7 @@ func (req *connRequest) Reject(reason RejectionReason) {
350350
req.ln.lock.Lock()
351351
defer req.ln.lock.Unlock()
352352

353-
if _, hasReq := req.ln.connReqs[req.peerSocketId]; !hasReq {
353+
if cr, hasReq := req.ln.connsByPeer[req.peerSocketId]; !hasReq || cr != nil {
354354
return
355355
}
356356

@@ -367,8 +367,8 @@ func (req *connRequest) Reject(reason RejectionReason) {
367367
req.ln.log("handshake:send:cif", func() string { return req.handshake.String() })
368368
req.ln.send(p)
369369

370-
delete(req.ln.connReqs, req.peerSocketId)
371370
delete(req.ln.conns, req.socketId)
371+
delete(req.ln.connsByPeer, req.peerSocketId)
372372
}
373373

374374
// generateSocketId generates an SRT SocketID that can be used for this connection
@@ -397,7 +397,7 @@ func (req *connRequest) Accept() (Conn, error) {
397397
req.ln.lock.Lock()
398398
defer req.ln.lock.Unlock()
399399

400-
if _, hasReq := req.ln.connReqs[req.peerSocketId]; !hasReq {
400+
if cr, hasReq := req.ln.connsByPeer[req.peerSocketId]; !hasReq || cr != nil {
401401
return nil, fmt.Errorf("connection already accepted")
402402
}
403403

@@ -472,7 +472,7 @@ func (req *connRequest) Accept() (Conn, error) {
472472
req.ln.send(p)
473473

474474
req.ln.conns[req.socketId] = conn
475-
delete(req.ln.connReqs, req.peerSocketId)
475+
req.ln.connsByPeer[req.peerSocketId] = conn
476476

477477
return conn, nil
478478
}

connection.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ type srtConn struct {
191191
readBuffer bytes.Buffer
192192

193193
onSend func(p packet.Packet)
194-
onShutdown func(socketId uint32)
194+
onShutdown func(*srtConn)
195195

196196
tick time.Duration
197197

@@ -234,7 +234,7 @@ type srtConnConfig struct {
234234
crypto crypto.Crypto
235235
keyBaseEncryption packet.PacketEncryption
236236
onSend func(p packet.Packet)
237-
onShutdown func(socketId uint32)
237+
onShutdown func(*srtConn)
238238
logger Logger
239239
}
240240

@@ -264,7 +264,7 @@ func newSRTConn(config srtConnConfig) *srtConn {
264264
}
265265

266266
if c.onShutdown == nil {
267-
c.onShutdown = func(socketId uint32) {}
267+
c.onShutdown = func(*srtConn) {}
268268
}
269269

270270
c.nextACKNumber = circular.New(1, packet.MAX_TIMESTAMP)
@@ -1415,7 +1415,7 @@ func (c *srtConn) close() {
14151415
c.log("connection:close", func() string { return "shutdown" })
14161416

14171417
go func() {
1418-
c.onShutdown(c.socketId)
1418+
c.onShutdown(c)
14191419
}()
14201420
})
14211421
}

dial.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ func (dl *dialer) handleHandshake(p packet.Packet) {
513513
crypto: dl.crypto,
514514
keyBaseEncryption: packet.EvenKeyEncrypted,
515515
onSend: dl.send,
516-
onShutdown: func(socketId uint32) { dl.Close() },
516+
onShutdown: func(*srtConn) { dl.Close() },
517517
logger: dl.config.Logger,
518518
})
519519

listen.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ type listener struct {
121121

122122
config Config
123123

124-
backlog chan packet.Packet
125-
connReqs map[uint32]*connRequest
126-
conns map[uint32]*srtConn
127-
lock sync.RWMutex
124+
backlog chan packet.Packet
125+
conns map[uint32]*srtConn
126+
connsByPeer map[uint32]*srtConn
127+
lock sync.RWMutex
128128

129129
start time.Time
130130

@@ -190,8 +190,8 @@ func Listen(network, address string, config Config) (Listener, error) {
190190
return nil, fmt.Errorf("listen: no local address")
191191
}
192192

193-
ln.connReqs = make(map[uint32]*connRequest)
194193
ln.conns = make(map[uint32]*srtConn)
194+
ln.connsByPeer = make(map[uint32]*srtConn)
195195

196196
ln.backlog = make(chan packet.Packet, 128)
197197

@@ -326,9 +326,10 @@ func (ln *listener) error() error {
326326
return ln.doneErr
327327
}
328328

329-
func (ln *listener) handleShutdown(socketId uint32) {
329+
func (ln *listener) handleShutdown(c *srtConn) {
330330
ln.lock.Lock()
331-
delete(ln.conns, socketId)
331+
delete(ln.conns, c.socketId)
332+
delete(ln.connsByPeer, c.peerSocketId)
332333
ln.lock.Unlock()
333334
}
334335

listen_test.go

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -597,27 +597,25 @@ func TestListenDiscardRepeatedHandshakes(t *testing.T) {
597597
ln, err := Listen("srt", "127.0.0.1:6003", DefaultConfig())
598598
require.NoError(t, err)
599599

600+
singleReqReceived := make(chan struct{})
601+
600602
listenDone := make(chan struct{})
601603
defer func() { <-listenDone }()
602604

603-
singleReqReceived := make(chan struct{})
605+
defer ln.Close()
604606

605607
go func() {
606608
defer close(listenDone)
607609

608-
var onlyRequest ConnRequest
609-
610610
for {
611611
req, err := ln.Accept2()
612612
if err != nil {
613613
break
614614
}
615615

616616
close(singleReqReceived)
617-
onlyRequest = req
617+
defer req.Reject(REJ_CLOSE)
618618
}
619-
620-
onlyRequest.Reject(REJ_CLOSE)
621619
}()
622620

623621
for i := 0; i < 4; i++ {
@@ -701,5 +699,132 @@ func TestListenDiscardRepeatedHandshakes(t *testing.T) {
701699
}
702700

703701
<-singleReqReceived
704-
ln.Close()
702+
}
703+
704+
func TestListenAcceptAndDiscardRepeatedHandshakes(t *testing.T) {
705+
ln, err := Listen("srt", "127.0.0.1:6003", DefaultConfig())
706+
require.NoError(t, err)
707+
708+
singleReqAccepted := make(chan struct{})
709+
710+
listenDone := make(chan struct{})
711+
defer func() { <-listenDone }()
712+
713+
defer ln.Close()
714+
715+
go func() {
716+
defer close(listenDone)
717+
718+
for {
719+
req, err := ln.Accept2()
720+
if err != nil {
721+
break
722+
}
723+
724+
conn, err := req.Accept()
725+
require.NoError(t, err)
726+
defer conn.Close()
727+
728+
close(singleReqAccepted)
729+
}
730+
}()
731+
732+
// Client sends initial connection request
733+
conn, err := net.Dial("udp", "127.0.0.1:6003")
734+
require.NoError(t, err)
735+
defer conn.Close()
736+
737+
// Send induction request
738+
p := packet.NewPacket(conn.RemoteAddr())
739+
p.Header().IsControlPacket = true
740+
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
741+
p.Header().SubType = 0
742+
p.Header().TypeSpecific = 0
743+
p.Header().Timestamp = 0
744+
p.Header().DestinationSocketId = 0
745+
sendcif := &packet.CIFHandshake{
746+
IsRequest: true,
747+
Version: 4,
748+
EncryptionField: 0,
749+
ExtensionField: 2,
750+
InitialPacketSequenceNumber: circular.New(10000, packet.MAX_SEQUENCENUMBER),
751+
MaxTransmissionUnitSize: MAX_MSS_SIZE,
752+
MaxFlowWindowSize: 25600,
753+
HandshakeType: packet.HSTYPE_INDUCTION,
754+
SRTSocketId: 55555,
755+
SynCookie: 0,
756+
}
757+
sendcif.PeerIP.FromNetAddr(conn.LocalAddr())
758+
p.MarshalCIF(sendcif)
759+
var buf bytes.Buffer
760+
err = p.Marshal(&buf)
761+
require.NoError(t, err)
762+
_, err = conn.Write(buf.Bytes())
763+
require.NoError(t, err)
764+
765+
// Read induction response
766+
inbuf := make([]byte, MAX_MSS_SIZE)
767+
n, err := conn.Read(inbuf)
768+
require.NoError(t, err)
769+
p, err = packet.NewPacketFromData(conn.RemoteAddr(), inbuf[:n])
770+
require.NoError(t, err)
771+
recvcif := &packet.CIFHandshake{}
772+
err = p.UnmarshalCIF(recvcif)
773+
require.NoError(t, err)
774+
775+
// Send conclusion request
776+
p.Header().IsControlPacket = true
777+
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
778+
p.Header().SubType = 0
779+
p.Header().TypeSpecific = 0
780+
p.Header().Timestamp = 0
781+
p.Header().DestinationSocketId = 0
782+
sendcif.Version = 5
783+
sendcif.ExtensionField = recvcif.ExtensionField
784+
sendcif.HandshakeType = packet.HSTYPE_CONCLUSION
785+
sendcif.SynCookie = recvcif.SynCookie
786+
sendcif.HasHS = true
787+
sendcif.SRTHS = &packet.CIFHandshakeExtension{
788+
SRTVersion: SRT_VERSION,
789+
SRTFlags: packet.CIFHandshakeExtensionFlags{
790+
TSBPDSND: true,
791+
TSBPDRCV: true,
792+
CRYPT: true,
793+
TLPKTDROP: true,
794+
PERIODICNAK: true,
795+
REXMITFLG: true,
796+
STREAM: false,
797+
PACKET_FILTER: false,
798+
},
799+
RecvTSBPDDelay: uint16(120),
800+
SendTSBPDDelay: uint16(120),
801+
}
802+
sendcif.HasSID = true
803+
sendcif.StreamId = "foobar"
804+
p.MarshalCIF(sendcif)
805+
buf.Reset()
806+
err = p.Marshal(&buf)
807+
require.NoError(t, err)
808+
_, err = conn.Write(buf.Bytes())
809+
require.NoError(t, err)
810+
811+
// read conclusion response
812+
n, err = conn.Read(inbuf)
813+
require.NoError(t, err)
814+
p, err = packet.NewPacketFromData(conn.RemoteAddr(), inbuf[:n])
815+
require.NoError(t, err)
816+
recvcif = &packet.CIFHandshake{}
817+
err = p.UnmarshalCIF(recvcif)
818+
require.NoError(t, err)
819+
require.Equal(t, packet.HSTYPE_CONCLUSION, recvcif.HandshakeType)
820+
require.False(t, recvcif.IsRequest)
821+
822+
<-singleReqAccepted
823+
824+
// send conclusion request, again
825+
_, err = conn.Write(buf.Bytes())
826+
require.NoError(t, err)
827+
828+
// wait some time to make sure that close(singleReqAccepted) is not triggered
829+
time.Sleep(500 * time.Millisecond)
705830
}

0 commit comments

Comments
 (0)