|
1 | 1 | import asyncio |
2 | | -from typing import Callable |
3 | 2 |
|
4 | 3 | from fishjam.agent import Agent, AgentResponseTrackData |
| 4 | +from transcription.worker import BackgroundWorker |
5 | 5 |
|
6 | 6 | from .transcription import TranscriptionSession |
7 | 7 |
|
8 | 8 |
|
9 | 9 | class TranscriptionAgent: |
10 | | - def __init__( |
11 | | - self, |
12 | | - room_id: str, |
13 | | - agent: Agent, |
14 | | - on_text: Callable[[str, str], None], |
15 | | - session_factory: Callable[[str], TranscriptionSession], |
16 | | - ): |
| 10 | + def __init__(self, room_id: str, agent: Agent, worker: BackgroundWorker): |
17 | 11 | self._room_id = room_id |
18 | 12 | self._agent = agent |
19 | | - self._peers: dict[str, TranscriptionSession] = {} |
20 | | - self._on_text = on_text |
21 | | - self._session_factory = session_factory |
22 | | - self._leave_event = asyncio.Event() |
23 | | - self.done = False |
| 13 | + self._sessions: dict[str, TranscriptionSession] = {} |
| 14 | + self._disconnect = asyncio.Event() |
| 15 | + self._worker = worker |
24 | 16 |
|
25 | 17 | @agent.on_track_data |
26 | 18 | def _(track_data: AgentResponseTrackData): |
27 | | - if track_data.peer_id not in self._peers: |
| 19 | + if track_data.peer_id not in self._sessions: |
28 | 20 | return |
29 | | - self._peers[track_data.peer_id].transcribe(track_data.data) |
| 21 | + self._sessions[track_data.peer_id].transcribe(track_data.data) |
30 | 22 |
|
31 | | - print(f"Created agent for room {room_id}") |
32 | | - self._task = asyncio.create_task(self._run_agent()) |
33 | | - |
34 | | - async def _run_agent(self): |
35 | | - print(f"Connecting agent to room {self._room_id}") |
| 23 | + async def _start(self): |
36 | 24 | async with self._agent: |
37 | 25 | print(f"Agent connected to room {self._room_id}") |
38 | | - await self._leave_event.wait() |
| 26 | + await self._disconnect.wait() |
| 27 | + self._disconnect.clear() |
39 | 28 | print(f"Agent disconnected from room {self._room_id}") |
40 | 29 |
|
| 30 | + def _stop(self): |
| 31 | + self._disconnect.set() |
| 32 | + |
| 33 | + def _handle_transcription(self, peer_id: str, text: str): |
| 34 | + print(f"Peer {peer_id} in room {self._room_id} said: {text}") |
| 35 | + |
41 | 36 | def on_peer_enter(self, peer_id: str): |
42 | | - if peer_id in self._peers: |
| 37 | + if peer_id in self._sessions: |
43 | 38 | return |
44 | 39 |
|
45 | | - print(f"Starting transcription session for peer {peer_id}") |
| 40 | + if len(self._sessions) == 0: |
| 41 | + self._worker.run_in_background(self._start()) |
46 | 42 |
|
47 | | - self._peers[peer_id] = self._session_factory(peer_id) |
| 43 | + session = TranscriptionSession(lambda t: self._handle_transcription(peer_id, t)) |
| 44 | + self._sessions[peer_id] = session |
| 45 | + self._worker.run_in_background(session.start(peer_id)) |
48 | 46 |
|
49 | 47 | def on_peer_leave(self, peer_id: str): |
50 | | - if peer_id not in self._peers: |
| 48 | + if peer_id not in self._sessions: |
51 | 49 | return |
52 | 50 |
|
53 | | - print(f"Ending transcription session for peer {peer_id}") |
54 | | - |
55 | | - session = self._peers[peer_id] |
56 | | - session.end() |
57 | | - self._peers.pop(peer_id) |
| 51 | + self._sessions[peer_id].end() |
| 52 | + self._sessions.pop(peer_id) |
58 | 53 |
|
59 | | - if len(self._peers) == 0: |
60 | | - self._leave_event.set() |
61 | | - self.done = True |
| 54 | + if len(self._sessions) == 0: |
| 55 | + self._stop() |
0 commit comments