Skip to content

Commit 5a83a8f

Browse files
committed
Implement RTCP forwarding
Signed-off-by: Šimon Brandner <[email protected]>
1 parent 6f9d002 commit 5a83a8f

File tree

6 files changed

+103
-10
lines changed

6 files changed

+103
-10
lines changed

pkg/conference/matrix.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent *
6565
logger: logger,
6666
remoteSessionID: inviteEvent.SenderSessionID,
6767
streamMetadata: inviteEvent.SDPStreamMetadata,
68-
publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP),
68+
publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack),
6969
}
7070

7171
c.participants[participantID] = participant

pkg/conference/participant.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package conference
22

33
import (
44
"encoding/json"
5+
"sync/atomic"
56

67
"github.com/matrix-org/waterfall/pkg/peer"
78
"github.com/matrix-org/waterfall/pkg/signaling"
@@ -19,14 +20,19 @@ type ParticipantID struct {
1920
CallID string
2021
}
2122

23+
type PublishedTrack struct {
24+
Track *webrtc.TrackLocalStaticRTP
25+
LastPLITimestamp atomic.Int64
26+
}
27+
2228
// Participant represents a participant in the conference.
2329
type Participant struct {
2430
id ParticipantID
2531
logger *logrus.Entry
2632
peer *peer.Peer[ParticipantID]
2733
remoteSessionID id.SessionID
2834
streamMetadata event.CallSDPStreamMetadata
29-
publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP
35+
publishedTracks map[event.SFUTrackDescription]PublishedTrack
3036
}
3137

3238
func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient {

pkg/conference/processor.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe
6767
return
6868
}
6969

70-
participant.publishedTracks[key] = msg.Track
70+
participant.publishedTracks[key] = PublishedTrack{Track: msg.Track}
7171
c.resendMetadataToAllExcept(participant.id)
7272

7373
case peer.PublishedTrackFailed:
@@ -129,6 +129,22 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe
129129
Op: event.SFUOperationMetadata,
130130
Metadata: c.getAvailableStreamsFor(participant.id),
131131
})
132+
case peer.ForwardRTCP:
133+
for _, participant := range c.participants {
134+
for _, publishedTrack := range participant.publishedTracks {
135+
if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID {
136+
participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.LastPLITimestamp.Load())
137+
}
138+
}
139+
}
140+
case peer.PLISent:
141+
for _, participant := range c.participants {
142+
for _, publishedTrack := range participant.publishedTracks {
143+
if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID {
144+
publishedTrack.LastPLITimestamp.Store(msg.Timestamp)
145+
}
146+
}
147+
}
132148

133149
default:
134150
c.logger.Errorf("Unknown message type: %T", msg)

pkg/conference/state.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"github.com/matrix-org/waterfall/pkg/signaling"
77
"github.com/pion/webrtc/v3"
88
"github.com/sirupsen/logrus"
9-
"golang.org/x/exp/maps"
109
"maunium.net/go/mautrix/event"
1110
)
1211

@@ -56,7 +55,10 @@ func (c *Conference) removeParticipant(participantID ParticipantID) {
5655
c.resendMetadataToAllExcept(participantID)
5756

5857
// Remove the participant's tracks from all participants who might have subscribed to them.
59-
obsoleteTracks := maps.Values(participant.publishedTracks)
58+
obsoleteTracks := []*webrtc.TrackLocalStaticRTP{}
59+
for _, publishedTrack := range participant.publishedTracks {
60+
obsoleteTracks = append(obsoleteTracks, publishedTrack.Track)
61+
}
6062
for _, otherParticipant := range c.participants {
6163
otherParticipant.peer.UnsubscribeFrom(obsoleteTracks)
6264
}
@@ -72,7 +74,7 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event.
7274
// Now, find out which of published tracks belong to the streams for which we have metadata
7375
// available and construct a metadata map for a given participant based on that.
7476
for _, track := range participant.publishedTracks {
75-
trackID, streamID := track.ID(), track.StreamID()
77+
trackID, streamID := track.Track.ID(), track.Track.StreamID()
7678

7779
if metadata, ok := streamsMetadata[streamID]; ok {
7880
metadata.Tracks[trackID] = event.CallSDPStreamMetadataTrack{}
@@ -97,7 +99,7 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt
9799
// Check if this participant has any of the tracks that we're looking for.
98100
for _, identifier := range identifiers {
99101
if track, ok := participant.publishedTracks[identifier]; ok {
100-
tracks = append(tracks, track)
102+
tracks = append(tracks, track.Track)
101103
}
102104
}
103105
}

pkg/peer/messages.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package peer
22

33
import (
4+
"github.com/pion/rtcp"
45
"github.com/pion/webrtc/v3"
56
"maunium.net/go/mautrix/event"
67
)
@@ -38,3 +39,15 @@ type DataChannelMessage struct {
3839
}
3940

4041
type DataChannelAvailable struct{}
42+
43+
type ForwardRTCP struct {
44+
Packets []rtcp.Packet
45+
StreamID string
46+
TrackID string
47+
}
48+
49+
type PLISent struct {
50+
Timestamp int64
51+
StreamID string
52+
TrackID string
53+
}

pkg/peer/peer.go

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package peer
22

33
import (
44
"errors"
5+
"io"
56
"sync"
67
"time"
78

89
"github.com/matrix-org/waterfall/pkg/common"
10+
"github.com/pion/rtcp"
911
"github.com/pion/webrtc/v3"
1012
"github.com/sirupsen/logrus"
1113
"maunium.net/go/mautrix/event"
@@ -22,6 +24,8 @@ var (
2224
ErrCantSubscribeToTrack = errors.New("can't subscribe to track")
2325
)
2426

27+
const minimalPLIInterval = time.Millisecond * 500
28+
2529
// A wrapped representation of the peer connection (single peer in the call).
2630
// The peer gets information about the things happening outside via public methods
2731
// and informs the outside world about the things happening inside the peer by posting
@@ -98,17 +102,69 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error {
98102
// Before these packets are returned they are processed by interceptors. For things
99103
// like NACK this needs to be called.
100104
go func() {
101-
rtcpBuf := make([]byte, 1500)
102105
for {
103-
if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil {
104-
return
106+
packets, _, err := rtpSender.ReadRTCP()
107+
if err != nil {
108+
if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) {
109+
return
110+
}
111+
112+
p.logger.WithError(err).Warn("failed to read RTCP on track")
105113
}
114+
115+
p.sink.Send(ForwardRTCP{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()})
106116
}
107117
}()
108118

109119
return nil
110120
}
111121

122+
func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp int64) {
123+
packetsToSend := []rtcp.Packet{}
124+
var mediaSSRC uint32
125+
for _, receiver := range p.peerConnection.GetReceivers() {
126+
if receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID {
127+
mediaSSRC = uint32(receiver.Track().SSRC())
128+
break
129+
}
130+
}
131+
132+
for _, packet := range packets {
133+
switch typedPacket := packet.(type) {
134+
// We mung the packets here, so that the SSRCs match what the
135+
// receiver expects:
136+
// The media SSRC is the SSRC of the media about which the packet is
137+
// reporting; therefore, we mung it to be the SSRC of the publishing
138+
// participant's track. Without this, it would be SSRC of the SFU's
139+
// track which isn't right
140+
case *rtcp.PictureLossIndication:
141+
// Since we sometimes spam the sender with PLIs, make sure we don't send
142+
// them way too often
143+
if time.Now().UnixNano()-lastPLITimestamp < minimalPLIInterval.Nanoseconds() {
144+
continue
145+
}
146+
147+
p.sink.Send(PLISent{Timestamp: time.Now().UnixNano(), StreamID: streamID, TrackID: trackID})
148+
149+
typedPacket.MediaSSRC = mediaSSRC
150+
packetsToSend = append(packetsToSend, typedPacket)
151+
case *rtcp.FullIntraRequest:
152+
typedPacket.MediaSSRC = mediaSSRC
153+
packetsToSend = append(packetsToSend, typedPacket)
154+
}
155+
156+
packetsToSend = append(packetsToSend, packet)
157+
}
158+
159+
if len(packetsToSend) != 0 {
160+
if err := p.peerConnection.WriteRTCP(packetsToSend); err != nil {
161+
if !errors.Is(err, io.ErrClosedPipe) {
162+
p.logger.WithError(err).Warn("failed to write RTCP on track")
163+
}
164+
}
165+
}
166+
}
167+
112168
// Unsubscribes from the given list of tracks.
113169
func (p *Peer[ID]) UnsubscribeFrom(tracks []*webrtc.TrackLocalStaticRTP) {
114170
// That's unfortunately an O(m*n) operation, but we don't expect the number of tracks to be big.

0 commit comments

Comments
 (0)