Skip to content

Commit 8dddb51

Browse files
committed
Implement pacing interceptor
1 parent 40d68d9 commit 8dddb51

File tree

5 files changed

+432
-1
lines changed

5 files changed

+432
-1
lines changed

go.mod

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
module github.com/pion/interceptor
22

3-
go 1.21
3+
go 1.21.0
44

55
require (
66
github.com/pion/logging v0.2.4
77
github.com/pion/rtcp v1.2.16
88
github.com/pion/rtp v1.8.25
99
github.com/pion/transport/v3 v3.1.1
1010
github.com/stretchr/testify v1.11.1
11+
golang.org/x/time v0.10.0
1112
)
1213

1314
require (

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
1616
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
1717
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
1818
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
19+
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
20+
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
1921
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
2022
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
2123
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

pkg/pacing/interceptor.go

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
// SPDX-FileCopyrightText: 2025 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
// Package pacing implements a pacing interceptor.
5+
package pacing
6+
7+
import (
8+
"errors"
9+
"log/slog"
10+
"maps"
11+
"sync"
12+
"time"
13+
14+
"github.com/pion/interceptor"
15+
"github.com/pion/logging"
16+
"github.com/pion/rtp"
17+
)
18+
19+
var (
20+
errPacerClosed = errors.New("pacer closed")
21+
errPacerOverflow = errors.New("pacer queue overflow")
22+
)
23+
24+
type pacerFactory func(initialRate, burst int) pacer
25+
26+
type pacer interface {
27+
SetRate(rate, burst int)
28+
Budget(time.Time) float64
29+
AllowN(time.Time, int) bool
30+
}
31+
32+
// Option is a configuration option for pacing interceptors.
33+
type Option func(*Interceptor) error
34+
35+
// InitialRate configures the initial pacing rate for interceptors created by
36+
// the interceptor factory.
37+
func InitialRate(rate int) Option {
38+
return func(i *Interceptor) error {
39+
i.initialRate = rate
40+
41+
return nil
42+
}
43+
}
44+
45+
// Interval configures the pacing interval for interceptors created by the
46+
// interceptor factory.
47+
func Interval(interval time.Duration) Option {
48+
return func(i *Interceptor) error {
49+
i.interval = interval
50+
51+
return nil
52+
}
53+
}
54+
55+
func setPacerFactory(f pacerFactory) Option {
56+
return func(i *Interceptor) error {
57+
i.pacerFactory = f
58+
59+
return nil
60+
}
61+
}
62+
63+
// InterceptorFactory is a factory for pacing interceptors. It also keeps a map
64+
// of interceptors created in the past by ID.
65+
type InterceptorFactory struct {
66+
lock sync.Mutex
67+
opts []Option
68+
interceptors map[string]*Interceptor
69+
}
70+
71+
// NewInterceptor returns a new InterceptorFactory.
72+
func NewInterceptor(opts ...Option) *InterceptorFactory {
73+
return &InterceptorFactory{
74+
lock: sync.Mutex{},
75+
opts: opts,
76+
interceptors: map[string]*Interceptor{},
77+
}
78+
}
79+
80+
// SetRate updates the pacing rate of the pacing interceptor with the given ID.
81+
func (f *InterceptorFactory) SetRate(id string, r int) {
82+
f.lock.Lock()
83+
defer f.lock.Unlock()
84+
85+
i, ok := f.interceptors[id]
86+
if !ok {
87+
return
88+
}
89+
i.setRate(r)
90+
}
91+
92+
func (f *InterceptorFactory) remove(id string) {
93+
f.lock.Lock()
94+
defer f.lock.Unlock()
95+
delete(f.interceptors, id)
96+
}
97+
98+
// NewInterceptor creates a new pacing interceptor.
99+
func (f *InterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) {
100+
f.lock.Lock()
101+
defer f.lock.Unlock()
102+
103+
interceptor := &Interceptor{
104+
NoOp: interceptor.NoOp{},
105+
log: logging.NewDefaultLoggerFactory().NewLogger("pacer_interceptor"),
106+
initialRate: 1_000_000,
107+
interval: 5 * time.Millisecond,
108+
queueSize: 1_000_000,
109+
pacerFactory: func(initialRate, burst int) pacer {
110+
return newRateLimitPacer(initialRate, burst)
111+
},
112+
limit: nil,
113+
queue: nil,
114+
closed: make(chan struct{}),
115+
wg: sync.WaitGroup{},
116+
id: id,
117+
onClose: f.remove,
118+
}
119+
for _, opt := range f.opts {
120+
if err := opt(interceptor); err != nil {
121+
return nil, err
122+
}
123+
}
124+
interceptor.limit = interceptor.pacerFactory(
125+
interceptor.initialRate,
126+
burst(interceptor.initialRate, interceptor.interval),
127+
)
128+
interceptor.queue = make(chan packet, interceptor.queueSize)
129+
130+
f.interceptors[id] = interceptor
131+
132+
interceptor.wg.Add(1)
133+
go func() {
134+
defer interceptor.wg.Done()
135+
interceptor.loop()
136+
}()
137+
138+
return interceptor, nil
139+
}
140+
141+
// Interceptor implements packet pacing using a token bucket filter and sends
142+
// packets at a fixed interval.
143+
type Interceptor struct {
144+
interceptor.NoOp
145+
log logging.LeveledLogger
146+
147+
// config
148+
initialRate int
149+
interval time.Duration
150+
queueSize int
151+
pacerFactory pacerFactory
152+
153+
// limiter and queue
154+
limit pacer
155+
queue chan packet
156+
157+
// shutdown
158+
closed chan struct{}
159+
wg sync.WaitGroup
160+
id string
161+
onClose func(string)
162+
}
163+
164+
// burst calculates the minimal burst size required to reach the given rate and
165+
// pacing interval.
166+
func burst(rate int, interval time.Duration) int {
167+
if interval == 0 {
168+
interval = time.Millisecond
169+
}
170+
f := float64(time.Second.Milliseconds() / interval.Milliseconds())
171+
172+
return 8 * int(float64(rate)/f)
173+
}
174+
175+
// setRate updates the pacing rate and burst of the rate limiter.
176+
func (i *Interceptor) setRate(r int) {
177+
i.limit.SetRate(r, burst(r, i.interval))
178+
}
179+
180+
// BindLocalStream implements interceptor.Interceptor.
181+
func (i *Interceptor) BindLocalStream(
182+
info *interceptor.StreamInfo,
183+
writer interceptor.RTPWriter,
184+
) interceptor.RTPWriter {
185+
return interceptor.RTPWriterFunc(func(
186+
header *rtp.Header,
187+
payload []byte,
188+
attributes interceptor.Attributes,
189+
) (int, error) {
190+
hdr := header.Clone()
191+
pay := make([]byte, len(payload))
192+
copy(pay, payload)
193+
attr := maps.Clone(attributes)
194+
select {
195+
case i.queue <- packet{
196+
writer: writer,
197+
header: &hdr,
198+
payload: pay,
199+
attributes: attr,
200+
}:
201+
case <-i.closed:
202+
return 0, errPacerClosed
203+
default:
204+
return 0, errPacerOverflow
205+
}
206+
207+
return header.MarshalSize() + len(payload), nil
208+
})
209+
}
210+
211+
// Close implements interceptor.Interceptor.
212+
func (i *Interceptor) Close() error {
213+
defer i.wg.Wait()
214+
close(i.closed)
215+
i.onClose(i.id)
216+
217+
return nil
218+
}
219+
220+
func (i *Interceptor) loop() {
221+
ticker := time.NewTicker(i.interval)
222+
defer ticker.Stop()
223+
queue := make([]packet, 0)
224+
for {
225+
select {
226+
case now := <-ticker.C:
227+
for len(queue) > 0 && i.limit.Budget(now) > 8*float64(queue[0].len()) {
228+
i.limit.AllowN(now, 8*queue[0].len())
229+
var next packet
230+
next, queue = queue[0], queue[1:]
231+
if _, err := next.writer.Write(next.header, next.payload, next.attributes); err != nil {
232+
slog.Warn("error on writing RTP packet", "error", err)
233+
}
234+
}
235+
case pkt := <-i.queue:
236+
queue = append(queue, pkt)
237+
case <-i.closed:
238+
return
239+
}
240+
}
241+
}
242+
243+
type packet struct {
244+
writer interceptor.RTPWriter
245+
header *rtp.Header
246+
payload []byte
247+
attributes interceptor.Attributes
248+
}
249+
250+
func (p *packet) len() int {
251+
return p.header.MarshalSize() + len(p.payload)
252+
}

0 commit comments

Comments
 (0)