Skip to content

Commit 7fbd78c

Browse files
author
David Eigen
committed
stream_util changes
1 parent 9f18f9e commit 7fbd78c

File tree

3 files changed

+103
-70
lines changed

3 files changed

+103
-70
lines changed

clarifai/runners/models/model_servicer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc
44
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
55

6-
from ..utils.url_fetcher import ensure_urls_downloaded, map_stream
6+
from ..utils.stream_utils import readahead
7+
from ..utils.url_fetcher import ensure_urls_downloaded
78

89

910
class ModelServicer(service_pb2_grpc.V2Servicer):
@@ -68,7 +69,7 @@ def StreamModelOutputs(self,
6869

6970
# Download any urls that are not already bytes.
7071
def _download_urls_stream(requests):
71-
yield from map_stream(ensure_urls_downloaded, requests)
72+
return readahead(map(ensure_urls_downloaded, requests))
7273

7374
try:
7475
return self.model_class.stream(_download_urls_stream(request))

clarifai/runners/utils/stream_utils.py

Lines changed: 91 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
import io
2+
import queue
23

3-
import requests
4+
import threading
5+
from concurrent.futures import ThreadPoolExecutor
46

57
MB = 1024 * 1024
68

79

8-
class BufferStream(io.RawIOBase):
10+
class StreamingChunksReader(io.RawIOBase):
911
'''
10-
A buffer that reads data from a chunked stream and provides a file-like interface for reading.
12+
A buffered reader that reads data from an iterator yielding chunks of bytes, used
13+
to provide file-like access to a streaming data source.
1114
12-
:param chunk_iterator: An iterator that yields chunks of data (bytes)
13-
'''
15+
:param chunk_iterator: An iterator that yields chunks of data (bytes)
16+
'''
1417

1518
def __init__(self, chunk_iterator):
19+
"""
20+
Args:
21+
chunk_iterator (iterator): An iterator that yields chunks of bytes.
22+
"""
1623
self._chunk_iterator = chunk_iterator
1724
self.response = None
1825
self.buffer = b''
19-
self.file_pos = 0
2026
self.b_pos = 0
2127
self._eof = False
2228

23-
#### read() methods
24-
2529
def readable(self):
2630
return True
2731

@@ -36,7 +40,7 @@ def readinto(self, output_buf):
3640
self.b_pos = 0
3741

3842
# copy data to output buffer
39-
n = min(len(output_buf), len(self.buffer - self.b_pos))
43+
n = min(len(output_buf), len(self.buffer) - self.b_pos)
4044
assert n > 0
4145

4246
output_buf[:n] = self.buffer[self.b_pos:self.b_pos + n]
@@ -52,16 +56,21 @@ def readinto(self, output_buf):
5256
return 0
5357

5458

55-
class SeekableBufferStream(io.RawIOBase):
56-
'''
57-
EXPERIMENTAL
58-
A buffer that reads data from a chunked stream and provides a file-like interface for reading.
59+
class SeekableStreamingChunksReader(io.RawIOBase):
60+
"""
61+
A buffered reader that reads data from an iterator yielding chunks of bytes, used
62+
to provide file-like access to a streaming data source.
5963
60-
:param chunk_iterator: An iterator that yields chunks of data (bytes)
61-
:param buffer_size: The maximum size of the buffer in bytes
62-
'''
64+
This class allows supports limited seeking to positions within the stream, by buffering
65+
buffering chunks internally and supporting basic seek operations within the buffer.
66+
"""
6367

6468
def __init__(self, chunk_iterator, buffer_size=100 * MB):
69+
"""
70+
Args:
71+
chunk_iterator (iterator): An iterator that yields chunks of bytes.
72+
buffer_size (int): Maximum buffer size in bytes before old chunks are discarded.
73+
"""
6574
self._chunk_iterator = chunk_iterator
6675
self.buffer_size = buffer_size
6776
self.buffer_vec = []
@@ -76,6 +85,15 @@ def readable(self):
7685
return True
7786

7887
def readinto(self, output_buf):
88+
"""
89+
Read data into the given buffer.
90+
91+
Args:
92+
output_buf (bytearray): Buffer to read data into.
93+
94+
Returns:
95+
int: Number of bytes read.
96+
"""
7997
if self._eof:
8098
return 0
8199

@@ -107,7 +125,7 @@ def readinto(self, output_buf):
107125
def _load_next_chunk(self, check_bounds=True):
108126
self.buffer_vec.append(next(self._chunk_iterator))
109127
total = sum(len(chunk) for chunk in self.buffer_vec)
110-
while total > self.buffer_size:
128+
while total > self.buffer_size and len(self.buffer_vec) > 1: # keep at least the last chunk
111129
chunk = self.buffer_vec.pop(0)
112130
total -= len(chunk)
113131
self.vec_pos -= 1
@@ -123,15 +141,27 @@ def tell(self):
123141
return self.file_pos
124142

125143
def seek(self, offset, whence=io.SEEK_SET):
126-
#printerr(f"seek(offset={offset}, whence={('SET', 'CUR', 'END')[whence]})")
127-
# convert to offset from start of file stream
144+
"""
145+
Seek to a new position in the buffered stream.
146+
147+
Args:
148+
offset (int): The offset to seek to.
149+
whence (int): The reference position (SEEK_SET, SEEK_CUR).
150+
SEEK_END is not supported.
151+
152+
Returns:
153+
int: The new file position.
154+
155+
Raises:
156+
ValueError: If an invalid `whence` value is provided.
157+
IOError: If seeking before the start of the buffer.
158+
"""
128159
if whence == io.SEEK_SET:
129160
seek_pos = offset
130161
elif whence == io.SEEK_CUR:
131162
seek_pos = self.file_pos + offset
132163
elif whence == io.SEEK_END:
133-
self._seek_to_end()
134-
seek_pos = self.file_pos + offset
164+
raise ValueError('SEEK_END is not supported')
135165
else:
136166
raise ValueError(f"Invalid whence: {whence}")
137167

@@ -163,37 +193,47 @@ def seek(self, offset, whence=io.SEEK_SET):
163193

164194
return self.file_pos
165195

166-
def _seek_to_end(self):
167-
try:
168-
# skip positions to end of the current buffer vec
169-
if self.b_pos > 0:
170-
self.file_pos += len(self.buffer_vec[self.vec_pos]) - self.b_pos
171-
self.vec_pos += 1
172-
self.b_pos = 0
173-
# keep loading chunks until EOF
174-
while True:
175-
while self.vec_pos < len(self.buffer_vec):
176-
self.file_pos += len(self.buffer_vec[self.vec_pos])
177-
self.vec_pos += 1
178-
self._load_next_chunk(check_bounds=False)
179-
except StopIteration:
180-
pass
181-
# advance to end of buffer vec
182-
while self.vec_pos < len(self.buffer_vec):
183-
self.file_pos += len(self.buffer_vec[self.vec_pos])
184-
self.vec_pos += 1
185196

197+
def readahead(iterator, n=1, daemon=True):
198+
"""
199+
Iterator wrapper that reads ahead from the underlying iterator, using a background thread.
200+
201+
:Args:
202+
iterator (iterator): The iterator to read from.
203+
n (int): The maximum number of items to read ahead.
204+
daemon (bool): Whether the background thread should be a daemon thread.
205+
"""
206+
q = queue.Queue(maxsize=n)
207+
_sentinel = object()
208+
209+
def _read():
210+
for x in iterator:
211+
q.put(x)
212+
q.put(_sentinel)
186213

187-
class URLStream(BufferStream):
214+
t = threading.Thread(target=_read, daemon=daemon)
215+
t.start()
216+
while True:
217+
x = q.get()
218+
if x is _sentinel:
219+
break
220+
yield x
188221

189-
def __init__(self, url, chunk_size=1 * MB, buffer_size=10 * MB, requests_kwargs={}):
190-
self.url = url
191-
self.chunk_size = chunk_size
192-
self.response = requests.get(self.url, stream=True, **requests_kwargs)
193-
self.response.raise_for_status()
194-
super().__init__(
195-
self.response.iter_content(chunk_size=self.chunk_size), buffer_size=buffer_size)
196222

197-
def close(self):
198-
super().close()
199-
self.response.close()
223+
def map(f, iterator, parallel=1):
224+
'''
225+
Apply a function to each item in an iterator, optionally using multiple threads.
226+
Similar to the built-in `map` function, but with support for parallel execution.
227+
'''
228+
if parallel < 1:
229+
return map(f, iterator)
230+
with ThreadPoolExecutor(max_workers=parallel) as executor:
231+
futures = []
232+
for i in range(parallel):
233+
futures.append(executor.submit(f, next(iterator)))
234+
for r in iterator:
235+
res = futures.pop(0).result()
236+
futures.append(executor.submit(f, r)) # start computing next result before yielding this one
237+
yield res
238+
for f in futures:
239+
yield f.result()

clarifai/runners/utils/url_fetcher.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import concurrent.futures
2+
from typing import Iterable
23

34
import fsspec
45

6+
from clarifai.runners.utils import MB
57
from clarifai.utils.logging import logger
68

79

@@ -50,20 +52,10 @@ def ensure_urls_downloaded(request, max_threads=128):
5052
return request
5153

5254

53-
def map_stream(f, it, parallel=1):
54-
'''
55-
Applies f to each element of it, yielding the results in order.
56-
If parallel >= 1, uses a ThreadPoolExecutor to apply f in parallel to the current thread.
57-
'''
58-
if parallel < 1:
59-
return map(f, it)
60-
with ThreadPoolExecutor(max_workers=parallel) as executor:
61-
futures = []
62-
for i in range(parallel):
63-
futures.append(executor.submit(f, next(it)))
64-
for r in it:
65-
res = futures.pop(0).result()
66-
futures.append(executor.submit(f, r)) # start computing next result before yielding this one
67-
yield res
68-
for f in futures:
69-
yield f.result()
55+
def stream_url(url: str, chunk_size: int = 1 * MB) -> Iterable[bytes]:
56+
"""
57+
Opens a stream of byte chunks from a URL.
58+
"""
59+
# block_size=0 means that the file is streamed
60+
with fsspec.open(url, 'rb', block_size=0) as f:
61+
yield from iter(lambda: f.read(chunk_size), b'')

0 commit comments

Comments
 (0)