Skip to content

Commit a4511dd

Browse files
Merge pull request #59 from matrix-org/track-related-improvements
More elegant handling of the RTCPs
2 parents 0895dc1 + 17873b0 commit a4511dd

File tree

8 files changed

+60
-75
lines changed

8 files changed

+60
-75
lines changed

pkg/conference/data_channel_message_processor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func (c *Conference) processSelectDCMessage(participant *Participant, msg event.
1919
if len(tracks) != len(msg.Start) {
2020
for _, expected := range msg.Start {
2121
found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool {
22-
return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID
22+
return track.ID() == expected.TrackID
2323
})
2424

2525
if found == -1 {

pkg/conference/matrix_message_processor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent *
6565
logger: logger,
6666
remoteSessionID: inviteEvent.SenderSessionID,
6767
streamMetadata: inviteEvent.SDPStreamMetadata,
68-
publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack),
68+
publishedTracks: make(map[string]PublishedTrack),
6969
}
7070

7171
c.participants[participantID] = participant

pkg/conference/messsage_processor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe
6363
case peer.DataChannelAvailable:
6464
c.processDataChannelAvailableMessage(participant, msg)
6565
case peer.RTCPReceived:
66-
c.processForwardRTCPMessage(msg)
66+
c.processRTCPPackets(msg)
6767
default:
6868
c.logger.Errorf("Unknown message type: %T", msg)
6969
}

pkg/conference/participant.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ type ParticipantID struct {
2222

2323
type PublishedTrack struct {
2424
track *webrtc.TrackLocalStaticRTP
25-
// The time when we sent the last PLI to the sender. We store this to avoid
26-
// spamming the sender.
27-
lastPLITimestamp time.Time
25+
// The timestamp at which we are allowed to send the FIR or PLI request. We don't want to send them
26+
// too often, so we introduce some trivial rate limiting to not "enforce" too many key frames.
27+
canSendKeyframeAt time.Time
2828
}
2929

3030
// Participant represents a participant in the conference.
@@ -34,7 +34,7 @@ type Participant struct {
3434
peer *peer.Peer[ParticipantID]
3535
remoteSessionID id.SessionID
3636
streamMetadata event.CallSDPStreamMetadata
37-
publishedTracks map[event.SFUTrackDescription]PublishedTrack
37+
publishedTracks map[string]PublishedTrack
3838
}
3939

4040
func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient {

pkg/conference/peer_message_processor.go

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,19 @@ func (c *Conference) processLeftTheCallMessage(participant *Participant, msg pee
2121

2222
func (c *Conference) processNewTrackPublishedMessage(participant *Participant, msg peer.NewTrackPublished) {
2323
participant.logger.Infof("Published new track: %s", msg.Track.ID())
24-
key := event.SFUTrackDescription{
25-
StreamID: msg.Track.StreamID(),
26-
TrackID: msg.Track.ID(),
27-
}
2824

29-
if _, ok := participant.publishedTracks[key]; ok {
30-
c.logger.Errorf("Track already published: %v", key)
25+
if _, ok := participant.publishedTracks[msg.Track.ID()]; ok {
26+
c.logger.Errorf("Track already published: %v", msg.Track.ID())
3127
return
3228
}
3329

34-
participant.publishedTracks[key] = PublishedTrack{track: msg.Track}
30+
participant.publishedTracks[msg.Track.ID()] = PublishedTrack{track: msg.Track}
3531
c.resendMetadataToAllExcept(participant.id)
3632
}
3733

3834
func (c *Conference) processPublishedTrackFailedMessage(participant *Participant, msg peer.PublishedTrackFailed) {
3935
participant.logger.Infof("Failed published track: %s", msg.Track.ID())
40-
delete(participant.publishedTracks, event.SFUTrackDescription{
41-
StreamID: msg.Track.StreamID(),
42-
TrackID: msg.Track.ID(),
43-
})
36+
delete(participant.publishedTracks, msg.Track.ID())
4437

4538
for _, otherParticipant := range c.participants {
4639
if otherParticipant.id == participant.id {
@@ -114,13 +107,14 @@ func (c *Conference) processDataChannelAvailableMessage(participant *Participant
114107
})
115108
}
116109

117-
func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) {
110+
func (c *Conference) processRTCPPackets(msg peer.RTCPReceived) {
111+
const sendKeyFrameInterval = 500 * time.Millisecond
112+
118113
for _, participant := range c.participants {
119-
for _, publishedTrack := range participant.publishedTracks {
120-
if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID {
121-
err := participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp)
122-
if err == nil {
123-
publishedTrack.lastPLITimestamp = time.Now()
114+
if published, ok := participant.publishedTracks[msg.TrackID]; ok {
115+
if published.canSendKeyframeAt.Before(time.Now()) {
116+
if err := participant.peer.WriteRTCP(msg.TrackID, msg.Packets); err == nil {
117+
published.canSendKeyframeAt = time.Now().Add(sendKeyFrameInterval)
124118
}
125119
}
126120
}

pkg/conference/state.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ func (c *Conference) removeParticipant(participantID ParticipantID) {
5959
for _, publishedTrack := range participant.publishedTracks {
6060
obsoleteTracks = append(obsoleteTracks, publishedTrack.track)
6161
}
62+
6263
for _, otherParticipant := range c.participants {
6364
otherParticipant.peer.UnsubscribeFrom(obsoleteTracks)
6465
}
@@ -98,7 +99,7 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt
9899
for _, participant := range c.participants {
99100
// Check if this participant has any of the tracks that we're looking for.
100101
for _, identifier := range identifiers {
101-
if track, ok := participant.publishedTracks[identifier]; ok {
102+
if track, ok := participant.publishedTracks[identifier.TrackID]; ok {
102103
tracks = append(tracks, track.track)
103104
}
104105
}

pkg/peer/messages.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package peer
22

33
import (
4-
"github.com/pion/rtcp"
54
"github.com/pion/webrtc/v3"
65
"maunium.net/go/mautrix/event"
76
)
@@ -41,7 +40,13 @@ type DataChannelMessage struct {
4140
type DataChannelAvailable struct{}
4241

4342
type RTCPReceived struct {
44-
Packets []rtcp.Packet
45-
StreamID string
46-
TrackID string
43+
TrackID string
44+
Packets []RTCPPacketType
4745
}
46+
47+
type RTCPPacketType int
48+
49+
const (
50+
PictureLossIndicator RTCPPacketType = iota + 1
51+
FullIntraRequest
52+
)

pkg/peer/peer.go

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ var (
2323
ErrDataChannelNotAvailable = errors.New("data channel is not available")
2424
ErrDataChannelNotReady = errors.New("data channel is not ready")
2525
ErrCantSubscribeToTrack = errors.New("can't subscribe to track")
26-
ErrCantWriteRTCP = errors.New("can't write RTCP")
26+
ErrTrackNotFound = errors.New("track not found")
2727
)
2828

2929
// A wrapped representation of the peer connection (single peer in the call).
@@ -106,71 +106,56 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error {
106106
packets, _, err := rtpSender.ReadRTCP()
107107
if err != nil {
108108
if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) {
109+
p.logger.WithError(err).Warn("failed to read RTCP on track")
109110
return
110111
}
112+
}
111113

112-
p.logger.WithError(err).Warn("failed to read RTCP on track")
114+
// We only want to inform others about PLIs and FIRs. We skip the rest of the packets for now.
115+
toForward := []RTCPPacketType{}
116+
for _, packet := range packets {
117+
// TODO: Should we also handle NACKs?
118+
switch packet.(type) {
119+
case *rtcp.PictureLossIndication:
120+
toForward = append(toForward, PictureLossIndicator)
121+
case *rtcp.FullIntraRequest:
122+
toForward = append(toForward, FullIntraRequest)
123+
}
113124
}
114125

115-
p.sink.Send(RTCPReceived{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()})
126+
p.sink.Send(RTCPReceived{Packets: toForward, TrackID: track.ID()})
116127
}
117128
}()
118129

119130
return nil
120131
}
121132

122-
func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp time.Time) error {
123-
const minimalPLIInterval = time.Millisecond * 500
124-
125-
packetsToSend := []rtcp.Packet{}
126-
var mediaSSRC uint32
133+
// Writes the specified packets to the `trackID`.
134+
func (p *Peer[ID]) WriteRTCP(trackID string, packets []RTCPPacketType) error {
135+
// Find the right track.
127136
receivers := p.peerConnection.GetReceivers()
128137
receiverIndex := slices.IndexFunc(receivers, func(receiver *webrtc.RTPReceiver) bool {
129-
return receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID
138+
return receiver.Track().ID() == trackID
130139
})
131-
132140
if receiverIndex == -1 {
133-
p.logger.Error("failed to find track to write RTCP on")
134-
return ErrCantWriteRTCP
135-
} else {
136-
mediaSSRC = uint32(receivers[receiverIndex].Track().SSRC())
141+
return ErrTrackNotFound
137142
}
138143

139-
for _, packet := range packets {
140-
switch typedPacket := packet.(type) {
141-
// We mung the packets here, so that the SSRCs match what the
142-
// receiver expects:
143-
// The media SSRC is the SSRC of the media about which the packet is
144-
// reporting; therefore, we mung it to be the SSRC of the publishing
145-
// participant's track. Without this, it would be SSRC of the SFU's
146-
// track which isn't right
147-
case *rtcp.PictureLossIndication:
148-
// Since we sometimes spam the sender with PLIs, make sure we don't send
149-
// them way too often
150-
if time.Now().UnixNano()-lastPLITimestamp.UnixNano() < minimalPLIInterval.Nanoseconds() {
151-
continue
152-
}
153-
154-
typedPacket.MediaSSRC = mediaSSRC
155-
packetsToSend = append(packetsToSend, typedPacket)
156-
case *rtcp.FullIntraRequest:
157-
typedPacket.MediaSSRC = mediaSSRC
158-
packetsToSend = append(packetsToSend, typedPacket)
159-
}
160-
161-
packetsToSend = append(packetsToSend, packet)
162-
}
163-
164-
if len(packetsToSend) != 0 {
165-
if err := p.peerConnection.WriteRTCP(packetsToSend); err != nil {
166-
if !errors.Is(err, io.ErrClosedPipe) {
167-
p.logger.WithError(err).Error("failed to write RTCP on track")
168-
return err
169-
}
144+
// The ssrc that we must use when sending the RTCP packet.
145+
// Otherwise the peer won't understand where the packet comes from.
146+
ssrc := uint32(receivers[receiverIndex].Track().SSRC())
147+
148+
toSend := make([]rtcp.Packet, len(packets))
149+
for i, packet := range packets {
150+
switch packet {
151+
case PictureLossIndicator:
152+
toSend[i] = &rtcp.PictureLossIndication{MediaSSRC: ssrc}
153+
case FullIntraRequest:
154+
toSend[i] = &rtcp.FullIntraRequest{MediaSSRC: ssrc}
170155
}
171156
}
172157

173-
return nil
158+
return p.peerConnection.WriteRTCP(toSend)
174159
}
175160

176161
// Unsubscribes from the given list of tracks.

0 commit comments

Comments
 (0)