Skip to content

Commit d8f51d4

Browse files
committed
Add interceptor tests
1 parent 1961dc7 commit d8f51d4

File tree

6 files changed

+382
-22
lines changed

6 files changed

+382
-22
lines changed

internal/test/mock_stream.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type RTPWithError struct {
4141
// RTCPWithError is used to send a batch of rtcp packets or an error on a channel
4242
type RTCPWithError struct {
4343
Packets []rtcp.Packet
44+
Attr interceptor.Attributes
4445
Err error
4546
}
4647

@@ -107,21 +108,21 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc
107108
go func() {
108109
buf := make([]byte, 1500)
109110
for {
110-
i, _, err := s.rtcpReader.Read(buf, interceptor.Attributes{})
111+
i, attr, err := s.rtcpReader.Read(buf, interceptor.Attributes{})
111112
if err != nil {
112113
if !errors.Is(err, io.EOF) {
113-
s.rtcpInModified <- RTCPWithError{Err: err}
114+
s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err}
114115
}
115116
return
116117
}
117118

118119
pkts, err := rtcp.Unmarshal(buf[:i])
119120
if err != nil {
120-
s.rtcpInModified <- RTCPWithError{Err: err}
121+
s.rtcpInModified <- RTCPWithError{Attr: attr, Err: err}
121122
return
122123
}
123124

124-
s.rtcpInModified <- RTCPWithError{Packets: pkts}
125+
s.rtcpInModified <- RTCPWithError{Attr: attr, Packets: pkts}
125126
}
126127
}()
127128
go func() {

pkg/ccfb/history.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type sentPacket struct {
3131
departure time.Time
3232
}
3333

34-
type history struct {
34+
type historyList struct {
3535
lock sync.Mutex
3636
size int
3737
evictList *list.List
@@ -40,8 +40,8 @@ type history struct {
4040
ackedSeqNr *sequencenumber.Unwrapper
4141
}
4242

43-
func newHistory(size int) *history {
44-
return &history{
43+
func newHistoryList(size int) *historyList {
44+
return &historyList{
4545
lock: sync.Mutex{},
4646
size: size,
4747
evictList: list.New(),
@@ -51,7 +51,7 @@ func newHistory(size int) *history {
5151
}
5252
}
5353

54-
func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
54+
func (h *historyList) add(seqNr uint16, size uint16, departure time.Time) error {
5555
h.lock.Lock()
5656
defer h.lock.Unlock()
5757

@@ -76,7 +76,7 @@ func (h *history) add(seqNr uint16, size uint16, departure time.Time) error {
7676
}
7777

7878
// Must be called while holding the lock
79-
func (h *history) removeOldest() {
79+
func (h *historyList) removeOldest() {
8080
if ent := h.evictList.Front(); ent != nil {
8181
v := h.evictList.Remove(ent)
8282
if sp, ok := v.(sentPacket); ok {
@@ -85,7 +85,7 @@ func (h *history) removeOldest() {
8585
}
8686
}
8787

88-
func (h *history) getReportForAck(al acknowledgementList) PacketReportList {
88+
func (h *historyList) getReportForAck(al acknowledgementList) PacketReportList {
8989
h.lock.Lock()
9090
defer h.lock.Unlock()
9191

pkg/ccfb/history_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010

1111
func TestHistory(t *testing.T) {
1212
t.Run("errorOnDecreasingSeqNr", func(t *testing.T) {
13-
h := newHistory(200)
13+
h := newHistoryList(200)
1414
assert.NoError(t, h.add(10, 1200, time.Now()))
1515
assert.NoError(t, h.add(11, 1200, time.Now()))
1616
assert.Error(t, h.add(9, 1200, time.Now()))
@@ -84,7 +84,7 @@ func TestHistory(t *testing.T) {
8484
}
8585
for i, tc := range cases {
8686
t.Run(fmt.Sprintf("%v", i), func(t *testing.T) {
87-
h := newHistory(200)
87+
h := newHistoryList(200)
8888
for _, op := range tc.outgoing {
8989
assert.NoError(t, h.add(op.seqNr, op.size, op.ts))
9090
}
@@ -97,7 +97,7 @@ func TestHistory(t *testing.T) {
9797
})
9898

9999
t.Run("garbageCollection", func(t *testing.T) {
100-
h := newHistory(200)
100+
h := newHistoryList(200)
101101

102102
for i := uint16(0); i < 300; i++ {
103103
assert.NoError(t, h.add(i, 1200, time.Time{}.Add(time.Duration(i)*time.Millisecond)))

pkg/ccfb/interceptor.go

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,48 @@ type ccfbAttributesKeyType uint32
1515

1616
const CCFBAttributesKey ccfbAttributesKeyType = iota
1717

18+
type history interface {
19+
add(seqNr uint16, size uint16, departure time.Time) error
20+
getReportForAck(al acknowledgementList) PacketReportList
21+
}
22+
1823
type Option func(*Interceptor) error
1924

25+
func HistorySize(size int) Option {
26+
return func(i *Interceptor) error {
27+
i.historySize = size
28+
return nil
29+
}
30+
}
31+
32+
func timeFactory(f func() time.Time) Option {
33+
return func(i *Interceptor) error {
34+
i.timestamp = f
35+
return nil
36+
}
37+
}
38+
39+
func historyFactory(f func(int) history) Option {
40+
return func(i *Interceptor) error {
41+
i.historyFactory = f
42+
return nil
43+
}
44+
}
45+
46+
func ccfbConverterFactory(f func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList)) Option {
47+
return func(i *Interceptor) error {
48+
i.convertCCFB = f
49+
return nil
50+
}
51+
}
52+
53+
func twccConverterFactory(f func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList)) Option {
54+
return func(i *Interceptor) error {
55+
i.convertTWCC = f
56+
return nil
57+
}
58+
}
59+
2060
type InterceptorFactory struct {
2161
opts []Option
2262
}
@@ -30,8 +70,15 @@ func NewInterceptor(opts ...Option) (*InterceptorFactory, error) {
3070
func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
3171
i := &Interceptor{
3272
NoOp: interceptor.NoOp{},
73+
lock: sync.Mutex{},
3374
timestamp: time.Now,
34-
ssrcToHistory: make(map[uint32]*history),
75+
convertCCFB: convertCCFB,
76+
convertTWCC: convertTWCC,
77+
ssrcToHistory: make(map[uint32]history),
78+
historySize: 200,
79+
historyFactory: func(size int) history {
80+
return newHistoryList(size)
81+
},
3582
}
3683
for _, opt := range f.opts {
3784
if err := opt(i); err != nil {
@@ -43,9 +90,13 @@ func (f *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor,
4390

4491
type Interceptor struct {
4592
interceptor.NoOp
46-
lock sync.Mutex
47-
timestamp func() time.Time
48-
ssrcToHistory map[uint32]*history
93+
lock sync.Mutex
94+
timestamp func() time.Time
95+
convertCCFB func(ts time.Time, feedback *rtcp.CCFeedbackReport) (time.Time, map[uint32]acknowledgementList)
96+
convertTWCC func(ts time.Time, feedback *rtcp.TransportLayerCC) (time.Time, map[uint32]acknowledgementList)
97+
ssrcToHistory map[uint32]history
98+
historySize int
99+
historyFactory func(int) history
49100
}
50101

51102
// BindLocalStream implements interceptor.Interceptor.
@@ -67,7 +118,7 @@ func (i *Interceptor) BindLocalStream(info *interceptor.StreamInfo, writer inter
67118
if useTWCC {
68119
ssrc = 0
69120
}
70-
i.ssrcToHistory[ssrc] = newHistory(200)
121+
i.ssrcToHistory[ssrc] = i.historyFactory(i.historySize)
71122

72123
return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
73124
i.lock.Lock()
@@ -109,14 +160,18 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.
109160
pktReportLists := map[uint32]*PacketReportList{}
110161

111162
pkts, err := attr.GetRTCPPackets(buf)
163+
if err != nil {
164+
return n, attr, err
165+
}
112166
for _, pkt := range pkts {
113167
var reportLists map[uint32]acknowledgementList
114168
var reportDeparture time.Time
115169
switch fb := pkt.(type) {
116170
case *rtcp.CCFeedbackReport:
117-
reportDeparture, reportLists = convertCCFB(now, fb)
171+
reportDeparture, reportLists = i.convertCCFB(now, fb)
118172
case *rtcp.TransportLayerCC:
119-
reportDeparture, reportLists = convertTWCC(now, fb)
173+
reportDeparture, reportLists = i.convertTWCC(now, fb)
174+
default:
120175
}
121176
for ssrc, reportList := range reportLists {
122177
prl := i.ssrcToHistory[ssrc].getReportForAck(reportList)

0 commit comments

Comments
 (0)