Skip to content

Commit c67656e

Browse files
authored
Merge pull request #2884 from mabel-dev/#2877
limit pushdown
2 parents b696fbe + b710794 commit c67656e

File tree

5 files changed

+220
-41
lines changed

5 files changed

+220
-41
lines changed

opteryx/__version__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# THIS FILE IS AUTOMATICALLY UPDATED DURING THE BUILD PROCESS
22
# DO NOT EDIT THIS FILE DIRECTLY
33

4-
__build__ = 1715
4+
__build__ = 1716
55
__author__ = "@joocer"
6-
__version__ = "0.26.0-beta.1715"
6+
__version__ = "0.26.0-beta.1716"
77

88
# Store the version here so:
99
# 1) we don't load dependencies by storing it in __init__.py

opteryx/planner/optimizer/strategies/limit_pushdown.py

Lines changed: 148 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
We try to push the limit to the other side of PROJECTS
1313
"""
1414

15+
from typing import Optional
16+
from typing import Set
17+
1518
from opteryx.connectors.capabilities import LimitPushable
1619
from opteryx.planner.logical_planner import LogicalPlan
1720
from opteryx.planner.logical_planner import LogicalPlanNode
@@ -23,55 +26,165 @@
2326

2427

2528
class LimitPushdownStrategy(OptimizationStrategy):
29+
"""Push LIMIT operators towards scans when it is safe to do so."""
30+
31+
_BARRIER_TYPES = {
32+
LogicalPlanStepType.Aggregate,
33+
LogicalPlanStepType.AggregateAndGroup,
34+
LogicalPlanStepType.Distinct,
35+
LogicalPlanStepType.Filter,
36+
LogicalPlanStepType.FunctionDataset,
37+
LogicalPlanStepType.HeapSort,
38+
LogicalPlanStepType.Limit,
39+
LogicalPlanStepType.MetadataWriter,
40+
LogicalPlanStepType.Order,
41+
LogicalPlanStepType.Set,
42+
LogicalPlanStepType.Union,
43+
}
44+
2645
def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerContext:
2746
if not context.optimized_plan:
28-
context.optimized_plan = context.pre_optimized_tree.copy() # type: ignore
47+
context.optimized_plan = context.pre_optimized_tree.copy() # type: ignore[arg-type]
2948

3049
if node.node_type == LogicalPlanStepType.Limit:
31-
if node.offset is not None:
32-
# we can't push down limits with offset
50+
if node.offset is not None or node.limit in (None, 0):
3351
return context
3452
node.nid = context.node_id
53+
if not hasattr(node, "pushdown_targets"):
54+
node.pushdown_targets = set(node.all_relations or [])
3555
context.collected_limits.append(node)
3656
return context
3757

38-
if (
39-
node.node_type == LogicalPlanStepType.Scan
40-
and LimitPushable in node.connector.__class__.mro()
41-
):
42-
for limit_node in context.collected_limits:
43-
if node.relation in limit_node.all_relations:
44-
self.statistics.optimization_limit_pushdown += 1
45-
context.optimized_plan.remove_node(limit_node.nid, heal=True)
46-
node.limit = limit_node.limit
47-
context.optimized_plan[context.node_id] = node
48-
elif node.node_type in (
49-
LogicalPlanStepType.Aggregate,
50-
LogicalPlanStepType.AggregateAndGroup,
51-
LogicalPlanStepType.Distinct,
52-
LogicalPlanStepType.Filter,
53-
LogicalPlanStepType.Join,
54-
LogicalPlanStepType.Order,
55-
LogicalPlanStepType.Union,
56-
LogicalPlanStepType.Scan,
57-
):
58-
# we don't push past here
59-
for limit_node in context.collected_limits:
60-
self.statistics.optimization_limit_pushdown += 1
61-
context.optimized_plan.remove_node(limit_node.nid, heal=True)
62-
context.optimized_plan.insert_node_after(
63-
limit_node.nid, limit_node, context.node_id
64-
)
65-
limit_node.columns = []
66-
context.collected_limits.clear()
58+
remaining_limits = []
59+
for limit_node in context.collected_limits:
60+
if self._should_skip_branch(limit_node, node):
61+
remaining_limits.append(limit_node)
62+
continue
63+
64+
if node.node_type == LogicalPlanStepType.Scan:
65+
outcome = self._apply_to_scan(limit_node, node, context)
66+
if outcome is True:
67+
continue
68+
if outcome is None:
69+
remaining_limits.append(limit_node)
70+
continue
71+
self._place_before_node(limit_node, node, context)
72+
continue
73+
74+
if node.node_type == LogicalPlanStepType.Join:
75+
if self._refine_targets_for_join(limit_node, node):
76+
remaining_limits.append(limit_node)
77+
continue
78+
self._place_before_node(limit_node, node, context)
79+
continue
6780

81+
if node.node_type in self._BARRIER_TYPES:
82+
self._place_before_node(limit_node, node, context)
83+
continue
84+
85+
remaining_limits.append(limit_node)
86+
87+
context.collected_limits = remaining_limits
6888
return context
6989

7090
def complete(self, plan: LogicalPlan, context: OptimizerContext) -> LogicalPlan:
71-
# No finalization needed for this strategy
91+
context.collected_limits.clear()
7292
return plan
7393

74-
def should_i_run(self, plan):
75-
# only run if there are LIMIT clauses in the plan
94+
def should_i_run(self, plan: LogicalPlan) -> bool:
7695
candidates = get_nodes_of_type_from_logical_plan(plan, (LogicalPlanStepType.Limit,))
7796
return len(candidates) > 0
97+
98+
@staticmethod
99+
def _collect_relations(node: LogicalPlanNode) -> Set[str]:
100+
relations = getattr(node, "all_relations", None)
101+
if relations:
102+
return set(relations)
103+
return set()
104+
105+
def _should_skip_branch(self, limit_node: LogicalPlanNode, node: LogicalPlanNode) -> bool:
106+
targets: Set[str] = getattr(limit_node, "pushdown_targets", set())
107+
if not targets:
108+
return False
109+
node_relations = self._collect_relations(node)
110+
return bool(node_relations) and targets.isdisjoint(node_relations)
111+
112+
def _apply_to_scan(
113+
self,
114+
limit_node: LogicalPlanNode,
115+
scan_node: LogicalPlanNode,
116+
context: OptimizerContext,
117+
) -> Optional[bool]:
118+
targets: Set[str] = getattr(
119+
limit_node, "pushdown_targets", set(limit_node.all_relations or [])
120+
)
121+
relation_names = {scan_node.relation, getattr(scan_node, "alias", None)}
122+
if targets and targets.isdisjoint({name for name in relation_names if name}):
123+
return None
124+
125+
connector = getattr(scan_node, "connector", None)
126+
if connector and LimitPushable in connector.__class__.mro():
127+
current_limit = getattr(scan_node, "limit", None)
128+
scan_node.limit = (
129+
limit_node.limit if current_limit is None else min(current_limit, limit_node.limit)
130+
)
131+
if limit_node.nid in context.optimized_plan:
132+
context.optimized_plan.remove_node(limit_node.nid, heal=True)
133+
context.optimized_plan[context.node_id] = scan_node
134+
self.statistics.optimization_limit_pushdown += 1
135+
return True
136+
137+
return False
138+
139+
def _refine_targets_for_join(
140+
self, limit_node: LogicalPlanNode, join_node: LogicalPlanNode
141+
) -> bool:
142+
join_type = getattr(join_node, "type", None)
143+
if not join_type:
144+
return False
145+
146+
targets: Set[str] = getattr(
147+
limit_node, "pushdown_targets", set(limit_node.all_relations or [])
148+
)
149+
if not targets:
150+
targets = set(limit_node.all_relations or [])
151+
152+
left_relations = set(getattr(join_node, "left_relation_names", []) or [])
153+
right_relations = set(getattr(join_node, "right_relation_names", []) or [])
154+
155+
new_targets: Optional[Set[str]] = None
156+
157+
if join_type == "left outer":
158+
new_targets = targets & left_relations
159+
elif join_type == "right outer":
160+
new_targets = targets & right_relations
161+
elif join_type == "cross join":
162+
left_size = getattr(join_node, "left_size", float("inf"))
163+
right_size = getattr(join_node, "right_size", float("inf"))
164+
left_choice = targets & left_relations
165+
right_choice = targets & right_relations
166+
if left_choice and right_choice:
167+
new_targets = left_choice if left_size <= right_size else right_choice
168+
elif left_choice:
169+
new_targets = left_choice
170+
elif right_choice:
171+
new_targets = right_choice
172+
else:
173+
return False
174+
175+
if not new_targets:
176+
return False
177+
178+
limit_node.pushdown_targets = new_targets
179+
limit_node.all_relations = set(new_targets)
180+
return True
181+
182+
def _place_before_node(
183+
self, limit_node: LogicalPlanNode, _: LogicalPlanNode, context: OptimizerContext
184+
) -> None:
185+
if limit_node.nid in context.optimized_plan:
186+
context.optimized_plan.remove_node(limit_node.nid, heal=True)
187+
context.optimized_plan.insert_node_after(limit_node.nid, limit_node, context.node_id)
188+
limit_node.columns = []
189+
limit_node.pushdown_targets = set(limit_node.all_relations or [])
190+
self.statistics.optimization_limit_pushdown += 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "opteryx"
3-
version = "0.26.0-beta.1715"
3+
version = "0.26.0-beta.1716"
44
description = "Query your data, where it lives"
55
requires-python = '>=3.11'
66
readme = {file = "README.md", content-type = "text/markdown"}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
import sys
3+
4+
sys.path.insert(1, os.path.join(sys.path[0], "../../.."))
5+
6+
import opteryx # noqa: E402
7+
import pytest # noqa: E402
8+
9+
10+
def _materialize(query: str):
11+
cursor = opteryx.query(query)
12+
cursor.materialize()
13+
return cursor
14+
15+
16+
def test_limit_pushdown_left_outer_join():
17+
query = (
18+
"SELECT s.name FROM testdata.satellites AS s "
19+
"LEFT JOIN testdata.planets AS p ON s.planetId = p.id LIMIT 5;"
20+
)
21+
cursor = _materialize(query)
22+
plan_lines = cursor.stats["executed_plan"].splitlines()
23+
scan_line = next(
24+
line for line in plan_lines if "READ (testdata.satellites AS s)" in line
25+
)
26+
assert "LIMIT 5" in scan_line, cursor.stats["executed_plan"]
27+
assert cursor.stats["rows_read"] <= 14, cursor.stats
28+
29+
30+
def test_limit_pushdown_cross_join_prefers_smaller_side():
31+
query = (
32+
"SELECT * FROM testdata.planets AS p CROSS JOIN testdata.satellites AS s LIMIT 5;"
33+
)
34+
cursor = _materialize(query)
35+
plan_lines = cursor.stats["executed_plan"].splitlines()
36+
scan_line = next(
37+
line for line in plan_lines if "READ (testdata.planets AS p)" in line
38+
)
39+
assert "LIMIT 5" in scan_line, cursor.stats["executed_plan"]
40+
assert cursor.stats["rows_read"] <= 182, cursor.stats
41+
42+
if __name__ == "__main__": # pragma: no cover
43+
pytest.main([__file__])

tests/unit/planner/test_limit_pushdown_parquet_disk.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import sys
33
import pytest
44

5-
sys.path.insert(1, os.path.join(sys.path[0], "../.."))
5+
sys.path.insert(1, os.path.join(sys.path[0], "../../.."))
66

77
import opteryx
8-
from opteryx.utils.formatter import format_sql
98
from tests import is_arm, is_mac, is_windows, skip_if
109

1110

@@ -35,6 +34,30 @@ def test_parquet_disk_limit_pushdown(query, expected_rows):
3534
cur.materialize()
3635
assert cur.stats["rows_read"] == expected_rows, cur.stats
3736

37+
38+
@skip_if(is_arm() or is_windows() or is_mac())
39+
def test_limit_pushdown_projection_plan():
40+
query = "SELECT name FROM (SELECT name FROM testdata.planets) AS s LIMIT 3;"
41+
cur = opteryx.query(query)
42+
cur.materialize()
43+
plan_lines = cur.stats["executed_plan"].splitlines()
44+
scan_line = next(line for line in plan_lines if "READ (testdata.planets)" in line)
45+
assert "LIMIT 3" in scan_line, cur.stats["executed_plan"]
46+
assert cur.stats["rows_read"] == 3, cur.stats
47+
48+
49+
@skip_if(is_arm() or is_windows() or is_mac())
50+
def test_limit_not_pushed_past_heap_sort():
51+
query = "SELECT name FROM testdata.planets ORDER BY name LIMIT 3;"
52+
cur = opteryx.query(query)
53+
cur.materialize()
54+
plan_lines = cur.stats["executed_plan"].splitlines()
55+
heap_sort_line = next(line for line in plan_lines if "HEAP SORT" in line)
56+
scan_line = next(line for line in plan_lines if "READ (testdata.planets)" in line)
57+
assert "LIMIT" in heap_sort_line # fused limit stays with heap sort
58+
assert "LIMIT" not in scan_line, cur.stats["executed_plan"]
59+
assert cur.stats["rows_read"] == 9, cur.stats
60+
3861
if __name__ == "__main__": # pragma: no cover
3962
import shutil
4063
import time
@@ -68,7 +91,7 @@ def test_parquet_disk_limit_pushdown(query, expected_rows):
6891
print(" \033[0;31m*\033[0m")
6992
else:
7093
print()
71-
except Exception as err:
94+
except (AssertionError, opteryx.exceptions.Error) as err:
7295
print(f"\033[0;31m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms ❌ *\033[0m")
7396
print(">", err)
7497
failed += 1

0 commit comments

Comments
 (0)