Skip to content

Commit b696fbe

Browse files
authored
Merge pull request #2883 from mabel-dev/#2877
Performance Review
2 parents e75226f + 151aae4 commit b696fbe

File tree

7 files changed

+204
-122
lines changed

7 files changed

+204
-122
lines changed

opteryx/__version__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# THIS FILE IS AUTOMATICALLY UPDATED DURING THE BUILD PROCESS
22
# DO NOT EDIT THIS FILE DIRECTLY
33

4-
__build__ = 1713
4+
__build__ = 1715
55
__author__ = "@joocer"
6-
__version__ = "0.26.0-beta.1713"
6+
__version__ = "0.26.0-beta.1715"
77

88
# Store the version here so:
99
# 1) we don't load dependencies by storing it in __init__.py

opteryx/connectors/disk_connector.py

Lines changed: 142 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
given as a folder on local disk
99
"""
1010

11+
import importlib
1112
import os
13+
import threading
1214
import time
15+
from concurrent.futures import FIRST_COMPLETED
16+
from concurrent.futures import ThreadPoolExecutor
17+
from concurrent.futures import wait
1318
from typing import Dict
1419
from typing import List
1520

@@ -79,6 +84,9 @@ def __init__(self, **kwargs):
7984
self.blob_list = {}
8085
self.rows_seen = 0
8186
self.blobs_seen = 0
87+
self._stats_lock = threading.Lock()
88+
cpu_count = os.cpu_count() or 1
89+
self._max_workers = max(1, min(8, (cpu_count + 1) // 2))
8290

8391
def read_blob(
8492
self, *, blob_name: str, decoder, just_schema=False, projection=None, selection=None
@@ -113,7 +121,8 @@ def read_blob(
113121
OSError:
114122
If an I/O error occurs while reading the file.
115123
"""
116-
from opteryx.compiled.io.disk_reader import read_file_mmap
124+
disk_reader = importlib.import_module("opteryx.compiled.io.disk_reader")
125+
read_file_mmap = getattr(disk_reader, "read_file_mmap")
117126

118127
# from opteryx.compiled.io.disk_reader import unmap_memory
119128
# Read using mmap for maximum speed
@@ -131,14 +140,17 @@ def read_blob(
131140
use_threads=True,
132141
)
133142

134-
self.statistics.bytes_read += len(mv)
143+
with self._stats_lock:
144+
self.statistics.bytes_read += len(mv)
135145

136146
if not just_schema:
137147
stats = self.read_blob_statistics(
138148
blob_name=blob_name, blob_bytes=mv, decoder=decoder
139149
)
140-
if self.relation_statistics is None:
141-
self.relation_statistics = stats
150+
if stats is not None:
151+
with self._stats_lock:
152+
if self.relation_statistics is None:
153+
self.relation_statistics = stats
142154

143155
return result
144156
finally:
@@ -200,54 +212,144 @@ def read_dataset(
200212
)
201213
self.statistics.time_pruning_blobs += time.monotonic_ns() - start
202214

203-
remaining_rows = limit if limit is not None else float("inf")
204-
205-
for blob_name in blob_names:
206-
decoder = get_decoder(blob_name)
207-
try:
208-
if not just_schema:
209-
num_rows, _, raw_size, decoded = self.read_blob(
210-
blob_name=blob_name,
211-
decoder=decoder,
212-
just_schema=False,
213-
projection=columns,
214-
selection=predicates,
215-
)
216-
217-
# push limits to the reader
218-
if decoded.num_rows > remaining_rows:
219-
decoded = decoded.slice(0, remaining_rows)
220-
remaining_rows -= decoded.num_rows
221-
222-
self.statistics.rows_seen += num_rows
223-
self.rows_seen += num_rows
224-
self.blobs_seen += 1
225-
self.statistics.bytes_raw += raw_size
226-
yield decoded
227-
228-
# if we have read all the rows we need to stop
229-
if remaining_rows <= 0:
230-
break
231-
else:
215+
if just_schema:
216+
for blob_name in blob_names:
217+
try:
218+
decoder = get_decoder(blob_name)
232219
schema = self.read_blob(
233220
blob_name=blob_name,
234221
decoder=decoder,
235222
just_schema=True,
236223
)
237-
# if we have more than one blob we need to estimate the row count
238224
blob_count = len(blob_names)
239225
if schema.row_count_metric and blob_count > 1:
240226
schema.row_count_estimate = schema.row_count_metric * blob_count
241227
schema.row_count_metric = None
242228
self.statistics.estimated_row_count += schema.row_count_estimate
243229
yield schema
230+
except UnsupportedFileTypeError:
231+
continue
232+
except pyarrow.ArrowInvalid:
233+
with self._stats_lock:
234+
self.statistics.unreadable_data_blobs += 1
235+
except Exception as err:
236+
raise DataError(f"Unable to read file {blob_name} ({err})") from err
237+
return
244238

245-
except UnsupportedFileTypeError:
246-
pass # Skip unsupported file types
247-
except pyarrow.ArrowInvalid:
248-
self.statistics.unreadable_data_blobs += 1
249-
except Exception as err:
250-
raise DataError(f"Unable to read file {blob_name} ({err})") from err
239+
remaining_rows = limit if limit is not None else float("inf")
240+
241+
def process_result(num_rows, raw_size, decoded):
242+
nonlocal remaining_rows
243+
if decoded.num_rows > remaining_rows:
244+
decoded = decoded.slice(0, remaining_rows)
245+
remaining_rows -= decoded.num_rows
246+
247+
self.statistics.rows_seen += num_rows
248+
self.rows_seen += num_rows
249+
self.blobs_seen += 1
250+
self.statistics.bytes_raw += raw_size
251+
return decoded
252+
253+
max_workers = min(self._max_workers, len(blob_names)) or 1
254+
255+
if max_workers <= 1:
256+
for blob_name in blob_names:
257+
try:
258+
num_rows, _, raw_size, decoded = self._read_blob_task(
259+
blob_name,
260+
columns,
261+
predicates,
262+
)
263+
except UnsupportedFileTypeError:
264+
continue
265+
except pyarrow.ArrowInvalid:
266+
with self._stats_lock:
267+
self.statistics.unreadable_data_blobs += 1
268+
continue
269+
except Exception as err:
270+
raise DataError(f"Unable to read file {blob_name} ({err})") from err
271+
272+
if remaining_rows <= 0:
273+
break
274+
275+
decoded = process_result(num_rows, raw_size, decoded)
276+
yield decoded
277+
278+
if remaining_rows <= 0:
279+
break
280+
else:
281+
blob_iter = iter(blob_names)
282+
pending = {}
283+
284+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
285+
for _ in range(max_workers):
286+
try:
287+
blob_name = next(blob_iter)
288+
except StopIteration:
289+
break
290+
future = executor.submit(
291+
self._read_blob_task,
292+
blob_name,
293+
columns,
294+
predicates,
295+
)
296+
pending[future] = blob_name
297+
298+
while pending:
299+
done, _ = wait(pending.keys(), return_when=FIRST_COMPLETED)
300+
for future in done:
301+
blob_name = pending.pop(future)
302+
try:
303+
num_rows, _, raw_size, decoded = future.result()
304+
except UnsupportedFileTypeError:
305+
pass
306+
except pyarrow.ArrowInvalid:
307+
with self._stats_lock:
308+
self.statistics.unreadable_data_blobs += 1
309+
except Exception as err:
310+
for remaining_future in list(pending):
311+
remaining_future.cancel()
312+
raise DataError(f"Unable to read file {blob_name} ({err})") from err
313+
else:
314+
if remaining_rows > 0:
315+
decoded = process_result(num_rows, raw_size, decoded)
316+
yield decoded
317+
if remaining_rows <= 0:
318+
for remaining_future in list(pending):
319+
remaining_future.cancel()
320+
pending.clear()
321+
break
322+
323+
if remaining_rows <= 0:
324+
break
325+
326+
try:
327+
next_blob = next(blob_iter)
328+
except StopIteration:
329+
continue
330+
future = executor.submit(
331+
self._read_blob_task,
332+
next_blob,
333+
columns,
334+
predicates,
335+
)
336+
pending[future] = next_blob
337+
338+
if remaining_rows <= 0:
339+
break
340+
341+
# column-level statistics are recorded by the read node after morsels
342+
# leave connector-level accounting to avoid double counting
343+
344+
def _read_blob_task(self, blob_name: str, columns, predicates):
345+
decoder = get_decoder(blob_name)
346+
return self.read_blob(
347+
blob_name=blob_name,
348+
decoder=decoder,
349+
just_schema=False,
350+
projection=columns,
351+
selection=predicates,
352+
)
251353

252354
def get_dataset_schema(self) -> RelationSchema:
253355
"""

opteryx/operators/aggregate_and_group_node.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def __init__(self, properties: QueryProperties, **parameters):
6868
self.column_map, self.aggregate_functions = build_aggregations(self.aggregates)
6969

7070
self.buffer = []
71-
self.max_buffer_size = 250 # Buffer size before partial aggregation (kept for future parallelization)
71+
self.max_buffer_size = (
72+
250 # Buffer size before partial aggregation (kept for future parallelization)
73+
)
7274
self._partial_aggregated = False # Track if we've done a partial aggregation
7375
self._disable_partial_agg = False # Can disable if partial agg isn't helping
7476

@@ -231,7 +233,7 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
231233

232234
groups = table.group_by(self.group_by_columns)
233235
groups = groups.aggregate(self.aggregate_functions)
234-
236+
235237
# Check if partial aggregation is effective
236238
# If we're not reducing the row count significantly, stop doing partial aggs
237239
reduction_ratio = groups.num_rows / table.num_rows if table.num_rows > 0 else 1

0 commit comments

Comments
 (0)