Skip to content

Commit 4fd4a70

Browse files
refactor: define a package for message sink
This allows to generalize the message sink and get rid of a lot of copy-paste in the handling functions. Also this moves types to the right modules, so that `peer` is now completely matrix-unaware module that contains only plain WebRTC logic.
1 parent 25ba9e2 commit 4fd4a70

File tree

12 files changed

+267
-261
lines changed

12 files changed

+267
-261
lines changed

src/common/message_sink.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package common
2+
3+
// MessageSink is a helper struct that allows to send messages to a message sink.
4+
// The MessageSink abstracts the message sink which has a certain sender, so that
5+
// the sender does not have to be specified every time a message is sent.
6+
// At the same it guarantees that the caller can't alter the `sender`, which means that
7+
// the sender can't impersonate another sender (and we guarantee this on a compile-time).
8+
type MessageSink[SenderType comparable, MessageType any] struct {
9+
// The sender of the messages. This is useful for multiple-producer-single-consumer scenarios.
10+
sender SenderType
11+
// The message sink to which the messages are sent.
12+
messageSink chan<- Message[SenderType, MessageType]
13+
}
14+
15+
// Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases.
16+
func NewMessageSink[S comparable, M any](sender S, messageSink chan<- Message[S, M]) *MessageSink[S, M] {
17+
return &MessageSink[S, M]{
18+
sender: sender,
19+
messageSink: messageSink,
20+
}
21+
}
22+
23+
// Sends a message to the message sink.
24+
func (s *MessageSink[S, M]) Send(message M) {
25+
s.messageSink <- Message[S, M]{
26+
Sender: s.sender,
27+
Content: message,
28+
}
29+
}
30+
31+
// Messages that are sent from the peer to the conference in order to communicate with other peers.
32+
// Since each peer is isolated from others, it can't influence the state of other peers directly.
33+
type Message[SenderType comparable, MessageType any] struct {
34+
// The sender of the message.
35+
Sender SenderType
36+
// The content of the message.
37+
Content MessageType
38+
}

src/conference/conference.go

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package conference
1818

1919
import (
20+
"github.com/matrix-org/waterfall/src/common"
2021
"github.com/matrix-org/waterfall/src/peer"
2122
"github.com/matrix-org/waterfall/src/signaling"
2223
"github.com/pion/webrtc/v3"
@@ -25,22 +26,22 @@ import (
2526
)
2627

2728
type Conference struct {
28-
id string
29-
config Config
30-
signaling signaling.MatrixSignaling
31-
participants map[peer.ID]*Participant
32-
peerEventsStream chan peer.Message
33-
logger *logrus.Entry
29+
id string
30+
config Config
31+
signaling signaling.MatrixSignaling
32+
participants map[ParticipantID]*Participant
33+
peerEvents chan common.Message[ParticipantID, peer.MessageContent]
34+
logger *logrus.Entry
3435
}
3536

3637
func NewConference(confID string, config Config, signaling signaling.MatrixSignaling) *Conference {
3738
conference := &Conference{
38-
id: confID,
39-
config: config,
40-
signaling: signaling,
41-
participants: make(map[peer.ID]*Participant),
42-
peerEventsStream: make(chan peer.Message),
43-
logger: logrus.WithFields(logrus.Fields{"conf_id": confID}),
39+
id: confID,
40+
config: config,
41+
signaling: signaling,
42+
participants: make(map[ParticipantID]*Participant),
43+
peerEvents: make(chan common.Message[ParticipantID, peer.MessageContent]),
44+
logger: logrus.WithFields(logrus.Fields{"conf_id": confID}),
4445
}
4546

4647
// Start conference "main loop".
@@ -49,7 +50,7 @@ func NewConference(confID string, config Config, signaling signaling.MatrixSigna
4950
}
5051

5152
// New participant tries to join the conference.
52-
func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.CallInviteEventContent) {
53+
func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) {
5354
// As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event,
5455
// any existing calls from this device in this call should be terminated.
5556
// TODO: Implement this.
@@ -67,7 +68,16 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.
6768
}
6869
}
6970

70-
peer, sdpOffer, err := peer.NewPeer(participantID, c.id, inviteEvent.Offer.SDP, c.peerEventsStream)
71+
var (
72+
participantlogger = logrus.WithFields(logrus.Fields{
73+
"user_id": participantID.UserID,
74+
"device_id": participantID.DeviceID,
75+
"conf_id": c.id,
76+
})
77+
messageSink = common.NewMessageSink(participantID, c.peerEvents)
78+
)
79+
80+
peer, sdpOffer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, participantlogger)
7181
if err != nil {
7282
c.logger.WithError(err).Errorf("Failed to create new peer")
7383
return
@@ -88,11 +98,11 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.
8898
c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpOffer.SDP)
8999
}
90100

91-
func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCandidatesEventContent) {
92-
if participant := c.getParticipant(peerID, nil); participant != nil {
101+
func (c *Conference) OnCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) {
102+
if participant := c.getParticipant(participantID, nil); participant != nil {
93103
// Convert the candidates to the WebRTC format.
94-
candidates := make([]webrtc.ICECandidateInit, len(candidatesEvent.Candidates))
95-
for i, candidate := range candidatesEvent.Candidates {
104+
candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates))
105+
for i, candidate := range ev.Candidates {
96106
SDPMLineIndex := uint16(candidate.SDPMLineIndex)
97107
candidates[i] = webrtc.ICECandidateInit{
98108
Candidate: candidate.Candidate,
@@ -105,19 +115,19 @@ func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCan
105115
}
106116
}
107117

108-
func (c *Conference) OnSelectAnswer(peerID peer.ID, selectAnswerEvent *event.CallSelectAnswerEventContent) {
109-
if participant := c.getParticipant(peerID, nil); participant != nil {
110-
if selectAnswerEvent.SelectedPartyID != peerID.DeviceID.String() {
118+
func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) {
119+
if participant := c.getParticipant(participantID, nil); participant != nil {
120+
if ev.SelectedPartyID != participantID.DeviceID.String() {
111121
c.logger.WithFields(logrus.Fields{
112-
"device_id": selectAnswerEvent.SelectedPartyID,
122+
"device_id": ev.SelectedPartyID,
113123
}).Errorf("Call was answered on a different device, kicking this peer")
114124
participant.peer.Terminate()
115125
}
116126
}
117127
}
118128

119-
func (c *Conference) OnHangup(peerID peer.ID, hangupEvent *event.CallHangupEventContent) {
120-
if participant := c.getParticipant(peerID, nil); participant != nil {
129+
func (c *Conference) OnHangup(participantID ParticipantID, ev *event.CallHangupEventContent) {
130+
if participant := c.getParticipant(participantID, nil); participant != nil {
121131
participant.peer.Terminate()
122132
}
123133
}

src/conference/participant.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,23 @@ import (
1414

1515
var ErrInvalidSFUMessage = errors.New("invalid SFU message")
1616

17+
type ParticipantID struct {
18+
UserID id.UserID
19+
DeviceID id.DeviceID
20+
}
21+
1722
type Participant struct {
18-
id peer.ID
19-
peer *peer.Peer
23+
id ParticipantID
24+
peer *peer.Peer[ParticipantID]
2025
remoteSessionID id.SessionID
2126
streamMetadata event.CallSDPStreamMetadata
2227
publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP
2328
}
2429

2530
func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient {
2631
return signaling.MatrixRecipient{
27-
ID: p.id,
32+
UserID: p.id.UserID,
33+
DeviceID: p.id.DeviceID,
2834
RemoteSessionID: p.remoteSessionID,
2935
}
3036
}
@@ -44,12 +50,12 @@ func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) error {
4450
return nil
4551
}
4652

47-
func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error) *Participant {
48-
participant, ok := c.participants[peerID]
53+
func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant {
54+
participant, ok := c.participants[participantID]
4955
if !ok {
5056
logEntry := c.logger.WithFields(logrus.Fields{
51-
"user_id": peerID.UserID,
52-
"device_id": peerID.DeviceID,
57+
"user_id": participantID.UserID,
58+
"device_id": participantID.DeviceID,
5359
})
5460

5561
if optionalErrorMessage != nil {
@@ -64,7 +70,7 @@ func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error)
6470
return participant
6571
}
6672

67-
func (c *Conference) getStreamsMetadata(forParticipant peer.ID) event.CallSDPStreamMetadata {
73+
func (c *Conference) getStreamsMetadata(forParticipant ParticipantID) event.CallSDPStreamMetadata {
6874
streamsMetadata := make(event.CallSDPStreamMetadata)
6975
for id, participant := range c.participants {
7076
if forParticipant != id {

src/conference/messages.go renamed to src/conference/processor.go

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,36 @@ import (
44
"encoding/json"
55
"errors"
66

7+
"github.com/matrix-org/waterfall/src/common"
78
"github.com/matrix-org/waterfall/src/peer"
89
"maunium.net/go/mautrix/event"
910
)
1011

1112
func (c *Conference) processMessages() {
1213
for {
1314
// Read a message from the stream (of type peer.Message) and process it.
14-
message := <-c.peerEventsStream
15+
message := <-c.peerEvents
1516
c.processPeerMessage(message)
1617
}
1718
}
1819

19-
//nolint:funlen
20-
func (c *Conference) processPeerMessage(message peer.Message) {
20+
func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) {
2121
// Since Go does not support ADTs, we have to use a switch statement to
2222
// determine the actual type of the message.
23-
switch msg := message.(type) {
23+
24+
participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant"))
25+
if participant == nil {
26+
return
27+
}
28+
29+
switch msg := message.Content.(type) {
2430
case peer.JoinedTheCall:
2531
case peer.LeftTheCall:
26-
delete(c.participants, msg.Sender)
32+
delete(c.participants, message.Sender)
2733
// TODO: Send new metadata about available streams to all participants.
2834
// TODO: Send the hangup event over the Matrix back to the user.
2935

3036
case peer.NewTrackPublished:
31-
participant := c.getParticipant(msg.Sender, errors.New("New track published from unknown participant"))
32-
if participant == nil {
33-
return
34-
}
35-
3637
key := event.SFUTrackDescription{
3738
StreamID: msg.Track.StreamID(),
3839
TrackID: msg.Track.ID(),
@@ -46,11 +47,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
4647
participant.publishedTracks[key] = msg.Track
4748

4849
case peer.PublishedTrackFailed:
49-
participant := c.getParticipant(msg.Sender, errors.New("Published track failed from unknown participant"))
50-
if participant == nil {
51-
return
52-
}
53-
5450
delete(participant.publishedTracks, event.SFUTrackDescription{
5551
StreamID: msg.Track.StreamID(),
5652
TrackID: msg.Track.ID(),
@@ -59,11 +55,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
5955
// TODO: Should we remove the local tracks from every subscriber as well? Or will it happen automatically?
6056

6157
case peer.NewICECandidate:
62-
participant := c.getParticipant(msg.Sender, errors.New("ICE candidate from unknown participant"))
63-
if participant == nil {
64-
return
65-
}
66-
6758
// Convert WebRTC ICE candidate to Matrix ICE candidate.
6859
jsonCandidate := msg.Candidate.ToJSON()
6960
candidates := []event.CallCandidate{{
@@ -74,20 +65,10 @@ func (c *Conference) processPeerMessage(message peer.Message) {
7465
c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates)
7566

7667
case peer.ICEGatheringComplete:
77-
participant := c.getParticipant(msg.Sender, errors.New("Received ICE complete from unknown participant"))
78-
if participant == nil {
79-
return
80-
}
81-
8268
// Send an empty array of candidates to indicate that ICE gathering is complete.
8369
c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient())
8470

8571
case peer.RenegotiationRequired:
86-
participant := c.getParticipant(msg.Sender, errors.New("Renegotiation from unknown participant"))
87-
if participant == nil {
88-
return
89-
}
90-
9172
toSend := event.SFUMessage{
9273
Op: event.SFUOperationOffer,
9374
SDP: msg.Offer.SDP,
@@ -97,11 +78,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
9778
participant.sendDataChannelMessage(toSend)
9879

9980
case peer.DataChannelMessage:
100-
participant := c.getParticipant(msg.Sender, errors.New("Data channel message from unknown participant"))
101-
if participant == nil {
102-
return
103-
}
104-
10581
var sfuMessage event.SFUMessage
10682
if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil {
10783
c.logger.Errorf("Failed to unmarshal SFU message: %v", err)
@@ -111,11 +87,6 @@ func (c *Conference) processPeerMessage(message peer.Message) {
11187
c.handleDataChannelMessage(participant, sfuMessage)
11288

11389
case peer.DataChannelAvailable:
114-
participant := c.getParticipant(msg.Sender, errors.New("Data channel available from unknown participant"))
115-
if participant == nil {
116-
return
117-
}
118-
11990
toSend := event.SFUMessage{
12091
Op: event.SFUOperationMetadata,
12192
Metadata: c.getStreamsMetadata(participant.id),

src/peer/channel.go

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/peer/id.go

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)