diff --git a/go.mod b/go.mod index adaf0b3..96e2c6a 100644 --- a/go.mod +++ b/go.mod @@ -39,4 +39,4 @@ require ( golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect ) -replace maunium.net/go/mautrix => github.com/matrix-org/mautrix-go v0.0.0-20220817142816-160ea900a20b +replace maunium.net/go/mautrix => github.com/matrix-org/mautrix-go v0.0.0-20221210135932-bd593dd0204b diff --git a/go.sum b/go.sum index e02475c..2578ca8 100644 --- a/go.sum +++ b/go.sum @@ -25,8 +25,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/matrix-org/mautrix-go v0.0.0-20220817142816-160ea900a20b h1:qKvyphdDykNjyF1vJLaVuWCPfNJWNzP7wHvMV5mw+Ss= -github.com/matrix-org/mautrix-go v0.0.0-20220817142816-160ea900a20b/go.mod h1:hHvNi5iKVAiI2MAdAeXHtP4g9BvNEX2rsQpSF/x6Kx4= +github.com/matrix-org/mautrix-go v0.0.0-20221210135932-bd593dd0204b h1:yMsRQmsBWm7wJurYwnyd7H7wZWawhp52ca62W3MqDA8= +github.com/matrix-org/mautrix-go v0.0.0-20221210135932-bd593dd0204b/go.mod h1:hHvNi5iKVAiI2MAdAeXHtP4g9BvNEX2rsQpSF/x6Kx4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= diff --git a/pkg/conference/config.go b/pkg/conference/config.go index a239d60..e1f6b6e 100644 --- a/pkg/conference/config.go +++ b/pkg/conference/config.go @@ -2,7 +2,11 @@ package conference // Configuration for the group conferences (calls). type Config struct { - // Keep-alive timeout for WebRTC connections. If no keep-alive has been received - // from the client for this duration, the connection is considered dead (in seconds). + // Keep-alive timeout for WebRTC connections. If the client doesn't respond + // to an `m.call.ping` with an `m.call.pong` for this amount of time, the + // connection is considered dead. (in seconds, no greater then 30) KeepAliveTimeout int `yaml:"timeout"` + // The time after which we should send another m.call.ping event to the + // client. (in seconds, greater then 30) + PingInterval int `yaml:"pingInterval"` } diff --git a/pkg/conference/data_channel_message_processor.go b/pkg/conference/data_channel_message_processor.go index ec6d22c..72a8996 100644 --- a/pkg/conference/data_channel_message_processor.go +++ b/pkg/conference/data_channel_message_processor.go @@ -2,22 +2,32 @@ package conference import ( "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" "golang.org/x/exp/slices" "maunium.net/go/mautrix/event" ) -// Handle the `SFUMessage` event from the DataChannel message. -func (c *Conference) processSelectDCMessage(participant *Participant, msg event.SFUMessage) { - participant.logger.Info("Received select request over DC") +// Handle the `FocusEvent` from the DataChannel message. +func (c *Conference) processTrackSubscriptionDCMessage( + participant *Participant, msg event.FocusCallTrackSubscriptionEventContent, +) { + participant.logger.Info("Received track subscription request over DC") + + // TODO: Handle unsubscribe // Find tracks based on what we were asked for. - tracks := c.getTracks(msg.Start) + tracks := c.getTracks(msg.Subscribe) + + participant.logger.WithFields(logrus.Fields{ + "tracks_we_got": tracks, + "tracks_we_want": msg, + }).Debug("Tracks to subscribe to") // Let's check if we have all the tracks that we were asked for are there. // If not, we will list which are not available (later on we must inform participant // about it unless the participant retries it). - if len(tracks) != len(msg.Start) { - for _, expected := range msg.Start { + if len(tracks) != len(msg.Subscribe) { + for _, expected := range msg.Subscribe { found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { return track.ID() == expected.TrackID }) @@ -30,6 +40,7 @@ func (c *Conference) processSelectDCMessage(participant *Participant, msg event. // Subscribe to the found tracks. for _, track := range tracks { + participant.logger.WithField("track_id", track.ID()).Debug("Subscribing to track") if err := participant.peer.SubscribeTo(track); err != nil { participant.logger.Errorf("Failed to subscribe to track: %v", err) return @@ -37,38 +48,50 @@ func (c *Conference) processSelectDCMessage(participant *Participant, msg event. } } -func (c *Conference) processAnswerDCMessage(participant *Participant, msg event.SFUMessage) { - participant.logger.Info("Received SDP answer over DC") - - if err := participant.peer.ProcessSDPAnswer(msg.SDP); err != nil { - participant.logger.Errorf("Failed to set SDP answer: %v", err) - return - } -} +func (c *Conference) processNegotiateDCMessage(participant *Participant, msg event.FocusCallNegotiateEventContent) { + participant.streamMetadata = msg.SDPStreamMetadata -func (c *Conference) processPublishDCMessage(participant *Participant, msg event.SFUMessage) { - participant.logger.Info("Received SDP offer over DC") + switch msg.Description.Type { + case event.CallDataTypeOffer: + participant.logger.WithField("SDP", msg.Description.SDP).Trace("Received SDP offer over DC") - answer, err := participant.peer.ProcessSDPOffer(msg.SDP) - if err != nil { - participant.logger.Errorf("Failed to set SDP offer: %v", err) - return - } + answer, err := participant.peer.ProcessSDPOffer(msg.Description.SDP) + if err != nil { + participant.logger.Errorf("Failed to set SDP offer: %v", err) + return + } - participant.streamMetadata = msg.Metadata + participant.sendDataChannelMessage(event.Event{ + Type: event.FocusCallNegotiate, + Content: event.Content{ + Parsed: event.FocusCallNegotiateEventContent{ + Description: event.CallData{ + Type: event.CallDataType(answer.Type.String()), + SDP: answer.SDP, + }, + SDPStreamMetadata: c.getAvailableStreamsFor(participant.id), + }, + }, + }) + case event.CallDataTypeAnswer: + participant.logger.WithField("SDP", msg.Description.SDP).Trace("Received SDP answer over DC") - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: answer.SDP, - Metadata: c.getAvailableStreamsFor(participant.id), - }) + if err := participant.peer.ProcessSDPAnswer(msg.Description.SDP); err != nil { + participant.logger.Errorf("Failed to set SDP answer: %v", err) + return + } + default: + participant.logger.Errorf("Unknown SDP description type") + } } -func (c *Conference) processAliveDCMessage(participant *Participant) { - participant.peer.ProcessHeartbeat() +func (c *Conference) processPongDCMessage(participant *Participant) { + participant.peer.ProcessPong() } -func (c *Conference) processMetadataDCMessage(participant *Participant, msg event.SFUMessage) { - participant.streamMetadata = msg.Metadata +func (c *Conference) processMetadataDCMessage( + participant *Participant, msg event.FocusCallSDPStreamMetadataChangedEventContent, +) { + participant.streamMetadata = msg.SDPStreamMetadata c.resendMetadataToAllExcept(participant.id) } diff --git a/pkg/conference/matrix_message_processor.go b/pkg/conference/matrix_message_processor.go index 20ae788..298dc93 100644 --- a/pkg/conference/matrix_message_processor.go +++ b/pkg/conference/matrix_message_processor.go @@ -52,8 +52,25 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * } else { messageSink := common.NewMessageSink(participantID, c.peerMessages) - keepAliveDeadline := time.Duration(c.config.KeepAliveTimeout) * time.Second - peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger, keepAliveDeadline) + peer, answer, err := peer.NewPeer( + inviteEvent.Offer.SDP, + messageSink, + logger, + peer.PingPongConfig{ + Interval: time.Duration(c.config.PingInterval) * time.Second, + Deadline: time.Duration(c.config.KeepAliveTimeout) * time.Second, + PongChannel: make(chan peer.Pong, common.UnboundedChannelSize), + SendPing: func() { + participant.sendDataChannelMessage(event.Event{ + Type: event.FocusCallPing, + Content: event.Content{}, + }) + }, + OnDeadLine: func() { + messageSink.Send(peer.LeftTheCall{Reason: event.CallHangupKeepAliveTimeout}) + }, + }, + ) if err != nil { logger.WithError(err).Errorf("Failed to process SDP offer") return err diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index f40cf42..79e7d26 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -1,7 +1,6 @@ package conference import ( - "encoding/json" "time" "github.com/matrix-org/waterfall/pkg/peer" @@ -46,8 +45,8 @@ func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { } } -func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) { - jsonToSend, err := json.Marshal(toSend) +func (p *Participant) sendDataChannelMessage(toSend event.Event) { + jsonToSend, err := toSend.MarshalJSON() if err != nil { p.logger.Error("Failed to marshal data channel message") return diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go index 39328ed..6e56485 100644 --- a/pkg/conference/peer_message_processor.go +++ b/pkg/conference/peer_message_processor.go @@ -1,7 +1,6 @@ package conference import ( - "encoding/json" "time" "github.com/matrix-org/waterfall/pkg/peer" @@ -68,40 +67,57 @@ func (c *Conference) processICEGatheringCompleteMessage(participant *Participant func (c *Conference) processRenegotiationRequiredMessage(participant *Participant, msg peer.RenegotiationRequired) { participant.logger.Info("Started renegotiation") - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationOffer, - SDP: msg.Offer.SDP, - Metadata: c.getAvailableStreamsFor(participant.id), + participant.sendDataChannelMessage(event.Event{ + Type: event.FocusCallNegotiate, + Content: event.Content{ + Parsed: event.FocusCallNegotiateEventContent{ + Description: event.CallData{ + Type: event.CallDataType(msg.Offer.Type.String()), + SDP: msg.Offer.SDP, + }, + SDPStreamMetadata: c.getAvailableStreamsFor(participant.id), + }, + }, }) } func (c *Conference) processDataChannelMessage(participant *Participant, msg peer.DataChannelMessage) { participant.logger.Debug("Received data channel message") - var sfuMessage event.SFUMessage - if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { + var focusEvent event.Event + if err := focusEvent.UnmarshalJSON([]byte(msg.Message)); err != nil { c.logger.Errorf("Failed to unmarshal SFU message: %v", err) return } - switch sfuMessage.Op { - case event.SFUOperationSelect: - c.processSelectDCMessage(participant, sfuMessage) - case event.SFUOperationAnswer: - c.processAnswerDCMessage(participant, sfuMessage) - case event.SFUOperationPublish, event.SFUOperationUnpublish: - c.processPublishDCMessage(participant, sfuMessage) - case event.SFUOperationAlive: - c.processAliveDCMessage(participant) - case event.SFUOperationMetadata: - c.processMetadataDCMessage(participant, sfuMessage) + // FIXME: We should be able to do + // focusEvent.Content.ParseRaw(focusEvent.Type) but it throws an error. + switch focusEvent.Type.Type { + case event.FocusCallTrackSubscription.Type: + focusEvent.Content.ParseRaw(event.FocusCallTrackSubscription) + c.processTrackSubscriptionDCMessage(participant, *focusEvent.Content.AsFocusCallTrackSubscription()) + case event.FocusCallNegotiate.Type: + focusEvent.Content.ParseRaw(event.FocusCallNegotiate) + c.processNegotiateDCMessage(participant, *focusEvent.Content.AsFocusCallNegotiate()) + case event.FocusCallPong.Type: + focusEvent.Content.ParseRaw(event.FocusCallPong) + c.processPongDCMessage(participant) + case event.FocusCallSDPStreamMetadataChanged.Type: + focusEvent.Content.ParseRaw(event.FocusCallSDPStreamMetadataChanged) + c.processMetadataDCMessage(participant, *focusEvent.Content.AsFocusCallSDPStreamMetadataChanged()) + default: + participant.logger.WithField("type", focusEvent.Type.Type).Warn("Received data channel message of unknown type") } } func (c *Conference) processDataChannelAvailableMessage(participant *Participant, msg peer.DataChannelAvailable) { participant.logger.Info("Connected data channel") - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationMetadata, - Metadata: c.getAvailableStreamsFor(participant.id), + participant.sendDataChannelMessage(event.Event{ + Type: event.FocusCallSDPStreamMetadataChanged, + Content: event.Content{ + Parsed: event.FocusCallSDPStreamMetadataChangedEventContent{ + SDPStreamMetadata: c.getAvailableStreamsFor(participant.id), + }, + }, }) } diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 838479f..2a11ea0 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -94,7 +94,7 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. } // Helper that returns the list of streams inside this conference that match the given stream IDs and track IDs. -func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { +func (c *Conference) getTracks(identifiers []event.FocusTrackDescription) []*webrtc.TrackLocalStaticRTP { tracks := make([]*webrtc.TrackLocalStaticRTP, 0) for _, participant := range c.participants { // Check if this participant has any of the tracks that we're looking for. @@ -112,9 +112,13 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt func (c *Conference) resendMetadataToAllExcept(exceptMe ParticipantID) { for participantID, participant := range c.participants { if participantID != exceptMe { - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationMetadata, - Metadata: c.getAvailableStreamsFor(participantID), + participant.sendDataChannelMessage(event.Event{ + Type: event.FocusCallSDPStreamMetadataChanged, + Content: event.Content{ + Parsed: event.FocusCallSDPStreamMetadataChangedEventContent{ + SDPStreamMetadata: c.getAvailableStreamsFor(participantID), + }, + }, }) } } diff --git a/pkg/config/config.go b/pkg/config/config.go index f9fde9e..87c04b0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -74,10 +74,13 @@ func LoadConfigFromString(configString string) (*Config, error) { return nil, fmt.Errorf("failed to unmarshal YAML file: %w", err) } + // TODO: We should split these up and add error messages if config.Matrix.UserID == "" || config.Matrix.HomeserverURL == "" || config.Matrix.AccessToken == "" || - config.Conference.KeepAliveTimeout == 0 { + config.Conference.KeepAliveTimeout == 0 || + config.Conference.KeepAliveTimeout > 30 || + config.Conference.PingInterval < 30 { return nil, errors.New("invalid config values") } diff --git a/pkg/peer/keepalive.go b/pkg/peer/keepalive.go deleted file mode 100644 index b66f2ed..0000000 --- a/pkg/peer/keepalive.go +++ /dev/null @@ -1,23 +0,0 @@ -package peer - -import "time" - -type HeartBeat struct{} - -// Starts a goroutine that will execute `onDeadLine` closure in case nothing has been published -// to the `heartBeat` channel for `deadline` duration. The goroutine stops once the channel is closed. -func startKeepAlive(deadline time.Duration, heartBeat <-chan HeartBeat, onDeadLine func()) { - go func() { - for { - select { - case <-time.After(deadline): - onDeadLine() - return - case _, ok := <-heartBeat: - if !ok { - return - } - } - } - }() -} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 5568cc5..1a18404 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -4,14 +4,12 @@ import ( "errors" "io" "sync" - "time" "github.com/matrix-org/waterfall/pkg/common" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" "golang.org/x/exp/slices" - "maunium.net/go/mautrix/event" ) var ( @@ -34,7 +32,7 @@ type Peer[ID comparable] struct { logger *logrus.Entry peerConnection *webrtc.PeerConnection sink *common.MessageSink[ID, MessageContent] - heartbeat chan HeartBeat + pingPongConfig PingPongConfig dataChannelMutex sync.Mutex dataChannel *webrtc.DataChannel @@ -45,7 +43,7 @@ func NewPeer[ID comparable]( sdpOffer string, sink *common.MessageSink[ID, MessageContent], logger *logrus.Entry, - keepAliveDeadline time.Duration, + pingPongConfig PingPongConfig, ) (*Peer[ID], *webrtc.SessionDescription, error) { peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) if err != nil { @@ -57,7 +55,7 @@ func NewPeer[ID comparable]( logger: logger, peerConnection: peerConnection, sink: sink, - heartbeat: make(chan HeartBeat, common.UnboundedChannelSize), + pingPongConfig: pingPongConfig, } peerConnection.OnTrack(peer.onRtpTrackReceived) @@ -72,8 +70,7 @@ func NewPeer[ID comparable]( if sdpAnswer, err := peer.ProcessSDPOffer(sdpOffer); err != nil { return nil, nil, err } else { - onDeadline := func() { peer.sink.Send(LeftTheCall{event.CallHangupKeepAliveTimeout}) } - startKeepAlive(keepAliveDeadline, peer.heartbeat, onDeadline) + startPingPong(pingPongConfig) return peer, sdpAnswer, nil } } @@ -252,6 +249,6 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, // New heartbeat received (keep-alive message that is periodically sent by the remote peer). // We need to update the last heartbeat time. If the peer is not active for too long, we will // consider peer's connection as stalled and will close it. -func (p *Peer[ID]) ProcessHeartbeat() { - p.heartbeat <- HeartBeat{} +func (p *Peer[ID]) ProcessPong() { + p.pingPongConfig.PongChannel <- Pong{} } diff --git a/pkg/peer/ping_pong.go b/pkg/peer/ping_pong.go new file mode 100644 index 0000000..bc21732 --- /dev/null +++ b/pkg/peer/ping_pong.go @@ -0,0 +1,33 @@ +package peer + +import "time" + +type Pong struct{} + +type PingPongConfig struct { + Interval time.Duration + Deadline time.Duration + PongChannel chan Pong + SendPing func() + OnDeadLine func() +} + +// Starts a goroutine that will execute `onDeadLine` closure in case nothing has been published +// to the `heartBeat` channel for `deadline` duration. The goroutine stops once the channel is closed. +func startPingPong(config PingPongConfig) { + go func() { + for range time.Tick(config.Interval) { + config.SendPing() + + select { + case <-time.After(config.Deadline): + config.OnDeadLine() + return + case _, ok := <-config.PongChannel: + if !ok { + return + } + } + } + }() +}