Skip to content

Commit ff4a028

Browse files
committed
Split RTCP compound packet before forwarding
When a compound RTCP packet is received, it is forwarded as is to the read streams. Downstream components have to check for the SSRC in the packets matching the track(s) it is handling. This leads to situations where downstream components could see duplicate RTCP reports. Happens when there is a downstream handler for all SSRCs. It receives the compound packet, unmarshals it and invokes the handlers. As it is fielding all SSRCs, it will get the same compound packet `n` times and invoke the handlers `n` times. This PR splits up the compound packet and forwards individual packets to avoid the end handlers from seeing duplicates. API wise, it is compatible as it still emits an encoded/marshaled packet. But, this does add a marshaling step. Add an unit test. Before this change, the test fails at the point where the test tries to read the CNAME packet and sees an incorrect type.
1 parent 1156dbf commit ff4a028

File tree

2 files changed

+152
-24
lines changed

2 files changed

+152
-24
lines changed

session_srtcp.go

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,11 @@ func (s *SessionSRTCP) setWriteDeadline(t time.Time) error {
138138
return s.session.nextConn.SetWriteDeadline(t)
139139
}
140140

141-
// create a list of Destination SSRCs
142-
// that's a superset of all Destinations in the slice.
143-
func destinationSSRC(pkts []rtcp.Packet) []uint32 {
141+
// create a list of Destination SSRCs for the packet.
142+
func destinationSSRC(pkt rtcp.Packet) []uint32 {
144143
ssrcSet := make(map[uint32]struct{})
145-
for _, p := range pkts {
146-
for _, ssrc := range p.DestinationSSRC() {
147-
ssrcSet[ssrc] = struct{}{}
148-
}
144+
for _, ssrc := range pkt.DestinationSSRC() {
145+
ssrcSet[ssrc] = struct{}{}
149146
}
150147

151148
out := make([]uint32, 0, len(ssrcSet))
@@ -156,36 +153,44 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 {
156153
return out
157154
}
158155

156+
//nolint:cyclop
159157
func (s *SessionSRTCP) decrypt(buf []byte) error {
160158
decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
161159
if err != nil {
162160
return err
163161
}
164162

165-
pkt, err := rtcp.Unmarshal(decrypted)
163+
pkts, err := rtcp.Unmarshal(decrypted)
166164
if err != nil {
167165
return err
168166
}
169167

170-
for _, ssrc := range destinationSSRC(pkt) {
171-
r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
172-
if r == nil {
173-
return nil // Session has been closed
174-
} else if isNew {
175-
if !s.session.acceptStreamTimeout.IsZero() {
176-
_ = s.session.nextConn.SetReadDeadline(time.Time{})
177-
}
178-
s.session.newStream <- r // Notify AcceptStream
168+
for _, pkt := range pkts {
169+
marshaled, err := pkt.Marshal()
170+
if err != nil {
171+
return err
179172
}
180173

181-
readStream, ok := r.(*ReadStreamSRTCP)
182-
if !ok {
183-
return errFailedTypeAssertion
184-
}
174+
for _, ssrc := range destinationSSRC(pkt) {
175+
r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
176+
if r == nil {
177+
return nil // Session has been closed
178+
} else if isNew {
179+
if !s.session.acceptStreamTimeout.IsZero() {
180+
_ = s.session.nextConn.SetReadDeadline(time.Time{})
181+
}
182+
s.session.newStream <- r // Notify AcceptStream
183+
}
185184

186-
_, err = readStream.write(decrypted)
187-
if err != nil {
188-
return err
185+
readStream, ok := r.(*ReadStreamSRTCP)
186+
if !ok {
187+
return errFailedTypeAssertion
188+
}
189+
190+
_, err = readStream.write(marshaled)
191+
if err != nil {
192+
return err
193+
}
189194
}
190195
}
191196

session_srtcp_test.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,129 @@ func TestSessionSRTCPReplayProtection(t *testing.T) {
245245
expectedSSRC, receivedSSRC)
246246
}
247247

248+
func TestSessionSRTCPCompoundPacket(t *testing.T) {
249+
lim := test.TimeOut(time.Second * 5)
250+
defer lim.Stop()
251+
252+
report := test.CheckRoutines(t)
253+
defer report()
254+
255+
testSSRCSenderReport1SR := uint32(0x902f9e2e)
256+
testSSRCSenderReport1RR := uint32(0xbc5e9a40)
257+
testSSRCCNAME := uint32(1234)
258+
testSSRCSenderReport2SR := uint32(0x12345678)
259+
aSession, bSession := buildSessionSRTCPPair(t)
260+
bReadStreamSR1SR, err := bSession.OpenReadStream(testSSRCSenderReport1SR)
261+
assert.NoError(t, err)
262+
bReadStreamSR1RR, err := bSession.OpenReadStream(testSSRCSenderReport1RR)
263+
assert.NoError(t, err)
264+
bReadStreamCNAME, err := bSession.OpenReadStream(testSSRCCNAME)
265+
assert.NoError(t, err)
266+
bReadStreamSR2SR, err := bSession.OpenReadStream(testSSRCSenderReport2SR)
267+
assert.NoError(t, err)
268+
269+
// Compound packet
270+
// first packet - Sender Report with a Receiver Report
271+
// seconde packet - Sender Report without a Receiver Report
272+
cp := &rtcp.CompoundPacket{
273+
&rtcp.SenderReport{
274+
SSRC: testSSRCSenderReport1SR,
275+
NTPTime: 0xda8bd1fcdddda05a,
276+
RTPTime: 0xaaf4edd5,
277+
PacketCount: 1,
278+
OctetCount: 2,
279+
Reports: []rtcp.ReceptionReport{{
280+
SSRC: testSSRCSenderReport1RR,
281+
FractionLost: 0,
282+
TotalLost: 0,
283+
LastSequenceNumber: 0x46e1,
284+
Jitter: 273,
285+
LastSenderReport: 0x9f36432,
286+
Delay: 150137,
287+
}},
288+
ProfileExtensions: []byte{
289+
0x81, 0xca, 0x0, 0x6,
290+
0x2b, 0x7e, 0xc0, 0xc5,
291+
0x1, 0x10, 0x4c, 0x63,
292+
0x49, 0x66, 0x7a, 0x58,
293+
0x6f, 0x6e, 0x44, 0x6f,
294+
0x72, 0x64, 0x53, 0x65,
295+
0x57, 0x36, 0x0, 0x0,
296+
},
297+
},
298+
rtcp.NewCNAMESourceDescription(testSSRCCNAME, "cname"), // to make it a valid compound packet
299+
&rtcp.SenderReport{
300+
SSRC: testSSRCSenderReport2SR,
301+
NTPTime: 0xda8bd1fcdddda05a,
302+
RTPTime: 0xaaf4edd5,
303+
PacketCount: 1,
304+
OctetCount: 2,
305+
},
306+
}
307+
308+
done := make(chan struct{})
309+
go func() {
310+
readBuffer := make([]byte, 200)
311+
312+
senderReport := &rtcp.SenderReport{}
313+
n, _, rerr := bReadStreamSR1SR.ReadRTCP(readBuffer)
314+
assert.NoError(t, rerr)
315+
rerr = senderReport.Unmarshal(readBuffer[:n])
316+
assert.NoError(t, rerr)
317+
assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC)
318+
assert.Len(t, senderReport.Reports, 1)
319+
assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC)
320+
assert.Len(t, senderReport.DestinationSSRC(), 2)
321+
assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC())
322+
323+
// should read via receiver report embedded in sender report
324+
senderReport = &rtcp.SenderReport{}
325+
n, _, rerr = bReadStreamSR1RR.ReadRTCP(readBuffer)
326+
assert.NoError(t, rerr)
327+
rerr = senderReport.Unmarshal(readBuffer[:n])
328+
assert.NoError(t, rerr)
329+
assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC)
330+
assert.Len(t, senderReport.Reports, 1)
331+
assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC)
332+
assert.Len(t, senderReport.DestinationSSRC(), 2)
333+
assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC())
334+
335+
cname := &rtcp.SourceDescription{}
336+
n, _, rerr = bReadStreamCNAME.ReadRTCP(readBuffer)
337+
assert.NoError(t, rerr)
338+
rerr = cname.Unmarshal(readBuffer[:n])
339+
assert.NoError(t, rerr)
340+
assert.Len(t, cname.DestinationSSRC(), 1)
341+
assert.Equal(t, uint32(1234), cname.DestinationSSRC()[0])
342+
343+
senderReport = &rtcp.SenderReport{}
344+
n, _, rerr = bReadStreamSR2SR.ReadRTCP(readBuffer)
345+
assert.NoError(t, rerr)
346+
rerr = senderReport.Unmarshal(readBuffer[:n])
347+
assert.NoError(t, rerr)
348+
assert.Equal(t, uint32(0x12345678), senderReport.SSRC)
349+
assert.Len(t, senderReport.Reports, 0)
350+
assert.Len(t, senderReport.DestinationSSRC(), 1)
351+
assert.ElementsMatch(t, []uint32{0x12345678}, senderReport.DestinationSSRC())
352+
353+
close(done)
354+
}()
355+
356+
encrypted, err := encryptSRTCP(aSession.session.localContext, cp)
357+
assert.NoError(t, err)
358+
_, err = aSession.session.nextConn.Write(encrypted)
359+
assert.NoError(t, err)
360+
361+
<-done
362+
363+
assert.NoError(t, aSession.Close())
364+
assert.NoError(t, bSession.Close())
365+
assert.NoError(t, bReadStreamSR1SR.Close())
366+
assert.NoError(t, bReadStreamSR1RR.Close())
367+
assert.NoError(t, bReadStreamCNAME.Close())
368+
assert.NoError(t, bReadStreamSR2SR.Close())
369+
}
370+
248371
// nolint: dupl
249372
func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) {
250373
lim := test.TimeOut(time.Second * 5)

0 commit comments

Comments
 (0)