Skip to content

Commit 4e41b5d

Browse files
Allow publishing tracks and simplify agent code (#40)
Co-authored-by: p1003 <[email protected]>
1 parent 6caec1e commit 4e41b5d

File tree

12 files changed

+201
-191
lines changed

12 files changed

+201
-191
lines changed
Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22

3-
from fishjam.agent import Agent, AgentResponseTrackData
3+
from fishjam.agent import Agent
44
from transcription.worker import BackgroundWorker
55

66
from .transcription import TranscriptionSession
@@ -11,24 +11,19 @@ def __init__(self, room_id: str, agent: Agent, worker: BackgroundWorker):
1111
self._room_id = room_id
1212
self._agent = agent
1313
self._sessions: dict[str, TranscriptionSession] = {}
14-
self._disconnect = asyncio.Event()
1514
self._worker = worker
16-
17-
@agent.on_track_data
18-
def _(track_data: AgentResponseTrackData):
19-
if track_data.peer_id not in self._sessions:
20-
return
21-
self._sessions[track_data.peer_id].transcribe(track_data.data)
15+
self._task: asyncio.Task[None] | None = None
2216

2317
async def _start(self):
24-
async with self._agent:
18+
async with self._agent.connect() as session:
2519
print(f"Agent connected to room {self._room_id}")
26-
await self._disconnect.wait()
27-
self._disconnect.clear()
28-
print(f"Agent disconnected from room {self._room_id}")
2920

30-
def _stop(self):
31-
self._disconnect.set()
21+
async for track_data in session.receive():
22+
if track_data.peer_id not in self._sessions:
23+
return
24+
self._sessions[track_data.peer_id].transcribe(track_data.data)
25+
26+
print(f"Agent disconnected from room {self._room_id}")
3227

3328
def _handle_transcription(self, peer_id: str, text: str):
3429
print(f"Peer {peer_id} in room {self._room_id} said: {text}")
@@ -38,7 +33,7 @@ def on_peer_enter(self, peer_id: str):
3833
return
3934

4035
if len(self._sessions) == 0:
41-
self._worker.run_in_background(self._start())
36+
self._task = self._worker.run_in_background(self._start())
4237

4338
session = TranscriptionSession(lambda t: self._handle_transcription(peer_id, t))
4439
self._sessions[peer_id] = session
@@ -48,8 +43,8 @@ def on_peer_leave(self, peer_id: str):
4843
if peer_id not in self._sessions:
4944
return
5045

51-
self._sessions[peer_id].end()
52-
self._sessions.pop(peer_id)
46+
self._sessions.pop(peer_id).end()
5347

54-
if len(self._sessions) == 0:
55-
self._stop()
48+
if len(self._sessions) == 0 and self._task is not None:
49+
self._task.cancel()
50+
self._task = None

examples/transcription/transcription/notifier.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from fishjam import FishjamNotifier
2-
from fishjam.events import ServerMessagePeerConnected, ServerMessagePeerDisconnected
2+
from fishjam.events import (
3+
ServerMessagePeerConnected,
4+
ServerMessagePeerDisconnected,
5+
ServerMessagePeerType,
6+
)
37
from fishjam.events.allowed_notifications import AllowedNotification
48

59
from .config import FISHJAM_ID, FISHJAM_TOKEN, FISHJAM_URL
@@ -12,10 +16,18 @@ def make_notifier(room_service: RoomService):
1216
@notifier.on_server_notification
1317
def _(notification: AllowedNotification):
1418
match notification:
15-
case ServerMessagePeerConnected(peer_id=peer_id, room_id=room_id):
19+
case ServerMessagePeerConnected(
20+
peer_id=peer_id,
21+
room_id=room_id,
22+
peer_type=ServerMessagePeerType.PEER_TYPE_WEBRTC,
23+
):
1624
handle_peer_connected(peer_id, room_id)
1725

18-
case ServerMessagePeerDisconnected(peer_id=peer_id, room_id=room_id):
26+
case ServerMessagePeerDisconnected(
27+
peer_id=peer_id,
28+
room_id=room_id,
29+
peer_type=ServerMessagePeerType.PEER_TYPE_WEBRTC,
30+
):
1931
handle_peer_disconnected(peer_id, room_id)
2032

2133
def handle_peer_connected(peer_id: str, room_id: str):

examples/transcription/transcription/worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def run_in_background(self, coro: Coroutine[Any, Any, None]):
1212
task = self._tg.create_task(coro)
1313
task.add_done_callback(self._remove_task)
1414
self._tasks.add(task)
15+
return task
1516

1617
def _remove_task(self, task: Task[None]):
1718
self._tasks.discard(task)

fishjam/_ws_notifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import asyncio
6+
import inspect
67
from collections.abc import Coroutine
78
from typing import Any, Callable, cast
89

@@ -142,7 +143,7 @@ async def _receive_loop(self):
142143

143144
if isinstance(message, ALLOWED_NOTIFICATIONS):
144145
res = self._notification_handler(message)
145-
if asyncio.iscoroutine(res):
146+
if inspect.isawaitable(res):
146147
await res
147148

148149
async def _subscribe_event(self, event: ServerMessageEventType):

fishjam/agent/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
from .agent import Agent, AgentResponseTrackData, TrackDataHandler
1+
from .agent import (
2+
Agent,
3+
AgentSession,
4+
AudioTrackOptions,
5+
IncomingTrackData,
6+
OutgoingTrack,
7+
)
28
from .errors import AgentAuthError, AgentError
39

410
__all__ = [
511
"Agent",
612
"AgentError",
13+
"AgentSession",
714
"AgentAuthError",
8-
"TrackDataHandler",
9-
"AgentResponseTrackData",
15+
"IncomingTrackData",
16+
"OutgoingTrack",
17+
"AudioTrackOptions",
1018
]

0 commit comments

Comments
 (0)