88given as a folder on local disk
99"""
1010
11+ import importlib
1112import os
13+ import threading
1214import time
15+ from concurrent .futures import FIRST_COMPLETED
16+ from concurrent .futures import ThreadPoolExecutor
17+ from concurrent .futures import wait
1318from typing import Dict
1419from typing import List
1520
@@ -79,6 +84,9 @@ def __init__(self, **kwargs):
7984 self .blob_list = {}
8085 self .rows_seen = 0
8186 self .blobs_seen = 0
87+ self ._stats_lock = threading .Lock ()
88+ cpu_count = os .cpu_count () or 1
89+ self ._max_workers = max (1 , min (8 , (cpu_count + 1 ) // 2 ))
8290
8391 def read_blob (
8492 self , * , blob_name : str , decoder , just_schema = False , projection = None , selection = None
@@ -113,7 +121,8 @@ def read_blob(
113121 OSError:
114122 If an I/O error occurs while reading the file.
115123 """
116- from opteryx .compiled .io .disk_reader import read_file_mmap
124+ disk_reader = importlib .import_module ("opteryx.compiled.io.disk_reader" )
125+ read_file_mmap = getattr (disk_reader , "read_file_mmap" )
117126
118127 # from opteryx.compiled.io.disk_reader import unmap_memory
119128 # Read using mmap for maximum speed
@@ -131,14 +140,17 @@ def read_blob(
131140 use_threads = True ,
132141 )
133142
134- self .statistics .bytes_read += len (mv )
143+ with self ._stats_lock :
144+ self .statistics .bytes_read += len (mv )
135145
136146 if not just_schema :
137147 stats = self .read_blob_statistics (
138148 blob_name = blob_name , blob_bytes = mv , decoder = decoder
139149 )
140- if self .relation_statistics is None :
141- self .relation_statistics = stats
150+ if stats is not None :
151+ with self ._stats_lock :
152+ if self .relation_statistics is None :
153+ self .relation_statistics = stats
142154
143155 return result
144156 finally :
@@ -200,54 +212,150 @@ def read_dataset(
200212 )
201213 self .statistics .time_pruning_blobs += time .monotonic_ns () - start
202214
203- remaining_rows = limit if limit is not None else float ("inf" )
204-
205- for blob_name in blob_names :
206- decoder = get_decoder (blob_name )
207- try :
208- if not just_schema :
209- num_rows , _ , raw_size , decoded = self .read_blob (
210- blob_name = blob_name ,
211- decoder = decoder ,
212- just_schema = False ,
213- projection = columns ,
214- selection = predicates ,
215- )
216-
217- # push limits to the reader
218- if decoded .num_rows > remaining_rows :
219- decoded = decoded .slice (0 , remaining_rows )
220- remaining_rows -= decoded .num_rows
221-
222- self .statistics .rows_seen += num_rows
223- self .rows_seen += num_rows
224- self .blobs_seen += 1
225- self .statistics .bytes_raw += raw_size
226- yield decoded
227-
228- # if we have read all the rows we need to stop
229- if remaining_rows <= 0 :
230- break
231- else :
215+ if just_schema :
216+ for blob_name in blob_names :
217+ try :
218+ decoder = get_decoder (blob_name )
232219 schema = self .read_blob (
233220 blob_name = blob_name ,
234221 decoder = decoder ,
235222 just_schema = True ,
236223 )
237- # if we have more than one blob we need to estimate the row count
238224 blob_count = len (blob_names )
239225 if schema .row_count_metric and blob_count > 1 :
240226 schema .row_count_estimate = schema .row_count_metric * blob_count
241227 schema .row_count_metric = None
242228 self .statistics .estimated_row_count += schema .row_count_estimate
243229 yield schema
230+ except UnsupportedFileTypeError :
231+ continue
232+ except pyarrow .ArrowInvalid :
233+ with self ._stats_lock :
234+ self .statistics .unreadable_data_blobs += 1
235+ except Exception as err :
236+ raise DataError (f"Unable to read file { blob_name } ({ err } )" ) from err
237+ return
244238
245- except UnsupportedFileTypeError :
246- pass # Skip unsupported file types
247- except pyarrow .ArrowInvalid :
248- self .statistics .unreadable_data_blobs += 1
249- except Exception as err :
250- raise DataError (f"Unable to read file { blob_name } ({ err } )" ) from err
239+ remaining_rows = limit if limit is not None else float ("inf" )
240+ last_morsel = None
241+
242+ def process_result (num_rows , raw_size , decoded ):
243+ nonlocal remaining_rows , last_morsel
244+ if decoded .num_rows > remaining_rows :
245+ decoded = decoded .slice (0 , remaining_rows )
246+ remaining_rows -= decoded .num_rows
247+
248+ self .statistics .rows_seen += num_rows
249+ self .rows_seen += num_rows
250+ self .blobs_seen += 1
251+ self .statistics .bytes_raw += raw_size
252+ last_morsel = decoded
253+ return decoded
254+
255+ max_workers = min (self ._max_workers , len (blob_names )) or 1
256+
257+ if max_workers <= 1 :
258+ for blob_name in blob_names :
259+ try :
260+ num_rows , _ , raw_size , decoded = self ._read_blob_task (
261+ blob_name ,
262+ columns ,
263+ predicates ,
264+ )
265+ except UnsupportedFileTypeError :
266+ continue
267+ except pyarrow .ArrowInvalid :
268+ with self ._stats_lock :
269+ self .statistics .unreadable_data_blobs += 1
270+ continue
271+ except Exception as err :
272+ raise DataError (f"Unable to read file { blob_name } ({ err } )" ) from err
273+
274+ if remaining_rows <= 0 :
275+ break
276+
277+ decoded = process_result (num_rows , raw_size , decoded )
278+ yield decoded
279+
280+ if remaining_rows <= 0 :
281+ break
282+ else :
283+ blob_iter = iter (blob_names )
284+ pending = {}
285+
286+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
287+ for _ in range (max_workers ):
288+ try :
289+ blob_name = next (blob_iter )
290+ except StopIteration :
291+ break
292+ future = executor .submit (
293+ self ._read_blob_task ,
294+ blob_name ,
295+ columns ,
296+ predicates ,
297+ )
298+ pending [future ] = blob_name
299+
300+ while pending :
301+ done , _ = wait (pending .keys (), return_when = FIRST_COMPLETED )
302+ for future in done :
303+ blob_name = pending .pop (future )
304+ try :
305+ num_rows , _ , raw_size , decoded = future .result ()
306+ except UnsupportedFileTypeError :
307+ pass
308+ except pyarrow .ArrowInvalid :
309+ with self ._stats_lock :
310+ self .statistics .unreadable_data_blobs += 1
311+ except Exception as err :
312+ for remaining_future in list (pending ):
313+ remaining_future .cancel ()
314+ raise DataError (f"Unable to read file { blob_name } ({ err } )" ) from err
315+ else :
316+ if remaining_rows > 0 :
317+ decoded = process_result (num_rows , raw_size , decoded )
318+ yield decoded
319+ if remaining_rows <= 0 :
320+ for remaining_future in list (pending ):
321+ remaining_future .cancel ()
322+ pending .clear ()
323+ break
324+
325+ if remaining_rows <= 0 :
326+ break
327+
328+ try :
329+ next_blob = next (blob_iter )
330+ except StopIteration :
331+ continue
332+ future = executor .submit (
333+ self ._read_blob_task ,
334+ next_blob ,
335+ columns ,
336+ predicates ,
337+ )
338+ pending [future ] = next_blob
339+
340+ if remaining_rows <= 0 :
341+ break
342+
343+ if last_morsel is not None :
344+ self .statistics .columns_read += last_morsel .num_columns
345+ elif columns :
346+ self .statistics .columns_read += len (columns )
347+ elif self .schema :
348+ self .statistics .columns_read += len (self .schema .columns )
349+
350+ def _read_blob_task (self , blob_name : str , columns , predicates ):
351+ decoder = get_decoder (blob_name )
352+ return self .read_blob (
353+ blob_name = blob_name ,
354+ decoder = decoder ,
355+ just_schema = False ,
356+ projection = columns ,
357+ selection = predicates ,
358+ )
251359
252360 def get_dataset_schema (self ) -> RelationSchema :
253361 """
0 commit comments