2424from ._proto import ffi_pb2 as proto_ffi
2525from ._proto import room_pb2 as proto_room
2626from ._ffi_client import FfiClient
27-
27+ from . _utils import split_utf8
2828from typing import TYPE_CHECKING
2929
3030if TYPE_CHECKING :
@@ -47,14 +47,6 @@ class BaseStreamInfo:
4747@dataclass
4848class TextStreamInfo (BaseStreamInfo ):
4949 attachments : List [str ]
50- pass
51-
52-
53- @dataclass
54- class TextStreamUpdate :
55- current : str
56- index : int
57- collected : str
5850
5951
6052class TextStreamReader :
@@ -81,31 +73,24 @@ async def _on_chunk_update(self, chunk: proto_DataStream.Chunk):
8173 async def _on_stream_close (self , trailer : proto_DataStream .Trailer ):
8274 await self ._queue .put (None )
8375
84- def __aiter__ (self ) -> AsyncIterator [TextStreamUpdate ]:
76+ def __aiter__ (self ) -> AsyncIterator [str ]:
8577 return self
8678
87- async def __anext__ (self ) -> TextStreamUpdate :
79+ async def __anext__ (self ) -> str :
8880 item = await self ._queue .get ()
8981 if item is None :
9082 raise StopAsyncIteration
9183 decodedStr = item .content .decode ()
92-
93- self ._chunks [item .chunk_index ] = item
94- chunk_list = list (self ._chunks .values ())
95- chunk_list .sort (key = lambda chunk : chunk .chunk_index )
96- collected : str = "" .join (map (lambda chunk : chunk .content .decode (), chunk_list ))
97- return TextStreamUpdate (
98- current = decodedStr , index = item .chunk_index , collected = collected
99- )
84+ return decodedStr
10085
10186 @property
10287 def info (self ) -> TextStreamInfo :
10388 return self ._info
10489
10590 async def read_all (self ) -> str :
10691 final_string = ""
107- async for update in self :
108- final_string = update . collected
92+ async for chunk in self :
93+ final_string += chunk
10994 return final_string
11095
11196
@@ -286,20 +271,20 @@ def __init__(
286271 attributes = dict (self ._header .attributes ),
287272 attachments = list (self ._header .text_header .attached_stream_ids ),
288273 )
274+ self ._write_lock = asyncio .Lock ()
289275
290- async def write (self , text : str , chunk_index : int | None = None ):
291- content = text .encode ()
292- if len (content ) > STREAM_CHUNK_SIZE :
293- raise ValueError ("maximum chunk size exceeded" )
294- if chunk_index is None :
295- chunk_index = self ._next_chunk_index
296- self ._next_chunk_index += 1
297- chunk_msg = proto_DataStream .Chunk (
298- stream_id = self ._header .stream_id ,
299- chunk_index = chunk_index ,
300- content = content ,
301- )
302- await self ._send_chunk (chunk_msg )
276+ async def write (self , text : str ):
277+ async with self ._write_lock :
278+ for chunk in split_utf8 (text , STREAM_CHUNK_SIZE ):
279+ content = chunk .encode ()
280+ chunk_index = self ._next_chunk_index
281+ self ._next_chunk_index += 1
282+ chunk_msg = proto_DataStream .Chunk (
283+ stream_id = self ._header .stream_id ,
284+ chunk_index = chunk_index ,
285+ content = content ,
286+ )
287+ await self ._send_chunk (chunk_msg )
303288
304289 @property
305290 def info (self ) -> TextStreamInfo :
0 commit comments