Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions dtlstransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/srtp/v3"
"github.com/pion/transport/v3"
"github.com/pion/webrtc/v4/internal/mux"
"github.com/pion/webrtc/v4/internal/util"
"github.com/pion/webrtc/v4/pkg/rtcerr"
Expand Down Expand Up @@ -231,13 +232,13 @@ func (t *DTLSTransport) startSRTP() error {
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
}

srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
srtpSession, err := srtp.NewSessionSRTPWithNewSocket(t.srtpEndpoint, srtpConfig)
if err != nil {
// nolint
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
}

srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
srtcpSession, err := srtp.NewSessionSRTCPWithNewSocket(t.srtcpEndpoint, srtpConfig)
if err != nil {
// nolint
return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
Expand Down Expand Up @@ -545,8 +546,15 @@ func (t *DTLSTransport) streamsForSSRC(
&streamInfo,
interceptor.RTPReaderFunc(
func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtpReadStream.Read(in)

attr := transport.NewPacketAttributesWithLen(10)
n, err = rtpReadStream.ReadWithAttributes(in, attr)

if a == nil {
a = interceptor.Attributes{}
if len(attr.GetReadPacketAttributes().Buffer) > 0 {
a["ECN"] = attr.GetReadPacketAttributes().Buffer[0]
}
}
return n, a, err
},
),
Expand Down
11 changes: 11 additions & 0 deletions internal/mux/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/pion/ice/v4"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/packetio"
)

Expand Down Expand Up @@ -53,6 +54,16 @@ func (e *Endpoint) ReadFrom(p []byte) (int, net.Addr, error) {
return i, nil, err
}

func (e *Endpoint) ReadWithAttributes(b []byte, attr *transport.PacketAttributes) (int, error) {
return e.buffer.ReadWithAttributes(b, attr)
}

func (e *Endpoint) ReadFromWithAttributes(b []byte, attr *transport.PacketAttributes) (int, net.Addr, error) {
n, err := e.ReadWithAttributes(b, attr)

return n, nil, err
}

// Write writes len(p) bytes to the underlying conn.
func (e *Endpoint) Write(p []byte) (int, error) {
n, err := e.mux.nextConn.Write(p)
Expand Down
48 changes: 30 additions & 18 deletions internal/mux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ package mux
import (
"errors"
"io"
"net"
"sync"

"github.com/pion/ice/v4"
"github.com/pion/logging"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/packetio"
)

Expand All @@ -26,20 +26,25 @@ const (
// Config collects the arguments to mux.Mux construction into
// a single structure.
type Config struct {
Conn net.Conn
Conn transport.NetConnSocket
BufferSize int
LoggerFactory logging.LoggerFactory
}

type pendingPacket struct {
packet []byte
attr *transport.PacketAttributes
}

// Mux allows multiplexing.
type Mux struct {
nextConn net.Conn
nextConn transport.NetConnSocket
bufferSize int
lock sync.Mutex
endpoints map[*Endpoint]MatchFunc
isClosed bool

pendingPackets [][]byte
pendingPackets []*pendingPacket

closedCh chan struct{}
log logging.LeveledLogger
Expand Down Expand Up @@ -118,8 +123,9 @@ func (m *Mux) readLoop() {
}()

buf := make([]byte, m.bufferSize)
attr := transport.NewPacketAttributesWithLen(transport.MaxAttributesLen)
for {
n, err := m.nextConn.Read(buf)
n, err := m.nextConn.ReadWithAttributes(buf, attr)
switch {
case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed):
return
Expand All @@ -133,7 +139,7 @@ func (m *Mux) readLoop() {
return
}

if err = m.dispatch(buf[:n]); err != nil {
if err = m.dispatch(buf[:n], attr.GetReadPacketAttributes()); err != nil {
if errors.Is(err, io.ErrClosedPipe) {
// if the buffer was closed, that's not an error we care to report
return
Expand All @@ -145,8 +151,8 @@ func (m *Mux) readLoop() {
}
}

func (m *Mux) dispatch(buf []byte) error {
if len(buf) == 0 {
func (m *Mux) dispatch(b []byte, attr *transport.PacketAttributes) error {
if len(b) == 0 {
m.log.Warnf("Warning: mux: unable to dispatch zero length packet")

return nil
Expand All @@ -156,7 +162,7 @@ func (m *Mux) dispatch(buf []byte) error {

m.lock.Lock()
for e, f := range m.endpoints {
if f(buf) {
if f(b) {
endpoint = e

break
Expand All @@ -169,24 +175,30 @@ func (m *Mux) dispatch(buf []byte) error {
if len(m.pendingPackets) >= maxPendingPackets {
m.log.Warnf(
"Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)",
buf[0], //nolint:gosec // G602, false positive?
b[0], //nolint:gosec // G602, false positive?
len(m.pendingPackets),
)
} else {
m.log.Warnf(
"Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)",
buf[0], //nolint:gosec // G602, false positive?
b[0], //nolint:gosec // G602, false positive?
len(m.pendingPackets),
)
m.pendingPackets = append(m.pendingPackets, append([]byte{}, buf...))
// copy the packet bytes and clone the PacketAttributes
pp := &pendingPacket{
packet: append([]byte{}, b...),
attr: attr.Clone(),
}

m.pendingPackets = append(m.pendingPackets, pp)
}
}

return nil
}

m.lock.Unlock()
_, err := endpoint.buffer.Write(buf)
_, err := endpoint.buffer.WriteWithAttributes(b, attr)

// Expected when bytes are received faster than the endpoint can process them (#2152, #2180)
if errors.Is(err, packetio.ErrFull) {
Expand All @@ -202,14 +214,14 @@ func (m *Mux) handlePendingPackets(endpoint *Endpoint, matchFunc MatchFunc) {
m.lock.Lock()
defer m.lock.Unlock()

pendingPackets := make([][]byte, len(m.pendingPackets))
for _, buf := range m.pendingPackets {
if matchFunc(buf) {
if _, err := endpoint.buffer.Write(buf); err != nil {
pendingPackets := make([]*pendingPacket, len(m.pendingPackets))
for _, p := range m.pendingPackets {
if matchFunc(p.packet) {
if _, err := endpoint.buffer.WriteWithAttributes(p.packet, p.attr); err != nil {
m.log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", err)
}
} else {
pendingPackets = append(pendingPackets, buf) //nolint:makezero // todo fix
pendingPackets = append(pendingPackets, p) //nolint:makezero // todo fix
}
}
m.pendingPackets = pendingPackets
Expand Down
20 changes: 12 additions & 8 deletions internal/mux/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/pion/logging"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/packetio"
"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/require"
Expand All @@ -22,12 +23,13 @@ func TestNoEndpoints(t *testing.T) {
ca, cb := net.Pipe()
require.NoError(t, cb.Close())

attr := transport.NewPacketAttributesWithLen(transport.MaxAttributesLen)
mux := NewMux(Config{
Conn: ca,
Conn: transport.NewNetConnToNetConnSocket(ca),
BufferSize: testPipeBufferSize,
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
require.NoError(t, mux.dispatch(make([]byte, 1)))
require.NoError(t, mux.dispatch(make([]byte, 1), attr))
require.NoError(t, mux.Close())
require.NoError(t, ca.Close())
}
Expand Down Expand Up @@ -83,7 +85,7 @@ func TestNonFatalRead(t *testing.T) {
}}

mux := NewMux(Config{
Conn: conn,
Conn: transport.NewNetConnToNetConnSocket(conn),
BufferSize: testPipeBufferSize,
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
Expand Down Expand Up @@ -112,7 +114,7 @@ func TestNonFatalDispatch(t *testing.T) {
in, out := net.Pipe()

mux := NewMux(Config{
Conn: out,
Conn: transport.NewNetConnToNetConnSocket(out),
LoggerFactory: logging.NewDefaultLoggerFactory(),
BufferSize: 1500,
})
Expand Down Expand Up @@ -146,7 +148,7 @@ func BenchmarkDispatch(b *testing.B) {
b.StartTimer()

for i := 0; i < b.N; i++ {
err := mux.dispatch(buf)
err := mux.dispatch(buf, transport.NewPacketAttributesWithLen(transport.MaxAttributesLen))
if err != nil {
b.Errorf("dispatch: %v", err)
}
Expand All @@ -165,15 +167,17 @@ func TestPendingQueue(t *testing.T) {
log: factory.NewLogger("mux"),
}

attr := transport.NewPacketAttributesWithLen(transport.MaxAttributesLen)

// Assert empty packets don't end up in queue
require.NoError(t, mux.dispatch([]byte{}))
require.NoError(t, mux.dispatch([]byte{}, attr))
require.Equal(t, len(mux.pendingPackets), 0)

// Test Happy Case
inBuffer := []byte{20, 1, 2, 3, 4}
outBuffer := make([]byte, len(inBuffer))

require.NoError(t, mux.dispatch(inBuffer))
require.NoError(t, mux.dispatch(inBuffer, attr))

endpoint := mux.NewEndpoint(MatchDTLS)
require.NotNil(t, endpoint)
Expand All @@ -185,7 +189,7 @@ func TestPendingQueue(t *testing.T) {

// Assert limit on pendingPackets
for i := 0; i <= 100; i++ {
require.NoError(t, mux.dispatch([]byte{64, 65, 66}))
require.NoError(t, mux.dispatch([]byte{64, 65, 66}, attr))
}
require.Equal(t, len(mux.pendingPackets), maxPendingPackets)
}
8 changes: 3 additions & 5 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"crypto/rand"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -1681,7 +1680,7 @@ func (pc *PeerConnection) handleNonMediaBandwidthProbe() {
}
}

func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop
func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop
remoteDescription := pc.RemoteDescription()
if remoteDescription == nil {
return errPeerConnRemoteDescriptionNil
Expand Down Expand Up @@ -1719,8 +1718,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err

// We read the RTP packet to determine the payload type
b := make([]byte, pc.api.settingEngine.getReceiveMTU())

i, err := rtpStream.Read(b)
i, err := rtpStream.ReadWithAttributes(b, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -1908,7 +1906,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop
continue
}

go func(rtpStream io.Reader, ssrc SSRC) {
go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) {
if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil {
pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err)
}
Expand Down
2 changes: 1 addition & 1 deletion settingengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type SettingEngine struct {
disableSRTPReplayProtection bool
disableSRTCPReplayProtection bool
net transport.Net
BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) *packetio.Buffer
LoggerFactory logging.LoggerFactory
iceTCPMux ice.TCPMux
iceUDPMux ice.UDPMux
Expand Down