diff --git a/dtlstransport.go b/dtlstransport.go index e0b575a1188..fd38f95fe4e 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -25,6 +25,7 @@ import ( "github.com/pion/logging" "github.com/pion/rtcp" "github.com/pion/srtp/v3" + "github.com/pion/transport/v3" "github.com/pion/webrtc/v4/internal/mux" "github.com/pion/webrtc/v4/internal/util" "github.com/pion/webrtc/v4/pkg/rtcerr" @@ -231,13 +232,13 @@ func (t *DTLSTransport) startSRTP() error { return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) } - srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) + srtpSession, err := srtp.NewSessionSRTPWithNewSocket(t.srtpEndpoint, srtpConfig) if err != nil { // nolint return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) } - srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig) + srtcpSession, err := srtp.NewSessionSRTCPWithNewSocket(t.srtcpEndpoint, srtpConfig) if err != nil { // nolint return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err) @@ -545,8 +546,15 @@ func (t *DTLSTransport) streamsForSSRC( &streamInfo, interceptor.RTPReaderFunc( func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { - n, err = rtpReadStream.Read(in) - + attr := transport.NewPacketAttributesWithLen(10) + n, err = rtpReadStream.ReadWithAttributes(in, attr) + + if a == nil { + a = interceptor.Attributes{} + if len(attr.GetReadPacketAttributes().Buffer) > 0 { + a["ECN"] = attr.GetReadPacketAttributes().Buffer[0] + } + } return n, a, err }, ), diff --git a/internal/mux/endpoint.go b/internal/mux/endpoint.go index d1a24c0b678..f594519e6b4 100644 --- a/internal/mux/endpoint.go +++ b/internal/mux/endpoint.go @@ -10,6 +10,7 @@ import ( "time" "github.com/pion/ice/v4" + "github.com/pion/transport/v3" "github.com/pion/transport/v3/packetio" ) @@ -53,6 +54,16 @@ func (e *Endpoint) ReadFrom(p []byte) (int, net.Addr, error) { return i, nil, err } +func (e *Endpoint) ReadWithAttributes(b []byte, attr *transport.PacketAttributes) (int, error) { + return e.buffer.ReadWithAttributes(b, attr) +} + +func (e *Endpoint) ReadFromWithAttributes(b []byte, attr *transport.PacketAttributes) (int, net.Addr, error) { + n, err := e.ReadWithAttributes(b, attr) + + return n, nil, err +} + // Write writes len(p) bytes to the underlying conn. func (e *Endpoint) Write(p []byte) (int, error) { n, err := e.mux.nextConn.Write(p) diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 942476897e2..45d85224904 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -7,11 +7,11 @@ package mux import ( "errors" "io" - "net" "sync" "github.com/pion/ice/v4" "github.com/pion/logging" + "github.com/pion/transport/v3" "github.com/pion/transport/v3/packetio" ) @@ -26,20 +26,25 @@ const ( // Config collects the arguments to mux.Mux construction into // a single structure. type Config struct { - Conn net.Conn + Conn transport.NetConnSocket BufferSize int LoggerFactory logging.LoggerFactory } +type pendingPacket struct { + packet []byte + attr *transport.PacketAttributes +} + // Mux allows multiplexing. type Mux struct { - nextConn net.Conn + nextConn transport.NetConnSocket bufferSize int lock sync.Mutex endpoints map[*Endpoint]MatchFunc isClosed bool - pendingPackets [][]byte + pendingPackets []*pendingPacket closedCh chan struct{} log logging.LeveledLogger @@ -118,8 +123,9 @@ func (m *Mux) readLoop() { }() buf := make([]byte, m.bufferSize) + attr := transport.NewPacketAttributesWithLen(transport.MaxAttributesLen) for { - n, err := m.nextConn.Read(buf) + n, err := m.nextConn.ReadWithAttributes(buf, attr) switch { case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed): return @@ -133,7 +139,7 @@ func (m *Mux) readLoop() { return } - if err = m.dispatch(buf[:n]); err != nil { + if err = m.dispatch(buf[:n], attr.GetReadPacketAttributes()); err != nil { if errors.Is(err, io.ErrClosedPipe) { // if the buffer was closed, that's not an error we care to report return @@ -145,8 +151,8 @@ func (m *Mux) readLoop() { } } -func (m *Mux) dispatch(buf []byte) error { - if len(buf) == 0 { +func (m *Mux) dispatch(b []byte, attr *transport.PacketAttributes) error { + if len(b) == 0 { m.log.Warnf("Warning: mux: unable to dispatch zero length packet") return nil @@ -156,7 +162,7 @@ func (m *Mux) dispatch(buf []byte) error { m.lock.Lock() for e, f := range m.endpoints { - if f(buf) { + if f(b) { endpoint = e break @@ -169,16 +175,22 @@ func (m *Mux) dispatch(buf []byte) error { if len(m.pendingPackets) >= maxPendingPackets { m.log.Warnf( "Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)", - buf[0], //nolint:gosec // G602, false positive? + b[0], //nolint:gosec // G602, false positive? len(m.pendingPackets), ) } else { m.log.Warnf( "Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)", - buf[0], //nolint:gosec // G602, false positive? + b[0], //nolint:gosec // G602, false positive? len(m.pendingPackets), ) - m.pendingPackets = append(m.pendingPackets, append([]byte{}, buf...)) + // copy the packet bytes and clone the PacketAttributes + pp := &pendingPacket{ + packet: append([]byte{}, b...), + attr: attr.Clone(), + } + + m.pendingPackets = append(m.pendingPackets, pp) } } @@ -186,7 +198,7 @@ func (m *Mux) dispatch(buf []byte) error { } m.lock.Unlock() - _, err := endpoint.buffer.Write(buf) + _, err := endpoint.buffer.WriteWithAttributes(b, attr) // Expected when bytes are received faster than the endpoint can process them (#2152, #2180) if errors.Is(err, packetio.ErrFull) { @@ -202,14 +214,14 @@ func (m *Mux) handlePendingPackets(endpoint *Endpoint, matchFunc MatchFunc) { m.lock.Lock() defer m.lock.Unlock() - pendingPackets := make([][]byte, len(m.pendingPackets)) - for _, buf := range m.pendingPackets { - if matchFunc(buf) { - if _, err := endpoint.buffer.Write(buf); err != nil { + pendingPackets := make([]*pendingPacket, len(m.pendingPackets)) + for _, p := range m.pendingPackets { + if matchFunc(p.packet) { + if _, err := endpoint.buffer.WriteWithAttributes(p.packet, p.attr); err != nil { m.log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", err) } } else { - pendingPackets = append(pendingPackets, buf) //nolint:makezero // todo fix + pendingPackets = append(pendingPackets, p) //nolint:makezero // todo fix } } m.pendingPackets = pendingPackets diff --git a/internal/mux/mux_test.go b/internal/mux/mux_test.go index 4306184db93..7ee60dcc64a 100644 --- a/internal/mux/mux_test.go +++ b/internal/mux/mux_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/pion/logging" + "github.com/pion/transport/v3" "github.com/pion/transport/v3/packetio" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" @@ -22,12 +23,13 @@ func TestNoEndpoints(t *testing.T) { ca, cb := net.Pipe() require.NoError(t, cb.Close()) + attr := transport.NewPacketAttributesWithLen(transport.MaxAttributesLen) mux := NewMux(Config{ - Conn: ca, + Conn: transport.NewNetConnToNetConnSocket(ca), BufferSize: testPipeBufferSize, LoggerFactory: logging.NewDefaultLoggerFactory(), }) - require.NoError(t, mux.dispatch(make([]byte, 1))) + require.NoError(t, mux.dispatch(make([]byte, 1), attr)) require.NoError(t, mux.Close()) require.NoError(t, ca.Close()) } @@ -83,7 +85,7 @@ func TestNonFatalRead(t *testing.T) { }} mux := NewMux(Config{ - Conn: conn, + Conn: transport.NewNetConnToNetConnSocket(conn), BufferSize: testPipeBufferSize, LoggerFactory: logging.NewDefaultLoggerFactory(), }) @@ -112,7 +114,7 @@ func TestNonFatalDispatch(t *testing.T) { in, out := net.Pipe() mux := NewMux(Config{ - Conn: out, + Conn: transport.NewNetConnToNetConnSocket(out), LoggerFactory: logging.NewDefaultLoggerFactory(), BufferSize: 1500, }) @@ -146,7 +148,7 @@ func BenchmarkDispatch(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - err := mux.dispatch(buf) + err := mux.dispatch(buf, transport.NewPacketAttributesWithLen(transport.MaxAttributesLen)) if err != nil { b.Errorf("dispatch: %v", err) } @@ -165,15 +167,17 @@ func TestPendingQueue(t *testing.T) { log: factory.NewLogger("mux"), } + attr := transport.NewPacketAttributesWithLen(transport.MaxAttributesLen) + // Assert empty packets don't end up in queue - require.NoError(t, mux.dispatch([]byte{})) + require.NoError(t, mux.dispatch([]byte{}, attr)) require.Equal(t, len(mux.pendingPackets), 0) // Test Happy Case inBuffer := []byte{20, 1, 2, 3, 4} outBuffer := make([]byte, len(inBuffer)) - require.NoError(t, mux.dispatch(inBuffer)) + require.NoError(t, mux.dispatch(inBuffer, attr)) endpoint := mux.NewEndpoint(MatchDTLS) require.NotNil(t, endpoint) @@ -185,7 +189,7 @@ func TestPendingQueue(t *testing.T) { // Assert limit on pendingPackets for i := 0; i <= 100; i++ { - require.NoError(t, mux.dispatch([]byte{64, 65, 66})) + require.NoError(t, mux.dispatch([]byte{64, 65, 66}, attr)) } require.Equal(t, len(mux.pendingPackets), maxPendingPackets) } diff --git a/peerconnection.go b/peerconnection.go index 1434e7dc1f8..17636d28163 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -12,7 +12,6 @@ import ( "crypto/rand" "errors" "fmt" - "io" "strconv" "strings" "sync" @@ -1681,7 +1680,7 @@ func (pc *PeerConnection) handleNonMediaBandwidthProbe() { } } -func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop +func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop remoteDescription := pc.RemoteDescription() if remoteDescription == nil { return errPeerConnRemoteDescriptionNil @@ -1719,8 +1718,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err // We read the RTP packet to determine the payload type b := make([]byte, pc.api.settingEngine.getReceiveMTU()) - - i, err := rtpStream.Read(b) + i, err := rtpStream.ReadWithAttributes(b, nil) if err != nil { return err } @@ -1908,7 +1906,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop continue } - go func(rtpStream io.Reader, ssrc SSRC) { + go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) { if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil { pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err) } diff --git a/settingengine.go b/settingengine.go index 96b851777f1..1c5c7869163 100644 --- a/settingengine.go +++ b/settingengine.go @@ -94,7 +94,7 @@ type SettingEngine struct { disableSRTPReplayProtection bool disableSRTCPReplayProtection bool net transport.Net - BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser + BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) *packetio.Buffer LoggerFactory logging.LoggerFactory iceTCPMux ice.TCPMux iceUDPMux ice.UDPMux