Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions examples/transcription/transcription/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio

from fishjam.agent import Agent, AgentResponseTrackData
from fishjam.agent import Agent
from transcription.worker import BackgroundWorker

from .transcription import TranscriptionSession
Expand All @@ -11,24 +11,19 @@ def __init__(self, room_id: str, agent: Agent, worker: BackgroundWorker):
self._room_id = room_id
self._agent = agent
self._sessions: dict[str, TranscriptionSession] = {}
self._disconnect = asyncio.Event()
self._worker = worker

@agent.on_track_data
def _(track_data: AgentResponseTrackData):
if track_data.peer_id not in self._sessions:
return
self._sessions[track_data.peer_id].transcribe(track_data.data)
self._task: asyncio.Task[None] | None = None

async def _start(self):
async with self._agent:
async with self._agent.connect() as session:
print(f"Agent connected to room {self._room_id}")
await self._disconnect.wait()
self._disconnect.clear()
print(f"Agent disconnected from room {self._room_id}")

def _stop(self):
self._disconnect.set()
async for track_data in session.receive():
if track_data.peer_id not in self._sessions:
return
self._sessions[track_data.peer_id].transcribe(track_data.data)

print(f"Agent disconnected from room {self._room_id}")

def _handle_transcription(self, peer_id: str, text: str):
print(f"Peer {peer_id} in room {self._room_id} said: {text}")
Expand All @@ -38,7 +33,7 @@ def on_peer_enter(self, peer_id: str):
return

if len(self._sessions) == 0:
self._worker.run_in_background(self._start())
self._task = self._worker.run_in_background(self._start())

session = TranscriptionSession(lambda t: self._handle_transcription(peer_id, t))
self._sessions[peer_id] = session
Expand All @@ -48,8 +43,8 @@ def on_peer_leave(self, peer_id: str):
if peer_id not in self._sessions:
return

self._sessions[peer_id].end()
self._sessions.pop(peer_id)
self._sessions.pop(peer_id).end()

if len(self._sessions) == 0:
self._stop()
if len(self._sessions) == 0 and self._task is not None:
self._task.cancel()
self._task = None
18 changes: 15 additions & 3 deletions examples/transcription/transcription/notifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from fishjam import FishjamNotifier
from fishjam.events import ServerMessagePeerConnected, ServerMessagePeerDisconnected
from fishjam.events import (
ServerMessagePeerConnected,
ServerMessagePeerDisconnected,
ServerMessagePeerType,
)
from fishjam.events.allowed_notifications import AllowedNotification

from .config import FISHJAM_ID, FISHJAM_TOKEN, FISHJAM_URL
Expand All @@ -12,10 +16,18 @@ def make_notifier(room_service: RoomService):
@notifier.on_server_notification
def _(notification: AllowedNotification):
match notification:
case ServerMessagePeerConnected(peer_id=peer_id, room_id=room_id):
case ServerMessagePeerConnected(
peer_id=peer_id,
room_id=room_id,
peer_type=ServerMessagePeerType.PEER_TYPE_WEBRTC,
):
handle_peer_connected(peer_id, room_id)

case ServerMessagePeerDisconnected(peer_id=peer_id, room_id=room_id):
case ServerMessagePeerDisconnected(
peer_id=peer_id,
room_id=room_id,
peer_type=ServerMessagePeerType.PEER_TYPE_WEBRTC,
):
handle_peer_disconnected(peer_id, room_id)

def handle_peer_connected(peer_id: str, room_id: str):
Expand Down
1 change: 1 addition & 0 deletions examples/transcription/transcription/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def run_in_background(self, coro: Coroutine[Any, Any, None]):
task = self._tg.create_task(coro)
task.add_done_callback(self._remove_task)
self._tasks.add(task)
return task

def _remove_task(self, task: Task[None]):
self._tasks.discard(task)
Expand Down
3 changes: 2 additions & 1 deletion fishjam/_ws_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import asyncio
import inspect
from collections.abc import Coroutine
from typing import Any, Callable, cast

Expand Down Expand Up @@ -142,7 +143,7 @@ async def _receive_loop(self):

if isinstance(message, ALLOWED_NOTIFICATIONS):
res = self._notification_handler(message)
if asyncio.iscoroutine(res):
if inspect.isawaitable(res):
await res

async def _subscribe_event(self, event: ServerMessageEventType):
Expand Down
14 changes: 11 additions & 3 deletions fishjam/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from .agent import Agent, AgentResponseTrackData, TrackDataHandler
from .agent import (
Agent,
AgentSession,
AudioTrackOptions,
IncomingTrackData,
OutgoingTrack,
)
from .errors import AgentAuthError, AgentError

__all__ = [
"Agent",
"AgentError",
"AgentSession",
"AgentAuthError",
"TrackDataHandler",
"AgentResponseTrackData",
"IncomingTrackData",
"OutgoingTrack",
"AudioTrackOptions",
]
Loading