Skip to content

Commit c6a1569

Browse files
Improve transcription demo
1 parent 1f63b78 commit c6a1569

File tree

8 files changed

+455
-123
lines changed

8 files changed

+455
-123
lines changed

examples/transcription/main.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import asyncio
21
from contextlib import asynccontextmanager
32
from typing import Annotated
43

54
from fastapi import Depends, FastAPI
65
from transcription.notifier import make_notifier
76
from transcription.room import RoomService, fishjam
8-
from transcription.transcription import transcription_session_factory
7+
from transcription.worker import async_worker
98

109
from fishjam import PeerOptions, SubscribeOptions
1110

@@ -20,19 +19,14 @@ def get_room_service():
2019

2120
@asynccontextmanager
2221
async def lifespan(_app: FastAPI):
23-
async with (
24-
transcription_session_factory() as session_factory,
25-
asyncio.TaskGroup() as tg,
26-
):
22+
async with async_worker() as worker:
2723
global _room_service
28-
_room_service = RoomService(session_factory)
24+
_room_service = RoomService(worker)
2925
notifier = make_notifier(_room_service)
30-
notifier_task = tg.create_task(notifier.connect())
26+
worker.run_in_background(notifier.connect())
3127

3228
yield
3329

34-
notifier_task.cancel()
35-
3630

3731
app = FastAPI(lifespan=lifespan)
3832

examples/transcription/pyproject.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@ name = "transcription"
33
version = "0.1.0"
44
description = "Add your description here"
55
readme = "README.md"
6-
requires-python = ">=3.12"
7-
dependencies = [
8-
"fastapi[standard]==0.116.0",
9-
"fishjam-server-sdk",
10-
]
6+
requires-python = ">=3.10"
7+
dependencies = ["fastapi[standard]==0.116.0", "fishjam-server-sdk"]
118

129
[tool.uv.sources]
1310
fishjam-server-sdk = { workspace = true }
Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,55 @@
11
import asyncio
2-
from typing import Callable
32

43
from fishjam.agent import Agent, AgentResponseTrackData
4+
from transcription.worker import BackgroundWorker
55

66
from .transcription import TranscriptionSession
77

88

99
class TranscriptionAgent:
10-
def __init__(
11-
self,
12-
room_id: str,
13-
agent: Agent,
14-
on_text: Callable[[str, str], None],
15-
session_factory: Callable[[str], TranscriptionSession],
16-
):
10+
def __init__(self, room_id: str, agent: Agent, worker: BackgroundWorker):
1711
self._room_id = room_id
1812
self._agent = agent
19-
self._peers: dict[str, TranscriptionSession] = {}
20-
self._on_text = on_text
21-
self._session_factory = session_factory
22-
self._leave_event = asyncio.Event()
23-
self.done = False
13+
self._sessions: dict[str, TranscriptionSession] = {}
14+
self._disconnect = asyncio.Event()
15+
self._worker = worker
2416

2517
@agent.on_track_data
2618
def _(track_data: AgentResponseTrackData):
27-
if track_data.peer_id not in self._peers:
19+
if track_data.peer_id not in self._sessions:
2820
return
29-
self._peers[track_data.peer_id].transcribe(track_data.data)
21+
self._sessions[track_data.peer_id].transcribe(track_data.data)
3022

31-
print(f"Created agent for room {room_id}")
32-
self._task = asyncio.create_task(self._run_agent())
33-
34-
async def _run_agent(self):
35-
print(f"Connecting agent to room {self._room_id}")
23+
async def _start(self):
3624
async with self._agent:
3725
print(f"Agent connected to room {self._room_id}")
38-
await self._leave_event.wait()
26+
await self._disconnect.wait()
27+
self._disconnect.clear()
3928
print(f"Agent disconnected from room {self._room_id}")
4029

30+
def _stop(self):
31+
self._disconnect.set()
32+
33+
def _handle_transcription(self, peer_id: str, text: str):
34+
print(f"Peer {peer_id} in room {self._room_id} said: {text}")
35+
4136
def on_peer_enter(self, peer_id: str):
42-
if peer_id in self._peers:
37+
if peer_id in self._sessions:
4338
return
4439

45-
print(f"Starting transcription session for peer {peer_id}")
40+
if len(self._sessions) == 0:
41+
self._worker.run_in_background(self._start())
4642

47-
self._peers[peer_id] = self._session_factory(peer_id)
43+
session = TranscriptionSession(lambda t: self._handle_transcription(peer_id, t))
44+
self._sessions[peer_id] = session
45+
self._worker.run_in_background(session.start(peer_id))
4846

4947
def on_peer_leave(self, peer_id: str):
50-
if peer_id not in self._peers:
48+
if peer_id not in self._sessions:
5149
return
5250

53-
print(f"Ending transcription session for peer {peer_id}")
54-
55-
session = self._peers[peer_id]
56-
session.end()
57-
self._peers.pop(peer_id)
51+
self._sessions[peer_id].end()
52+
self._sessions.pop(peer_id)
5853

59-
if len(self._peers) == 0:
60-
self._leave_event.set()
61-
self.done = True
54+
if len(self._sessions) == 0:
55+
self._stop()

examples/transcription/transcription/notifier.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ def make_notifier(room_service: RoomService):
1111

1212
@notifier.on_server_notification
1313
def _(notification: AllowedNotification):
14-
print(f"Received notification {notification}")
1514
match notification:
1615
case ServerMessagePeerConnected(peer_id=peer_id, room_id=room_id):
1716
handle_peer_connected(peer_id, room_id)
@@ -20,18 +19,11 @@ def _(notification: AllowedNotification):
2019
handle_peer_disconnected(peer_id, room_id)
2120

2221
def handle_peer_connected(peer_id: str, room_id: str):
23-
if room_id != room_service.room_id:
24-
return
25-
26-
agent = room_service.create_agent()
27-
agent.on_peer_enter(peer_id)
22+
if room_id == room_service.room.id:
23+
room_service.agent.on_peer_enter(peer_id)
2824

2925
def handle_peer_disconnected(peer_id: str, room_id: str):
30-
if room_id != room_service.room_id:
31-
return
32-
33-
agent = room_service.get_agent()
34-
if agent:
35-
agent.on_peer_leave(peer_id)
26+
if room_id == room_service.room.id:
27+
room_service.agent.on_peer_leave(peer_id)
3628

3729
return notifier
Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Callable
2-
3-
from fishjam import FishjamClient, FishjamNotifier, Room
1+
from fishjam import FishjamClient, Room
42
from fishjam.errors import NotFoundError
3+
from transcription.worker import BackgroundWorker
54

65
from .agent import TranscriptionAgent
76
from .config import FISHJAM_ID, FISHJAM_TOKEN, FISHJAM_URL
8-
from .transcription import TranscriptionSession
97

108
fishjam = FishjamClient(
119
FISHJAM_ID,
@@ -15,49 +13,27 @@
1513

1614

1715
class RoomService:
18-
def __init__(
19-
self,
20-
session_factory: Callable[[Callable[[str], None]], TranscriptionSession],
21-
):
22-
self.room_id = fishjam.create_room().id
23-
self._agent: TranscriptionAgent | None = None
24-
self._notifier = FishjamNotifier(
25-
FISHJAM_ID,
26-
FISHJAM_TOKEN,
27-
fishjam_url=FISHJAM_URL,
28-
)
29-
30-
def _make_session(peer_id: str):
31-
return session_factory(lambda t: self._handle_transcription(peer_id, t))
32-
33-
self._session_factory = _make_session
16+
def __init__(self, worker: BackgroundWorker):
17+
self._worker = worker
18+
self._create_room()
3419

3520
def get_room(self) -> Room:
3621
try:
37-
return fishjam.get_room(self.room_id)
22+
self.room = fishjam.get_room(self.room.id)
3823
except NotFoundError:
39-
self.clear()
40-
room = fishjam.create_room()
41-
self.room_id = room.id
42-
return room
43-
44-
def clear(self):
45-
fishjam.delete_room(self.room_id)
46-
47-
def _handle_transcription(self, peer_id: str, text: str):
48-
print(f"Peer {peer_id} in room {self.room_id} said: {text}")
49-
50-
def create_agent(self):
51-
if not self._agent or self._agent.done:
52-
self._agent = TranscriptionAgent(
53-
self.room_id,
54-
fishjam.create_agent(self.room_id),
55-
self._handle_transcription,
56-
self._session_factory,
57-
)
58-
return self._agent
24+
self._create_room()
25+
return self.room
26+
27+
def _create_room(self):
28+
self.room = fishjam.create_room()
29+
self._create_agent()
30+
31+
def _create_agent(self):
32+
self.agent = TranscriptionAgent(
33+
self.room.id,
34+
fishjam.create_agent(self.room.id),
35+
self._worker,
36+
)
5937

6038
def get_agent(self):
61-
if self._agent.done:
62-
self._agent = None
63-
return self._agent
39+
return self.agent

examples/transcription/transcription/transcription.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from asyncio import Event, Queue, Task, TaskGroup
1+
from asyncio import Event, Queue, TaskGroup
22
from contextlib import asynccontextmanager
33
from typing import Callable
44

@@ -17,7 +17,7 @@ def __init__(self, on_text: Callable[[str], None]):
1717
self._model = TRANSCRIPTION_MODEL
1818
self._on_text = on_text
1919

20-
async def start(self):
20+
async def start(self, peer_id: str):
2121
async with self._gemini.aio.live.connect(
2222
model=self._model,
2323
config=TRANSCRIPTION_CONFIG,
@@ -26,12 +26,13 @@ async def start(self):
2626
send_task = tg.create_task(self._send_loop(session))
2727
recv_task = tg.create_task(self._recv_loop(session))
2828

29-
print("Started transcription session")
29+
print(f"Started transcription session for peer {peer_id}")
3030

3131
await self._end_event.wait()
3232

3333
send_task.cancel()
3434
recv_task.cancel()
35+
print(f"Stopped transcription session for peer {peer_id}")
3536

3637
def transcribe(self, audio: bytes):
3738
self._audio_queue.put_nowait(audio)
@@ -48,27 +49,16 @@ async def _send_loop(self, session: AsyncSession):
4849

4950
async def _recv_loop(self, session: AsyncSession):
5051
while True:
52+
acc = ""
5153
async for res in session.receive():
5254
if (
5355
(content := res.server_content)
5456
and (transcription := content.input_transcription)
5557
and (text := transcription.text)
5658
):
57-
self._on_text(text)
59+
acc += text
5860

61+
if acc:
62+
self._on_text(acc)
5963

60-
@asynccontextmanager
61-
async def transcription_session_factory():
62-
async with TaskGroup() as tg:
63-
sessions: list[TranscriptionSession] = []
64-
65-
def make_session(on_text: Callable[[str], None]):
66-
session = TranscriptionSession(on_text)
67-
sessions.append(session)
68-
tg.create_task(session.start())
6964
return session
70-
71-
yield make_session
72-
73-
for session in sessions:
74-
session.end()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from asyncio import Task, TaskGroup
2+
from contextlib import asynccontextmanager
3+
from typing import Any, Coroutine
4+
5+
6+
class BackgroundWorker:
7+
def __init__(self, tg: TaskGroup) -> None:
8+
self._tg = tg
9+
self._tasks: set[Task[None]] = set()
10+
11+
def run_in_background(self, coro: Coroutine[Any, Any, None]):
12+
task = self._tg.create_task(coro)
13+
task.add_done_callback(self._remove_task)
14+
self._tasks.add(task)
15+
16+
def _remove_task(self, task: Task[None]):
17+
self._tasks.discard(task)
18+
19+
def cleanup(self):
20+
for task in self._tasks:
21+
task.cancel()
22+
self._tasks = set()
23+
24+
25+
@asynccontextmanager
26+
async def async_worker():
27+
async with TaskGroup() as tg:
28+
worker = BackgroundWorker(tg)
29+
yield worker
30+
worker.cleanup()

0 commit comments

Comments
 (0)