Skip to content

Commit a766014

Browse files
authored
Add auto chunking for text stream writer (#361)
1 parent fe1d072 commit a766014

File tree

3 files changed

+21
-39
lines changed

3 files changed

+21
-39
lines changed

livekit-rtc/livekit/rtc/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
from .synchronizer import AVSynchronizer
8282
from .data_stream import (
8383
TextStreamInfo,
84-
TextStreamUpdate,
8584
ByteStreamInfo,
8685
TextStreamReader,
8786
TextStreamWriter,
@@ -155,7 +154,6 @@
155154
"EventEmitter",
156155
"combine_audio_frames",
157156
"AVSynchronizer",
158-
"TextStreamUpdate",
159157
"TextStreamInfo",
160158
"ByteStreamInfo",
161159
"TextStreamReader",

livekit-rtc/livekit/rtc/data_stream.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ._proto import ffi_pb2 as proto_ffi
2525
from ._proto import room_pb2 as proto_room
2626
from ._ffi_client import FfiClient
27-
27+
from ._utils import split_utf8
2828
from typing import TYPE_CHECKING
2929

3030
if TYPE_CHECKING:
@@ -47,14 +47,6 @@ class BaseStreamInfo:
4747
@dataclass
4848
class TextStreamInfo(BaseStreamInfo):
4949
attachments: List[str]
50-
pass
51-
52-
53-
@dataclass
54-
class TextStreamUpdate:
55-
current: str
56-
index: int
57-
collected: str
5850

5951

6052
class TextStreamReader:
@@ -81,31 +73,24 @@ async def _on_chunk_update(self, chunk: proto_DataStream.Chunk):
8173
async def _on_stream_close(self, trailer: proto_DataStream.Trailer):
8274
await self._queue.put(None)
8375

84-
def __aiter__(self) -> AsyncIterator[TextStreamUpdate]:
76+
def __aiter__(self) -> AsyncIterator[str]:
8577
return self
8678

87-
async def __anext__(self) -> TextStreamUpdate:
79+
async def __anext__(self) -> str:
8880
item = await self._queue.get()
8981
if item is None:
9082
raise StopAsyncIteration
9183
decodedStr = item.content.decode()
92-
93-
self._chunks[item.chunk_index] = item
94-
chunk_list = list(self._chunks.values())
95-
chunk_list.sort(key=lambda chunk: chunk.chunk_index)
96-
collected: str = "".join(map(lambda chunk: chunk.content.decode(), chunk_list))
97-
return TextStreamUpdate(
98-
current=decodedStr, index=item.chunk_index, collected=collected
99-
)
84+
return decodedStr
10085

10186
@property
10287
def info(self) -> TextStreamInfo:
10388
return self._info
10489

10590
async def read_all(self) -> str:
10691
final_string = ""
107-
async for update in self:
108-
final_string = update.collected
92+
async for chunk in self:
93+
final_string += chunk
10994
return final_string
11095

11196

@@ -286,20 +271,20 @@ def __init__(
286271
attributes=dict(self._header.attributes),
287272
attachments=list(self._header.text_header.attached_stream_ids),
288273
)
274+
self._write_lock = asyncio.Lock()
289275

290-
async def write(self, text: str, chunk_index: int | None = None):
291-
content = text.encode()
292-
if len(content) > STREAM_CHUNK_SIZE:
293-
raise ValueError("maximum chunk size exceeded")
294-
if chunk_index is None:
295-
chunk_index = self._next_chunk_index
296-
self._next_chunk_index += 1
297-
chunk_msg = proto_DataStream.Chunk(
298-
stream_id=self._header.stream_id,
299-
chunk_index=chunk_index,
300-
content=content,
301-
)
302-
await self._send_chunk(chunk_msg)
276+
async def write(self, text: str):
277+
async with self._write_lock:
278+
for chunk in split_utf8(text, STREAM_CHUNK_SIZE):
279+
content = chunk.encode()
280+
chunk_index = self._next_chunk_index
281+
self._next_chunk_index += 1
282+
chunk_msg = proto_DataStream.Chunk(
283+
stream_id=self._header.stream_id,
284+
chunk_index=chunk_index,
285+
content=content,
286+
)
287+
await self._send_chunk(chunk_msg)
303288

304289
@property
305290
def info(self) -> TextStreamInfo:

livekit-rtc/livekit/rtc/participant.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ._proto.track_pb2 import (
3535
ParticipantTrackPermission,
3636
)
37-
from ._utils import BroadcastQueue, split_utf8
37+
from ._utils import BroadcastQueue
3838
from .track import LocalTrack
3939
from .track_publication import (
4040
LocalTrackPublication,
@@ -623,8 +623,7 @@ async def send_text(
623623
total_size=total_size,
624624
)
625625

626-
for chunk in split_utf8(text, STREAM_CHUNK_SIZE):
627-
await writer.write(chunk)
626+
await writer.write(text)
628627
await writer.aclose()
629628

630629
return writer.info

0 commit comments

Comments
 (0)