diff --git a/examples/transcription/transcription/agent.py b/examples/transcription/transcription/agent.py index 2334b36..33e908e 100644 --- a/examples/transcription/transcription/agent.py +++ b/examples/transcription/transcription/agent.py @@ -1,6 +1,6 @@ import asyncio -from fishjam.agent import Agent, AgentResponseTrackData +from fishjam.agent import Agent from transcription.worker import BackgroundWorker from .transcription import TranscriptionSession @@ -11,24 +11,19 @@ def __init__(self, room_id: str, agent: Agent, worker: BackgroundWorker): self._room_id = room_id self._agent = agent self._sessions: dict[str, TranscriptionSession] = {} - self._disconnect = asyncio.Event() self._worker = worker - - @agent.on_track_data - def _(track_data: AgentResponseTrackData): - if track_data.peer_id not in self._sessions: - return - self._sessions[track_data.peer_id].transcribe(track_data.data) + self._task: asyncio.Task[None] | None = None async def _start(self): - async with self._agent: + async with self._agent.connect() as session: print(f"Agent connected to room {self._room_id}") - await self._disconnect.wait() - self._disconnect.clear() - print(f"Agent disconnected from room {self._room_id}") - def _stop(self): - self._disconnect.set() + async for track_data in session.receive(): + if track_data.peer_id not in self._sessions: + return + self._sessions[track_data.peer_id].transcribe(track_data.data) + + print(f"Agent disconnected from room {self._room_id}") def _handle_transcription(self, peer_id: str, text: str): print(f"Peer {peer_id} in room {self._room_id} said: {text}") @@ -38,7 +33,7 @@ def on_peer_enter(self, peer_id: str): return if len(self._sessions) == 0: - self._worker.run_in_background(self._start()) + self._task = self._worker.run_in_background(self._start()) session = TranscriptionSession(lambda t: self._handle_transcription(peer_id, t)) self._sessions[peer_id] = session @@ -48,8 +43,8 @@ def on_peer_leave(self, peer_id: str): if peer_id not in self._sessions: return - self._sessions[peer_id].end() - self._sessions.pop(peer_id) + self._sessions.pop(peer_id).end() - if len(self._sessions) == 0: - self._stop() + if len(self._sessions) == 0 and self._task is not None: + self._task.cancel() + self._task = None diff --git a/examples/transcription/transcription/notifier.py b/examples/transcription/transcription/notifier.py index 058b94d..461c7f5 100644 --- a/examples/transcription/transcription/notifier.py +++ b/examples/transcription/transcription/notifier.py @@ -1,5 +1,9 @@ from fishjam import FishjamNotifier -from fishjam.events import ServerMessagePeerConnected, ServerMessagePeerDisconnected +from fishjam.events import ( + ServerMessagePeerConnected, + ServerMessagePeerDisconnected, + ServerMessagePeerType, +) from fishjam.events.allowed_notifications import AllowedNotification from .config import FISHJAM_ID, FISHJAM_TOKEN, FISHJAM_URL @@ -12,10 +16,18 @@ def make_notifier(room_service: RoomService): @notifier.on_server_notification def _(notification: AllowedNotification): match notification: - case ServerMessagePeerConnected(peer_id=peer_id, room_id=room_id): + case ServerMessagePeerConnected( + peer_id=peer_id, + room_id=room_id, + peer_type=ServerMessagePeerType.PEER_TYPE_WEBRTC, + ): handle_peer_connected(peer_id, room_id) - case ServerMessagePeerDisconnected(peer_id=peer_id, room_id=room_id): + case ServerMessagePeerDisconnected( + peer_id=peer_id, + room_id=room_id, + peer_type=ServerMessagePeerType.PEER_TYPE_WEBRTC, + ): handle_peer_disconnected(peer_id, room_id) def handle_peer_connected(peer_id: str, room_id: str): diff --git a/examples/transcription/transcription/worker.py b/examples/transcription/transcription/worker.py index 3c21cd5..3bff365 100644 --- a/examples/transcription/transcription/worker.py +++ b/examples/transcription/transcription/worker.py @@ -12,6 +12,7 @@ def run_in_background(self, coro: Coroutine[Any, Any, None]): task = self._tg.create_task(coro) task.add_done_callback(self._remove_task) self._tasks.add(task) + return task def _remove_task(self, task: Task[None]): self._tasks.discard(task) diff --git a/fishjam/_ws_notifier.py b/fishjam/_ws_notifier.py index d5417f4..ea288ec 100644 --- a/fishjam/_ws_notifier.py +++ b/fishjam/_ws_notifier.py @@ -3,6 +3,7 @@ """ import asyncio +import inspect from collections.abc import Coroutine from typing import Any, Callable, cast @@ -142,7 +143,7 @@ async def _receive_loop(self): if isinstance(message, ALLOWED_NOTIFICATIONS): res = self._notification_handler(message) - if asyncio.iscoroutine(res): + if inspect.isawaitable(res): await res async def _subscribe_event(self, event: ServerMessageEventType): diff --git a/fishjam/agent/__init__.py b/fishjam/agent/__init__.py index ea5be67..2b4601b 100644 --- a/fishjam/agent/__init__.py +++ b/fishjam/agent/__init__.py @@ -1,10 +1,18 @@ -from .agent import Agent, AgentResponseTrackData, TrackDataHandler +from .agent import ( + Agent, + AgentSession, + AudioTrackOptions, + IncomingTrackData, + OutgoingTrack, +) from .errors import AgentAuthError, AgentError __all__ = [ "Agent", "AgentError", + "AgentSession", "AgentAuthError", - "TrackDataHandler", - "AgentResponseTrackData", + "IncomingTrackData", + "OutgoingTrack", + "AudioTrackOptions", ] diff --git a/fishjam/agent/agent.py b/fishjam/agent/agent.py index ba7776b..f0ecd11 100644 --- a/fishjam/agent/agent.py +++ b/fishjam/agent/agent.py @@ -1,118 +1,170 @@ -""" -Class for implementing Fishjam agents -""" +from __future__ import annotations -import asyncio -import functools -from contextlib import suppress -from types import TracebackType -from typing import Any, Callable, TypeAlias, TypeVar +import json +import uuid +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, AsyncIterator, Literal import betterproto -from websockets import ClientConnection, CloseCode, ConnectionClosed +from websockets import ClientConnection, ConnectionClosed from websockets.asyncio import client from fishjam.agent.errors import AgentAuthError from fishjam.events._protos.fishjam import ( AgentRequest, + AgentRequestAddTrack, + AgentRequestAddTrackCodecParameters, AgentRequestAuthRequest, AgentResponse, - AgentResponseTrackData, ) +from fishjam.events._protos.fishjam import AgentRequestTrackData as OutgoingTrackData +from fishjam.events._protos.fishjam import AgentResponseTrackData as IncomingTrackData +from fishjam.events._protos.fishjam.notifications import Track, TrackEncoding, TrackType -TrackDataHandler: TypeAlias = Callable[[AgentResponseTrackData], None] +IncomingAgentMessage = IncomingTrackData -TrackDataHandlerT = TypeVar("TrackDataHandlerT", bound=TrackDataHandler) +@dataclass +class AudioTrackOptions: + """Parameters of an outgoing audio track.""" -def _close_ok(e: ConnectionClosed): - return e.code == CloseCode.NORMAL_CLOSURE + encoding: TrackEncoding = TrackEncoding.TRACK_ENCODING_UNSPECIFIED + """ + The encoding of the audio source. + Defaults to raw 16-bit PCM. + """ + sample_rate: Literal[16000, 24000] = 16000 + """ + The sample rate of the audio source. + Defaults to 16000. + """ + channels: Literal[1, 2] = 1 + """ + The number of channels in the audio source. + Supported values are 1 (mono) and 2 (stereo). + Defaults to 1 (mono) + """ + metadata: dict[str, Any] | None = None + """ + Custom metadata for the track. + Must be JSON-encodable. + """ -class Agent: + +@dataclass(frozen=True) +class OutgoingTrack: """ - Allows for connecting to a Fishjam room as an agent peer. - Provides callbacks for receiving audio. + Represents an outgoing track of an agent connected to Fishjam, + created by :func:`Agent.add_track`. """ - def __init__(self, id: str, token: str, fishjam_url: str): + id: str + """The global identifier of the track.""" + session: AgentSession + """The agent the track belongs to.""" + options: AudioTrackOptions + """The parameters used to create the track.""" + + async def send_chunk(self, data: bytes): """ - Create FishjamAgent instance, providing the fishjam id and management token. + Send a chunk of audio to Fishjam on this track. + + Peers connected to the room of the agent will receive this data. """ + message = AgentRequest( + track_data=OutgoingTrackData( + track_id=self.id, + data=data, + ) + ) - self.id = id - self._socket_url = f"{fishjam_url}/socket/agent/websocket".replace("http", "ws") - self._token = token - self._msg_loop: asyncio.Task[None] | None = None - self._end_event = asyncio.Event() + await self.session._send(message) - @functools.singledispatch - def _message_handler(content: Any) -> None: - raise TypeError(f"Unexpected message of type #{type(content)}") - @_message_handler.register - def _(_content: AgentResponseTrackData): - return +class AgentSession: + def __init__(self, agent: Agent, websocket: ClientConnection): + self.agent = agent - self._dispatch_handler = _message_handler + self._ws = websocket + self._closed = False - def on_track_data(self, handler: TrackDataHandlerT) -> TrackDataHandlerT: + async def receive(self) -> AsyncIterator[IncomingAgentMessage]: """ - Decorator used for defining a handler for track data messages from Fishjam. + Returns an infinite async iterator over the incoming messages from Fishjam to + the agent. """ - self._dispatch_handler.register(AgentResponseTrackData, handler) - return handler - - async def connect(self): + while message := await self._ws.recv(decode=False): + parsed = AgentResponse().parse(message) + _, msg = betterproto.which_one_of(parsed, "content") + match msg: + case IncomingTrackData() as content: + yield content + + async def add_track(self, options: AudioTrackOptions): """ - Connect the agent to Fishjam to start receiving messages. - - Incoming messages from Fishjam will be routed to handlers - defined with :func:`on_track_data`. + Adds a track to the connected agent, with the specified options and metadata. - :raises AgentAuthError: authentication failed + Returns an instance of :class:`OutgoingTrack`, which can be used to send data + over the added track. """ - await self.disconnect() + track_id = uuid.uuid4().hex + metadata_json = json.dumps(options.metadata) + message = AgentRequest( + add_track=AgentRequestAddTrack( + track=Track( + id=track_id, + type=TrackType.TRACK_TYPE_AUDIO, + metadata=metadata_json, + ), + codec_params=AgentRequestAddTrackCodecParameters( + encoding=options.encoding, + sample_rate=options.sample_rate, + channels=options.channels, + ), + ) + ) + await self._send(message) + return OutgoingTrack(id=track_id, session=self, options=options) + + async def _send(self, message: AgentRequest): + await self._ws.send(bytes(message), text=False) - websocket = await client.connect(self._socket_url) - await self._authenticate(websocket) - task = asyncio.create_task(self._recv_loop(websocket)) - - self._msg_loop = task +class Agent: + """ + Allows for connecting to a Fishjam room as an agent peer. + Provides callbacks for receiving audio. + """ - async def disconnect(self, code: CloseCode = CloseCode.NORMAL_CLOSURE): + def __init__(self, id: str, room_id: str, token: str, fishjam_url: str): """ - Disconnect the agent from Fishjam. + Create Agent instance, providing the fishjam id and management token. - Does nothing if already disconnected. + This constructor should not be called directly. + Instead, you should call :func:`fishjam.FishjamClient.create_agent`. """ - if (task := self._msg_loop) is None: - return - event = self._end_event + self.id = id + self.room_id = room_id - self._end_event = asyncio.Event() - self._msg_loop = None + self._socket_url = f"{fishjam_url}/socket/agent/websocket".replace("http", "ws") + self._token = token - task.add_done_callback(lambda _t: event.set()) - if task.cancel(code): - await event.wait() + @asynccontextmanager + async def connect(self): + """ + Connect the agent to Fishjam to start receiving messages. - async def __aenter__(self): - await self.connect() - return self + Incoming messages from Fishjam will be routed to handlers + defined with :func:`on_track_data`. - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ): - if exc_type is not None: - await self.disconnect(CloseCode.INTERNAL_ERROR) - else: - await self.disconnect() + :raises AgentAuthError: authentication failed + """ + async with client.connect(self._socket_url) as websocket: + await self._authenticate(websocket) + yield AgentSession(self, websocket) async def _authenticate(self, websocket: ClientConnection): req = AgentRequest(auth_request=AgentRequestAuthRequest(token=self._token)) @@ -120,30 +172,5 @@ async def _authenticate(self, websocket: ClientConnection): await websocket.send(bytes(req)) # Fishjam will close the socket if auth fails and send a response on success await websocket.recv(decode=False) - except ConnectionClosed as e: - raise AgentAuthError(e.reason) - - async def _recv_loop(self, websocket: ClientConnection): - close_code = CloseCode.NORMAL_CLOSURE - try: - while True: - message = await websocket.recv(decode=False) - message = AgentResponse().parse(message) - - _which, content = betterproto.which_one_of(message, "content") - self._dispatch_handler(content) - except ConnectionClosed as e: - if not _close_ok(e): - close_code = CloseCode.INTERNAL_ERROR - raise - except asyncio.CancelledError as e: - # NOTE: e.args[0] is the close code supplied by disconnect() - # However cancellation can have other causes, which we treat as normal - with suppress(IndexError): - close_code = e.args[0] - raise - except Exception: - close_code = CloseCode.INTERNAL_ERROR - raise - finally: - await websocket.close(close_code) + except ConnectionClosed: + raise AgentAuthError(websocket.close_reason or "") diff --git a/fishjam/api/_fishjam_client.py b/fishjam/api/_fishjam_client.py index 2c832c0..3e0d7e6 100644 --- a/fishjam/api/_fishjam_client.py +++ b/fishjam/api/_fishjam_client.py @@ -141,7 +141,7 @@ def create_agent(self, room_id: str): self._request(room_add_peer, room_id=room_id, body=body), ) - return Agent(resp.data.peer.id, resp.data.token, self._fishjam_url) + return Agent(resp.data.peer.id, room_id, resp.data.token, self._fishjam_url) def create_room(self, options: RoomOptions | None = None) -> Room: """ diff --git a/fishjam/events/__init__.py b/fishjam/events/__init__.py index 1cdbf06..738af9a 100644 --- a/fishjam/events/__init__.py +++ b/fishjam/events/__init__.py @@ -10,6 +10,7 @@ ServerMessagePeerDeleted, ServerMessagePeerDisconnected, ServerMessagePeerMetadataUpdated, + ServerMessagePeerType, ServerMessageRoomCrashed, ServerMessageRoomCreated, ServerMessageRoomDeleted, @@ -43,4 +44,5 @@ "Track", "TrackEncoding", "TrackType", + "ServerMessagePeerType", ] diff --git a/fishjam/types.py b/fishjam/types.py new file mode 100644 index 0000000..d93fbd8 --- /dev/null +++ b/fishjam/types.py @@ -0,0 +1,23 @@ +import functools +from typing import ( + Any, + Callable, + Coroutine, + ParamSpec, + TypeVar, +) + +R = TypeVar("R") +P = ParamSpec("P") + +AsyncCallable = Callable[P, Coroutine[Any, Any, R]] + +AnyCallable = Callable[P, R] | AsyncCallable[P, R] + + +def to_async(f: Callable[P, R]) -> AsyncCallable[P, R]: + @functools.wraps(f) + async def _wrapper(*args: P.args, **kwargs: P.kwargs): + return f(*args, **kwargs) + + return _wrapper diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index c662cd7..f41f4d5 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -86,52 +86,13 @@ async def wait_event(event: asyncio.Event, timeout: float = 5): class TestAgentConnection: - @pytest.mark.asyncio - async def test_connect_disconnect( - self, - room_api: FishjamClient, - room: Room, - agent: Agent, - notifier: FishjamNotifier, - ): - connect_event = asyncio.Event() - disconnect_event = asyncio.Event() - - @notifier.on_server_notification - def _(notification: AllowedNotification): - print(f"Received notification {notification}") - if ( - isinstance(notification, ServerMessagePeerMetadataUpdated) - and notification.peer_id == agent.id - ): - connect_event.set() - if ( - isinstance(notification, ServerMessagePeerDisconnected) - and notification.peer_id == agent.id - ): - disconnect_event.set() - - await agent.connect() - await wait_event(connect_event) - - room = room_api.get_room(room.id) - assert len(room.peers) == 1 - assert room.peers[0].id == agent.id - assert room.peers[0].status == "connected" - - await agent.disconnect() - await wait_event(disconnect_event) - @pytest.mark.asyncio async def test_invalid_auth(self, room_api: FishjamClient): - agent = Agent("fake-id", "fake-token", room_api._fishjam_url) - - with pytest.raises(AgentAuthError): - await agent.connect() + agent = Agent("fake-id", "room-id", "fake-token", room_api._fishjam_url) with pytest.raises(AgentAuthError): - async with agent: - pass + async with agent.connect(): + raise RuntimeError("Connect should have raised AgentAuthError.") @pytest.mark.asyncio async def test_context_manager( @@ -158,7 +119,7 @@ def _(notification: AllowedNotification): ): disconnect_event.set() - async with agent: + async with agent.connect(): await wait_event(connect_event) room = room_api.get_room(room.id) diff --git a/tests/support/peer_socket.py b/tests/support/peer_socket.py index 589572e..7159329 100644 --- a/tests/support/peer_socket.py +++ b/tests/support/peer_socket.py @@ -1,9 +1,7 @@ -# pylint: disable=locally-disabled, missing-class-docstring, missing-function-docstring, redefined-outer-name, too-few-public-methods, missing-module-docstring - import asyncio import betterproto -from websockets import client +from websockets.asyncio import client from websockets.exceptions import ConnectionClosedOK from tests.support.protos.fishjam import ( @@ -29,7 +27,7 @@ async def connect(self, token): await websocket.send(bytes(msg)) try: - message = await websocket.recv() + message = await websocket.recv(decode=False) except ConnectionClosedOK as exception: raise RuntimeError from exception diff --git a/tests/test_room_api.py b/tests/test_room_api.py index a1390d4..57768cb 100644 --- a/tests/test_room_api.py +++ b/tests/test_room_api.py @@ -74,12 +74,6 @@ def test_no_params(self, room_api: FishjamClient): webhook_url=None, room_type=RoomConfigRoomType(CONFERENCE), ) - config.__setitem__("roomId", room.config.__getitem__("roomId")) - config.__setitem__( - "peerlessPurgeTimeout", room.config.__getitem__("peerlessPurgeTimeout") - ) - config.__setitem__("geoLoc", room.config.__getitem__("geoLoc")) - config.__setitem__("s3UploadConfig", room.config.__getitem__("s3UploadConfig")) assert room == Room( config=config, @@ -103,12 +97,6 @@ def test_valid_params(self, room_api): webhook_url=None, room_type=RoomConfigRoomType(AUDIO_ONLY), ) - config.__setitem__("roomId", room.config.__getitem__("roomId")) - config.__setitem__( - "peerlessPurgeTimeout", room.config.__getitem__("peerlessPurgeTimeout") - ) - config.__setitem__("geoLoc", room.config.__getitem__("geoLoc")) - config.__setitem__("s3UploadConfig", room.config.__getitem__("s3UploadConfig")) assert room == Room( config=config, @@ -164,12 +152,6 @@ def test_valid(self, room_api: FishjamClient): webhook_url=None, room_type=RoomConfigRoomType(CONFERENCE), ) - config.__setitem__("roomId", room.config.__getitem__("roomId")) - config.__setitem__( - "peerlessPurgeTimeout", room.config.__getitem__("peerlessPurgeTimeout") - ) - config.__setitem__("geoLoc", room.config.__getitem__("geoLoc")) - config.__setitem__("s3UploadConfig", room.config.__getitem__("s3UploadConfig")) assert Room( peers=[],