Skip to content

Commit 945862e

Browse files
authored
Merge pull request #2876 from mabel-dev/clickbench-performance-regression-investigation-1
fix agg bug
2 parents c64dcba + 6f24deb commit 945862e

File tree

9 files changed

+286
-83
lines changed

9 files changed

+286
-83
lines changed

opteryx/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import datetime
2121
import os
22+
import random
2223
import time
2324
import warnings
2425
import platform
@@ -34,7 +35,7 @@
3435
getcontext().prec = 28
3536

3637
# end-of-stream marker
37-
EOS: int = 0
38+
EOS: int = random.randint(-(2**63), 2**63 - 1)
3839

3940

4041
def is_mac() -> bool: # pragma: no cover

opteryx/__version__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# THIS FILE IS AUTOMATICALLY UPDATED DURING THE BUILD PROCESS
22
# DO NOT EDIT THIS FILE DIRECTLY
33

4-
__build__ = 1706
4+
__build__ = 1707
55
__author__ = "@joocer"
6-
__version__ = "0.26.0-beta.1706"
6+
__version__ = "0.26.0-beta.1707"
77

88
# Store the version here so:
99
# 1) we don't load dependencies by storing it in __init__.py

opteryx/operators/aggregate_and_group_node.py

Lines changed: 116 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(self, properties: QueryProperties, **parameters):
6868
self.column_map, self.aggregate_functions = build_aggregations(self.aggregates)
6969

7070
self.buffer = []
71-
self.max_buffer_size = 50 # Process in chunks to avoid excessive memory usage
71+
self.max_buffer_size = 100 # Process in chunks to avoid excessive memory usage
72+
self._partial_aggregated = False # Track if we've done a partial aggregation
7273

7374
@property
7475
def config(self): # pragma: no cover
@@ -86,38 +87,122 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
8687
yield EOS
8788
return
8889

89-
# If we have partial results in buffer, do final aggregation
90-
if len(self.buffer) > 0:
91-
table = pyarrow.concat_tables(
92-
self.buffer,
93-
promote_options="permissive",
94-
)
90+
# Do final aggregation if we have buffered data
91+
table = pyarrow.concat_tables(
92+
self.buffer,
93+
promote_options="permissive",
94+
)
95+
# Only combine chunks if we haven't done partial aggregation yet
96+
# combine_chunks can fail after partial aggregation due to buffer structure
97+
if not self._partial_aggregated:
9598
table = table.combine_chunks()
99+
100+
# If we've done partial aggregations, the aggregate functions need adjusting
101+
# because columns like "*" have been renamed to "*_count"
102+
if self._partial_aggregated:
103+
# Build new aggregate functions for re-aggregating partial results
104+
adjusted_aggs = []
105+
adjusted_column_map = {}
106+
107+
for field_name, function, _count_options in self.aggregate_functions:
108+
# For COUNT aggregates, the column is now named "*_count" and we need to SUM it
109+
if function == "count":
110+
renamed_field = f"{field_name}_count"
111+
adjusted_aggs.append((renamed_field, "sum", None))
112+
# The final column will be named "*_count_sum", need to track for renaming
113+
for orig_name, mapped_name in self.column_map.items():
114+
if mapped_name == f"{field_name}_count":
115+
adjusted_column_map[orig_name] = f"{renamed_field}_sum"
116+
# For other aggregates, we can re-aggregate with the same function
117+
else:
118+
renamed_field = f"{field_name}_{function}".replace("_hash_", "_")
119+
# Some aggregates can be re-aggregated (sum, max, min)
120+
if function in ("sum", "max", "min", "hash_one", "all", "any"):
121+
adjusted_aggs.append((renamed_field, function, None))
122+
# Track the mapping: original -> intermediate -> final
123+
for orig_name, mapped_name in self.column_map.items():
124+
if mapped_name == renamed_field:
125+
# sum->sum, max->max, etc. means same name
126+
adjusted_column_map[orig_name] = (
127+
f"{renamed_field}_{function}".replace("_hash_", "_")
128+
)
129+
elif function == "mean":
130+
# For mean, just take one of the existing values (not ideal)
131+
adjusted_aggs.append((renamed_field, "hash_one", None))
132+
for orig_name, mapped_name in self.column_map.items():
133+
if mapped_name == renamed_field:
134+
adjusted_column_map[orig_name] = f"{renamed_field}_one"
135+
elif function == "hash_list":
136+
# For ARRAY_AGG, we need to flatten lists
137+
adjusted_aggs.append((renamed_field, "hash_list", None))
138+
for orig_name, mapped_name in self.column_map.items():
139+
if mapped_name == renamed_field:
140+
adjusted_column_map[orig_name] = f"{renamed_field}_list"
141+
else:
142+
# For other aggregates, take one value
143+
adjusted_aggs.append((renamed_field, "hash_one", None))
144+
for orig_name, mapped_name in self.column_map.items():
145+
if mapped_name == renamed_field:
146+
adjusted_column_map[orig_name] = f"{renamed_field}_one"
147+
148+
groups = table.group_by(self.group_by_columns)
149+
groups = groups.aggregate(adjusted_aggs)
150+
151+
# Use the adjusted column map for selecting/renaming
152+
groups = groups.select(list(adjusted_column_map.values()) + self.group_by_columns)
153+
groups = groups.rename_columns(
154+
list(adjusted_column_map.keys()) + self.group_by_columns
155+
)
156+
else:
96157
groups = table.group_by(self.group_by_columns)
97158
groups = groups.aggregate(self.aggregate_functions)
98-
self.buffer = [groups] # Replace buffer with final result
99-
100-
# Now buffer has the final aggregated result
101-
groups = self.buffer[0]
102-
103-
# do the secondary activities for ARRAY_AGG
104-
for node in get_all_nodes_of_type(self.aggregates, select_nodes=(NodeType.AGGREGATOR,)):
105-
if node.value == "ARRAY_AGG" and node.order or node.limit:
106-
# rip the column out of the table
107-
column_name = self.column_map[node.schema_column.identity]
108-
column_def = groups.field(column_name) # this is used
109-
column = groups.column(column_name).to_pylist()
110-
groups = groups.drop([column_name])
159+
160+
# project to the desired column names from the pyarrow names
161+
groups = groups.select(list(self.column_map.values()) + self.group_by_columns)
162+
groups = groups.rename_columns(list(self.column_map.keys()) + self.group_by_columns)
163+
164+
# do the secondary activities for ARRAY_AGG (order and limit)
165+
array_agg_nodes = [
166+
node
167+
for node in get_all_nodes_of_type(
168+
self.aggregates, select_nodes=(NodeType.AGGREGATOR,)
169+
)
170+
if node.value == "ARRAY_AGG" and (node.order or node.limit)
171+
]
172+
173+
if array_agg_nodes:
174+
# Process all ARRAY_AGG columns that need ordering/limiting
175+
arrays_to_update = {}
176+
field_defs = {}
177+
178+
for node in array_agg_nodes:
179+
column_name = node.schema_column.identity
180+
181+
# Store field definition before we drop the column
182+
field_defs[column_name] = groups.field(column_name)
183+
184+
# Extract and process the data
185+
column_data = groups.column(column_name).to_pylist()
186+
187+
# Apply ordering if specified
111188
if node.order:
112-
column = [sorted(c, reverse=bool(node.order[0][1])) for c in column]
189+
column_data = [
190+
sorted(c, reverse=bool(node.order[0][1])) for c in column_data
191+
]
192+
193+
# Apply limit if specified
113194
if node.limit:
114-
column = [c[: node.limit] for c in column]
115-
# put the new column into the table
116-
groups = groups.append_column(column_def, [column])
195+
column_data = [c[: node.limit] for c in column_data]
196+
197+
arrays_to_update[column_name] = column_data
198+
199+
# Drop all columns we're updating
200+
columns_to_drop = list(arrays_to_update.keys())
201+
groups = groups.drop(columns_to_drop)
117202

118-
# project to the desired column names from the pyarrow names
119-
groups = groups.select(list(self.column_map.values()) + self.group_by_columns)
120-
groups = groups.rename_columns(list(self.column_map.keys()) + self.group_by_columns)
203+
# Append all updated columns back
204+
for column_name, column_data in arrays_to_update.items():
205+
groups = groups.append_column(field_defs[column_name], [column_data])
121206

122207
num_rows = groups.num_rows
123208
for start in range(0, num_rows, CHUNK_SIZE):
@@ -128,9 +213,10 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
128213

129214
morsel = project(morsel, self.all_identifiers)
130215
# Add a "*" column, this is an int because when a bool it miscounts
216+
# FIX: Use int8 as the comment states (bool can miscount)
131217
if "*" not in morsel.column_names:
132218
morsel = morsel.append_column(
133-
"*", [numpy.ones(shape=morsel.num_rows, dtype=numpy.bool_)]
219+
"*", [numpy.ones(shape=morsel.num_rows, dtype=numpy.int8)]
134220
)
135221
if self.evaluatable_nodes:
136222
morsel = evaluate_and_append(self.evaluatable_nodes, morsel)
@@ -144,9 +230,11 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
144230
self.buffer,
145231
promote_options="permissive",
146232
)
233+
# Only combine chunks once before aggregation
147234
table = table.combine_chunks()
148235
groups = table.group_by(self.group_by_columns)
149236
groups = groups.aggregate(self.aggregate_functions)
150237
self.buffer = [groups] # Replace buffer with partial result
238+
self._partial_aggregated = True # Mark that we've done a partial aggregation
151239

152240
yield None

opteryx/operators/heap_sort_node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def _sort_and_slice(self, table: pyarrow.Table) -> pyarrow.Table:
108108
indices = pyarrow.compute.sort_indices(column)
109109
else:
110110
indices = pyarrow.compute.sort_indices(column)[::-1]
111-
return table.take(indices.slice(0, self.limit))
111+
# Take min of limit and available indices to avoid index errors
112+
take_count = min(self.limit, len(indices))
113+
return table.take(indices.slice(0, take_count))
112114

113115
np_column = column.to_numpy()
114116
if use_decimal:

opteryx/utils/file_decoders.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,12 @@ def zstd_decoder(
174174

175175
import zstandard
176176

177-
# zstandard.open expects a file-like; we open on a BytesIO constructed from
178-
# the provided buffer and then pass the decompressed bytes as a memoryview
179-
if isinstance(buffer, memoryview):
180-
buf_bytes = buffer.tobytes()
181-
elif isinstance(buffer, bytes):
182-
buf_bytes = buffer
183-
else:
184-
# fallback, try to read
185-
try:
186-
buf_bytes = buffer.read()
187-
except Exception:
188-
buf_bytes = bytes(buffer)
177+
# zstandard.open expects a file-like
178+
if not isinstance(buffer, memoryview):
179+
buffer = memoryview(buffer)
180+
buffer = MemoryViewStream(buffer)
189181

190-
with zstandard.open(io.BytesIO(buf_bytes), "rb") as file:
182+
with zstandard.open(buffer, "rb") as file:
191183
decompressed = file.read()
192184
return jsonl_decoder(
193185
memoryview(decompressed),
@@ -215,17 +207,12 @@ def lzma_decoder(
215207
import lzma
216208

217209
# similar to zstd path: read bytes and pass decompressed data as memoryview
218-
if isinstance(buffer, memoryview):
219-
buf_bytes = buffer.tobytes()
220-
elif isinstance(buffer, bytes):
221-
buf_bytes = buffer
222-
else:
223-
try:
224-
buf_bytes = buffer.read()
225-
except Exception:
226-
buf_bytes = bytes(buffer)
210+
# zstandard.open expects a file-like
211+
if not isinstance(buffer, memoryview):
212+
buffer = memoryview(buffer)
213+
buffer = MemoryViewStream(buffer)
227214

228-
with lzma.open(io.BytesIO(buf_bytes), "rb") as file:
215+
with lzma.open(buffer, "rb") as file:
229216
decompressed = file.read()
230217
return jsonl_decoder(
231218
memoryview(decompressed),

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "opteryx"
3-
version = "0.26.0-beta.1706"
3+
version = "0.26.0-beta.1707"
44
description = "Query your data, where it lives"
55
requires-python = '>=3.11'
66
readme = {file = "README.md", content-type = "text/markdown"}

tests/fuzzing/test_sql_fuzzer_compare_engines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_sql_fuzzing_connector_comparisons(i):
168168

169169
try:
170170
duck_statement = statement.replace(table_name, table["duckdb_name"])
171-
duck_result = conn.query(duck_statement).arrow()
171+
duck_result = conn.query(duck_statement).arrow().read_all()
172172
opteryx_statement = statement.replace(table_name, table["opteryx_name"])
173173
opteryx_result = opteryx.query(opteryx_statement).arrow()
174174

0 commit comments

Comments
 (0)