Skip to content

Commit 1b6743e

Browse files
Merge pull request #54 from matrix-org/SimonBrandner/feat/rtcp-forward
Implement RTCP forwarding in the refactored version of the SFU
2 parents 78e9e99 + 5f90f21 commit 1b6743e

File tree

10 files changed

+384
-250
lines changed

10 files changed

+384
-250
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package conference
2+
3+
import (
4+
"github.com/pion/webrtc/v3"
5+
"golang.org/x/exp/slices"
6+
"maunium.net/go/mautrix/event"
7+
)
8+
9+
// Handle the `SFUMessage` event from the DataChannel message.
10+
func (c *Conference) processSelectDCMessage(participant *Participant, msg event.SFUMessage) {
11+
participant.logger.Info("Received select request over DC")
12+
13+
// Find tracks based on what we were asked for.
14+
tracks := c.getTracks(msg.Start)
15+
16+
// Let's check if we have all the tracks that we were asked for are there.
17+
// If not, we will list which are not available (later on we must inform participant
18+
// about it unless the participant retries it).
19+
if len(tracks) != len(msg.Start) {
20+
for _, expected := range msg.Start {
21+
found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool {
22+
return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID
23+
})
24+
25+
if found == -1 {
26+
c.logger.Warnf("Track not found: %s", expected.TrackID)
27+
}
28+
}
29+
}
30+
31+
// Subscribe to the found tracks.
32+
for _, track := range tracks {
33+
if err := participant.peer.SubscribeTo(track); err != nil {
34+
participant.logger.Errorf("Failed to subscribe to track: %v", err)
35+
return
36+
}
37+
}
38+
}
39+
40+
func (c *Conference) processAnswerDCMessage(participant *Participant, msg event.SFUMessage) {
41+
participant.logger.Info("Received SDP answer over DC")
42+
43+
if err := participant.peer.ProcessSDPAnswer(msg.SDP); err != nil {
44+
participant.logger.Errorf("Failed to set SDP answer: %v", err)
45+
return
46+
}
47+
}
48+
49+
func (c *Conference) processPublishDCMessage(participant *Participant, msg event.SFUMessage) {
50+
participant.logger.Info("Received SDP offer over DC")
51+
52+
answer, err := participant.peer.ProcessSDPOffer(msg.SDP)
53+
if err != nil {
54+
participant.logger.Errorf("Failed to set SDP offer: %v", err)
55+
return
56+
}
57+
58+
participant.streamMetadata = msg.Metadata
59+
60+
participant.sendDataChannelMessage(event.SFUMessage{
61+
Op: event.SFUOperationAnswer,
62+
SDP: answer.SDP,
63+
Metadata: c.getAvailableStreamsFor(participant.id),
64+
})
65+
}
66+
67+
func (c *Conference) processUnpublishDCMessage(participant *Participant) {
68+
participant.logger.Info("Received unpublish over DC")
69+
}
70+
71+
func (c *Conference) processAliveDCMessage(participant *Participant) {
72+
participant.peer.ProcessHeartbeat()
73+
}
74+
75+
func (c *Conference) processMetadataDCMessage(participant *Participant, msg event.SFUMessage) {
76+
participant.streamMetadata = msg.Metadata
77+
c.resendMetadataToAllExcept(participant.id)
78+
}

pkg/conference/matrix.go renamed to pkg/conference/matrix_message_processor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent *
6565
logger: logger,
6666
remoteSessionID: inviteEvent.SenderSessionID,
6767
streamMetadata: inviteEvent.SDPStreamMetadata,
68-
publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP),
68+
publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack),
6969
}
7070

7171
c.participants[participantID] = participant
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package conference
2+
3+
import (
4+
"errors"
5+
6+
"github.com/matrix-org/waterfall/pkg/common"
7+
"github.com/matrix-org/waterfall/pkg/peer"
8+
"maunium.net/go/mautrix/event"
9+
)
10+
11+
// Listen on messages from incoming channels and process them.
12+
// This is essentially the main loop of the conference.
13+
// If this function returns, the conference is over.
14+
func (c *Conference) processMessages() {
15+
for {
16+
select {
17+
case msg := <-c.peerMessages:
18+
c.processPeerMessage(msg)
19+
case msg := <-c.matrixMessages.Channel:
20+
c.processMatrixMessage(msg)
21+
}
22+
23+
// If there are no more participants, stop the conference.
24+
if len(c.participants) == 0 {
25+
c.logger.Info("No more participants, stopping the conference")
26+
// Close the channel so that the sender can't push any messages.
27+
unreadMessages := c.matrixMessages.Close()
28+
29+
// Send the information that we ended to the owner and pass the message
30+
// that we did not process (so that we don't drop it silently).
31+
c.endNotifier.Notify(unreadMessages)
32+
return
33+
}
34+
}
35+
}
36+
37+
// Process a message from a local peer.
38+
func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) {
39+
participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant"))
40+
if participant == nil {
41+
return
42+
}
43+
44+
// Since Go does not support ADTs, we have to use a switch statement to
45+
// determine the actual type of the message.
46+
switch msg := message.Content.(type) {
47+
case peer.JoinedTheCall:
48+
c.processJoinedTheCallMessage(participant, msg)
49+
case peer.LeftTheCall:
50+
c.processLeftTheCallMessage(participant, msg)
51+
case peer.NewTrackPublished:
52+
c.processNewTrackPublishedMessage(participant, msg)
53+
case peer.PublishedTrackFailed:
54+
c.processPublishedTrackFailedMessage(participant, msg)
55+
case peer.NewICECandidate:
56+
c.processNewICECandidateMessage(participant, msg)
57+
case peer.ICEGatheringComplete:
58+
c.processICEGatheringCompleteMessage(participant, msg)
59+
case peer.RenegotiationRequired:
60+
c.processRenegotiationRequiredMessage(participant, msg)
61+
case peer.DataChannelMessage:
62+
c.processDataChannelMessage(participant, msg)
63+
case peer.DataChannelAvailable:
64+
c.processDataChannelAvailableMessage(participant, msg)
65+
case peer.RTCPReceived:
66+
c.processForwardRTCPMessage(msg)
67+
default:
68+
c.logger.Errorf("Unknown message type: %T", msg)
69+
}
70+
}
71+
72+
func (c *Conference) processMatrixMessage(msg MatrixMessage) {
73+
switch ev := msg.Content.(type) {
74+
case *event.CallInviteEventContent:
75+
c.onNewParticipant(msg.Sender, ev)
76+
case *event.CallCandidatesEventContent:
77+
c.onCandidates(msg.Sender, ev)
78+
case *event.CallSelectAnswerEventContent:
79+
c.onSelectAnswer(msg.Sender, ev)
80+
case *event.CallHangupEventContent:
81+
c.onHangup(msg.Sender, ev)
82+
default:
83+
c.logger.Errorf("Unexpected event type: %T", ev)
84+
}
85+
}

pkg/conference/participant.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package conference
22

33
import (
44
"encoding/json"
5+
"time"
56

67
"github.com/matrix-org/waterfall/pkg/peer"
78
"github.com/matrix-org/waterfall/pkg/signaling"
@@ -19,14 +20,21 @@ type ParticipantID struct {
1920
CallID string
2021
}
2122

23+
type PublishedTrack struct {
24+
track *webrtc.TrackLocalStaticRTP
25+
// The time when we sent the last PLI to the sender. We store this to avoid
26+
// spamming the sender.
27+
lastPLITimestamp time.Time
28+
}
29+
2230
// Participant represents a participant in the conference.
2331
type Participant struct {
2432
id ParticipantID
2533
logger *logrus.Entry
2634
peer *peer.Peer[ParticipantID]
2735
remoteSessionID id.SessionID
2836
streamMetadata event.CallSDPStreamMetadata
29-
publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP
37+
publishedTracks map[event.SFUTrackDescription]PublishedTrack
3038
}
3139

3240
func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient {
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package conference
2+
3+
import (
4+
"encoding/json"
5+
"time"
6+
7+
"github.com/matrix-org/waterfall/pkg/peer"
8+
"github.com/pion/webrtc/v3"
9+
"maunium.net/go/mautrix/event"
10+
)
11+
12+
func (c *Conference) processJoinedTheCallMessage(participant *Participant, message peer.JoinedTheCall) {
13+
participant.logger.Info("Joined the call")
14+
}
15+
16+
func (c *Conference) processLeftTheCallMessage(participant *Participant, msg peer.LeftTheCall) {
17+
participant.logger.Info("Left the call: %s", msg.Reason)
18+
c.removeParticipant(participant.id)
19+
c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason)
20+
}
21+
22+
func (c *Conference) processNewTrackPublishedMessage(participant *Participant, msg peer.NewTrackPublished) {
23+
participant.logger.Infof("Published new track: %s", msg.Track.ID())
24+
key := event.SFUTrackDescription{
25+
StreamID: msg.Track.StreamID(),
26+
TrackID: msg.Track.ID(),
27+
}
28+
29+
if _, ok := participant.publishedTracks[key]; ok {
30+
c.logger.Errorf("Track already published: %v", key)
31+
return
32+
}
33+
34+
participant.publishedTracks[key] = PublishedTrack{track: msg.Track}
35+
c.resendMetadataToAllExcept(participant.id)
36+
}
37+
38+
func (c *Conference) processPublishedTrackFailedMessage(participant *Participant, msg peer.PublishedTrackFailed) {
39+
participant.logger.Infof("Failed published track: %s", msg.Track.ID())
40+
delete(participant.publishedTracks, event.SFUTrackDescription{
41+
StreamID: msg.Track.StreamID(),
42+
TrackID: msg.Track.ID(),
43+
})
44+
45+
for _, otherParticipant := range c.participants {
46+
if otherParticipant.id == participant.id {
47+
continue
48+
}
49+
50+
otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track})
51+
}
52+
53+
c.resendMetadataToAllExcept(participant.id)
54+
}
55+
56+
func (c *Conference) processNewICECandidateMessage(participant *Participant, msg peer.NewICECandidate) {
57+
participant.logger.Debug("Received a new local ICE candidate")
58+
59+
// Convert WebRTC ICE candidate to Matrix ICE candidate.
60+
jsonCandidate := msg.Candidate.ToJSON()
61+
candidates := []event.CallCandidate{{
62+
Candidate: jsonCandidate.Candidate,
63+
SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex),
64+
SDPMID: *jsonCandidate.SDPMid,
65+
}}
66+
c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates)
67+
}
68+
69+
func (c *Conference) processICEGatheringCompleteMessage(participant *Participant, msg peer.ICEGatheringComplete) {
70+
participant.logger.Info("Completed local ICE gathering")
71+
72+
// Send an empty array of candidates to indicate that ICE gathering is complete.
73+
c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient())
74+
}
75+
76+
func (c *Conference) processRenegotiationRequiredMessage(participant *Participant, msg peer.RenegotiationRequired) {
77+
participant.logger.Info("Started renegotiation")
78+
participant.sendDataChannelMessage(event.SFUMessage{
79+
Op: event.SFUOperationOffer,
80+
SDP: msg.Offer.SDP,
81+
Metadata: c.getAvailableStreamsFor(participant.id),
82+
})
83+
}
84+
85+
func (c *Conference) processDataChannelMessage(participant *Participant, msg peer.DataChannelMessage) {
86+
participant.logger.Debug("Received data channel message")
87+
var sfuMessage event.SFUMessage
88+
if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil {
89+
c.logger.Errorf("Failed to unmarshal SFU message: %v", err)
90+
return
91+
}
92+
93+
switch sfuMessage.Op {
94+
case event.SFUOperationSelect:
95+
c.processSelectDCMessage(participant, sfuMessage)
96+
case event.SFUOperationAnswer:
97+
c.processAnswerDCMessage(participant, sfuMessage)
98+
case event.SFUOperationPublish:
99+
c.processPublishDCMessage(participant, sfuMessage)
100+
case event.SFUOperationUnpublish:
101+
c.processUnpublishDCMessage(participant)
102+
case event.SFUOperationAlive:
103+
c.processAliveDCMessage(participant)
104+
case event.SFUOperationMetadata:
105+
c.processMetadataDCMessage(participant, sfuMessage)
106+
}
107+
}
108+
109+
func (c *Conference) processDataChannelAvailableMessage(participant *Participant, msg peer.DataChannelAvailable) {
110+
participant.logger.Info("Connected data channel")
111+
participant.sendDataChannelMessage(event.SFUMessage{
112+
Op: event.SFUOperationMetadata,
113+
Metadata: c.getAvailableStreamsFor(participant.id),
114+
})
115+
}
116+
117+
func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) {
118+
for _, participant := range c.participants {
119+
for _, publishedTrack := range participant.publishedTracks {
120+
if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID {
121+
err := participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp)
122+
if err == nil {
123+
publishedTrack.lastPLITimestamp = time.Now()
124+
}
125+
}
126+
}
127+
}
128+
}

0 commit comments

Comments
 (0)