diff --git a/engine.go b/engine.go index 3cbdce0e..c92b0776 100644 --- a/engine.go +++ b/engine.go @@ -37,17 +37,22 @@ const ( ) type RTCEngine struct { - log protoLogger.Logger - pclock sync.Mutex - publisher *PCTransport - subscriber *PCTransport - client *SignalClient - dclock sync.RWMutex - reliableDC *webrtc.DataChannel - lossyDC *webrtc.DataChannel - reliableDCSub *webrtc.DataChannel - lossyDCSub *webrtc.DataChannel - trackPublishedChan chan *livekit.TrackPublishedResponse + log protoLogger.Logger + + pclock sync.Mutex + publisher *PCTransport + subscriber *PCTransport + client *SignalClient + + dclock sync.RWMutex + reliableDC *webrtc.DataChannel + lossyDC *webrtc.DataChannel + reliableDCSub *webrtc.DataChannel + lossyDCSub *webrtc.DataChannel + + trackPublishedListenersLock sync.Mutex + trackPublishedListeners map[string]chan *livekit.TrackPublishedResponse + subscriberPrimary bool hasConnected atomic.Bool hasPublish atomic.Bool @@ -79,10 +84,10 @@ type RTCEngine struct { func NewRTCEngine() *RTCEngine { e := &RTCEngine{ - log: logger, - client: NewSignalClient(), - trackPublishedChan: make(chan *livekit.TrackPublishedResponse, 1), - JoinTimeout: 15 * time.Second, + log: logger, + client: NewSignalClient(), + trackPublishedListeners: make(map[string]chan *livekit.TrackPublishedResponse), + JoinTimeout: 15 * time.Second, } e.client.OnParticipantUpdate = func(info []*livekit.ParticipantInfo) { @@ -207,10 +212,6 @@ func (e *RTCEngine) Subscriber() (*PCTransport, bool) { return e.subscriber, e.subscriber != nil } -func (e *RTCEngine) TrackPublishedChan() <-chan *livekit.TrackPublishedResponse { - return e.trackPublishedChan -} - func (e *RTCEngine) setRTT(rtt uint32) { if subscriber, ok := e.Subscriber(); ok { subscriber.SetRTT(rtt) @@ -472,8 +473,26 @@ func (e *RTCEngine) dataPubChannelReady() bool { return e.reliableDC.ReadyState() == webrtc.DataChannelStateOpen && e.lossyDC.ReadyState() == webrtc.DataChannelStateOpen } +func (e *RTCEngine) RegisterTrackPublishedListener(cid string, c chan *livekit.TrackPublishedResponse) { + e.trackPublishedListenersLock.Lock() + e.trackPublishedListeners[cid] = c + e.trackPublishedListenersLock.Unlock() +} + +func (e *RTCEngine) UnregisterTrackPublishedListener(cid string) { + e.trackPublishedListenersLock.Lock() + delete(e.trackPublishedListeners, cid) + e.trackPublishedListenersLock.Unlock() +} + func (e *RTCEngine) handleLocalTrackPublished(res *livekit.TrackPublishedResponse) { - e.trackPublishedChan <- res + e.trackPublishedListenersLock.Lock() + listener, ok := e.trackPublishedListeners[res.Cid] + e.trackPublishedListenersLock.Unlock() + + if ok { + listener <- res + } } func (e *RTCEngine) handleLocalTrackUnpublished(res *livekit.TrackUnpublishedResponse) { diff --git a/localparticipant.go b/localparticipant.go index c4c4721f..dfd23958 100644 --- a/localparticipant.go +++ b/localparticipant.go @@ -107,16 +107,19 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl publisher.Negotiate() - pubChan := p.engine.TrackPublishedChan() - var pubRes *livekit.TrackPublishedResponse + pubChan := make(chan *livekit.TrackPublishedResponse, 1) + p.engine.RegisterTrackPublishedListener(track.ID(), pubChan) + var pubRes *livekit.TrackPublishedResponse select { case pubRes = <-pubChan: break case <-time.After(trackPublishTimeout): + p.engine.UnregisterTrackPublishedListener(track.ID()) return nil, ErrTrackPublishTimeout } + p.engine.UnregisterTrackPublishedListener(track.ID()) pub.updateInfo(pubRes.Track) p.addPublication(pub) @@ -124,7 +127,6 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl p.roomCallback.OnLocalTrackPublished(pub, p) p.engine.log.Infow("published track", "name", opts.Name, "source", opts.Source.String(), "trackID", pubRes.Track.Sid) - return pub, nil } @@ -191,16 +193,19 @@ func (p *LocalParticipant) PublishSimulcastTrack(tracks []*LocalTrack, opts *Tra return nil, err } - pubChan := p.engine.TrackPublishedChan() - var pubRes *livekit.TrackPublishedResponse + pubChan := make(chan *livekit.TrackPublishedResponse, 1) + p.engine.RegisterTrackPublishedListener(mainTrack.ID(), pubChan) + var pubRes *livekit.TrackPublishedResponse select { case pubRes = <-pubChan: break case <-time.After(trackPublishTimeout): + p.engine.UnregisterTrackPublishedListener(mainTrack.ID()) return nil, ErrTrackPublishTimeout } + p.engine.UnregisterTrackPublishedListener(mainTrack.ID()) publisher, ok := p.engine.Publisher() if !ok { return nil, ErrNoPeerConnection @@ -425,7 +430,7 @@ func (p *LocalParticipant) UnpublishTrack(sid string) error { p.Callback.OnLocalTrackUnpublished(pub, p) p.roomCallback.OnLocalTrackUnpublished(pub, p) - p.engine.log.Infow("unpublished track", "name", pub.Name(), "sid", sid) + p.engine.log.Infow("unpublished track", "name", pub.Name(), "trackID", sid) return err }