Skip to content

Commit 64b5ff1

Browse files
refactor: Tree traversals now non-recursive (#1386)
1 parent 4c8e6c3 commit 64b5ff1

File tree

4 files changed

+379
-278
lines changed

4 files changed

+379
-278
lines changed

bigframes/core/bigframe_node.py

Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import abc
18+
import collections
19+
import dataclasses
20+
import functools
21+
import itertools
22+
import typing
23+
from typing import Callable, Dict, Generator, Iterable, Mapping, Set, Tuple
24+
25+
from bigframes.core import identifiers
26+
import bigframes.core.guid
27+
import bigframes.core.schema as schemata
28+
import bigframes.dtypes
29+
30+
if typing.TYPE_CHECKING:
31+
import bigframes.session
32+
33+
COLUMN_SET = frozenset[identifiers.ColumnId]
34+
35+
36+
@dataclasses.dataclass(frozen=True)
37+
class Field:
38+
id: identifiers.ColumnId
39+
dtype: bigframes.dtypes.Dtype
40+
41+
42+
@dataclasses.dataclass(eq=False, frozen=True)
43+
class BigFrameNode:
44+
"""
45+
Immutable node for representing 2D typed array as a tree of operators.
46+
47+
All subclasses must be hashable so as to be usable as caching key.
48+
"""
49+
50+
@property
51+
def deterministic(self) -> bool:
52+
"""Whether this node will evaluates deterministically."""
53+
return True
54+
55+
@property
56+
def row_preserving(self) -> bool:
57+
"""Whether this node preserves input rows."""
58+
return True
59+
60+
@property
61+
def non_local(self) -> bool:
62+
"""
63+
Whether this node combines information across multiple rows instead of processing rows independently.
64+
Used as an approximation for whether the expression may require shuffling to execute (and therefore be expensive).
65+
"""
66+
return False
67+
68+
@property
69+
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
70+
"""Direct children of this node"""
71+
return tuple([])
72+
73+
@property
74+
@abc.abstractmethod
75+
def row_count(self) -> typing.Optional[int]:
76+
return None
77+
78+
@abc.abstractmethod
79+
def remap_refs(
80+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
81+
) -> BigFrameNode:
82+
"""Remap variable references"""
83+
...
84+
85+
@property
86+
@abc.abstractmethod
87+
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
88+
"""The variables defined in this node (as opposed to by child nodes)."""
89+
...
90+
91+
@functools.cached_property
92+
def session(self):
93+
sessions = []
94+
for child in self.child_nodes:
95+
if child.session is not None:
96+
sessions.append(child.session)
97+
unique_sessions = len(set(sessions))
98+
if unique_sessions > 1:
99+
raise ValueError("Cannot use combine sources from multiple sessions.")
100+
elif unique_sessions == 1:
101+
return sessions[0]
102+
return None
103+
104+
def _validate(self):
105+
"""Validate the local data in the node."""
106+
return
107+
108+
@functools.cache
109+
def validate_tree(self) -> bool:
110+
for child in self.child_nodes:
111+
child.validate_tree()
112+
self._validate()
113+
field_list = list(self.fields)
114+
if len(set(field_list)) != len(field_list):
115+
raise ValueError(f"Non unique field ids {list(self.fields)}")
116+
return True
117+
118+
def _as_tuple(self) -> Tuple:
119+
"""Get all fields as tuple."""
120+
return tuple(getattr(self, field.name) for field in dataclasses.fields(self))
121+
122+
def __hash__(self) -> int:
123+
# Custom hash that uses cache to avoid costly recomputation
124+
return self._cached_hash
125+
126+
def __eq__(self, other) -> bool:
127+
# Custom eq that tries to short-circuit full structural comparison
128+
if not isinstance(other, self.__class__):
129+
return False
130+
if self is other:
131+
return True
132+
if hash(self) != hash(other):
133+
return False
134+
return self._as_tuple() == other._as_tuple()
135+
136+
# BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
137+
# Each subclass of BigFrameNode should use this property to implement __hash__
138+
# The default dataclass-generated __hash__ method is not cached
139+
@functools.cached_property
140+
def _cached_hash(self):
141+
return hash(self._as_tuple())
142+
143+
@property
144+
def roots(self) -> typing.Set[BigFrameNode]:
145+
roots = itertools.chain.from_iterable(
146+
map(lambda child: child.roots, self.child_nodes)
147+
)
148+
return set(roots)
149+
150+
# TODO: Store some local data lazily for select, aggregate nodes.
151+
@property
152+
@abc.abstractmethod
153+
def fields(self) -> Iterable[Field]:
154+
...
155+
156+
@property
157+
def ids(self) -> Iterable[identifiers.ColumnId]:
158+
"""All output ids from the node."""
159+
return (field.id for field in self.fields)
160+
161+
@property
162+
@abc.abstractmethod
163+
def variables_introduced(self) -> int:
164+
"""
165+
Defines number of values created by the current node. Helps represent the "width" of a query
166+
"""
167+
...
168+
169+
@property
170+
def relation_ops_created(self) -> int:
171+
"""
172+
Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
173+
"""
174+
return 1
175+
176+
@property
177+
def joins(self) -> bool:
178+
"""
179+
Defines whether the node joins data.
180+
"""
181+
return False
182+
183+
@property
184+
@abc.abstractmethod
185+
def order_ambiguous(self) -> bool:
186+
"""
187+
Whether row ordering is potentially ambiguous. For example, ReadTable (without a primary key) could be ordered in different ways.
188+
"""
189+
...
190+
191+
@property
192+
@abc.abstractmethod
193+
def explicitly_ordered(self) -> bool:
194+
"""
195+
Whether row ordering is potentially ambiguous. For example, ReadTable (without a primary key) could be ordered in different ways.
196+
"""
197+
...
198+
199+
@functools.cached_property
200+
def height(self) -> int:
201+
if len(self.child_nodes) == 0:
202+
return 0
203+
return max(child.height for child in self.child_nodes) + 1
204+
205+
@functools.cached_property
206+
def total_variables(self) -> int:
207+
return self.variables_introduced + sum(
208+
map(lambda x: x.total_variables, self.child_nodes)
209+
)
210+
211+
@functools.cached_property
212+
def total_relational_ops(self) -> int:
213+
return self.relation_ops_created + sum(
214+
map(lambda x: x.total_relational_ops, self.child_nodes)
215+
)
216+
217+
@functools.cached_property
218+
def total_joins(self) -> int:
219+
return int(self.joins) + sum(map(lambda x: x.total_joins, self.child_nodes))
220+
221+
@functools.cached_property
222+
def schema(self) -> schemata.ArraySchema:
223+
# TODO: Make schema just a view on fields
224+
return schemata.ArraySchema(
225+
tuple(schemata.SchemaItem(i.id.name, i.dtype) for i in self.fields)
226+
)
227+
228+
@property
229+
def planning_complexity(self) -> int:
230+
"""
231+
Empirical heuristic measure of planning complexity.
232+
233+
Used to determine when to decompose overly complex computations. May require tuning.
234+
"""
235+
return self.total_variables * self.total_relational_ops * (1 + self.total_joins)
236+
237+
@abc.abstractmethod
238+
def transform_children(
239+
self, t: Callable[[BigFrameNode], BigFrameNode]
240+
) -> BigFrameNode:
241+
"""Apply a function to each child node."""
242+
...
243+
244+
@abc.abstractmethod
245+
def remap_vars(
246+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
247+
) -> BigFrameNode:
248+
"""Remap defined (in this node only) variables."""
249+
...
250+
251+
@property
252+
def defines_namespace(self) -> bool:
253+
"""
254+
If true, this node establishes a new column id namespace.
255+
256+
If false, this node consumes and produces ids in the namespace
257+
"""
258+
return False
259+
260+
@property
261+
def referenced_ids(self) -> COLUMN_SET:
262+
return frozenset()
263+
264+
@functools.cached_property
265+
def defined_variables(self) -> set[str]:
266+
"""Full set of variables defined in the namespace, even if not selected."""
267+
self_defined_variables = set(self.schema.names)
268+
if self.defines_namespace:
269+
return self_defined_variables
270+
return self_defined_variables.union(
271+
*(child.defined_variables for child in self.child_nodes)
272+
)
273+
274+
def get_type(self, id: identifiers.ColumnId) -> bigframes.dtypes.Dtype:
275+
return self._dtype_lookup[id]
276+
277+
@functools.cached_property
278+
def _dtype_lookup(self):
279+
return {field.id: field.dtype for field in self.fields}
280+
281+
# Plan algorithms
282+
def unique_nodes(
283+
self: BigFrameNode,
284+
) -> Generator[BigFrameNode, None, None]:
285+
"""Walks the tree for unique nodes"""
286+
seen = set()
287+
stack: list[BigFrameNode] = [self]
288+
while stack:
289+
item = stack.pop()
290+
if item not in seen:
291+
yield item
292+
seen.add(item)
293+
stack.extend(item.child_nodes)
294+
295+
def edges(
296+
self: BigFrameNode,
297+
) -> Generator[Tuple[BigFrameNode, BigFrameNode], None, None]:
298+
for item in self.unique_nodes():
299+
for child in item.child_nodes:
300+
yield (item, child)
301+
302+
def iter_nodes_topo(self: BigFrameNode) -> Generator[BigFrameNode, None, None]:
303+
"""Returns nodes from bottom up."""
304+
queue = collections.deque(
305+
[node for node in self.unique_nodes() if not node.child_nodes]
306+
)
307+
308+
child_to_parents: Dict[
309+
BigFrameNode, Set[BigFrameNode]
310+
] = collections.defaultdict(set)
311+
for parent, child in self.edges():
312+
child_to_parents[child].add(parent)
313+
314+
yielded = set()
315+
316+
while queue:
317+
item = queue.popleft()
318+
yield item
319+
yielded.add(item)
320+
for parent in child_to_parents[item]:
321+
if set(parent.child_nodes).issubset(yielded):
322+
queue.append(parent)
323+
324+
def top_down(
325+
self: BigFrameNode,
326+
transform: Callable[[BigFrameNode], BigFrameNode],
327+
) -> BigFrameNode:
328+
"""
329+
Perform a top-down transformation of the BigFrameNode tree.
330+
"""
331+
to_process = [self]
332+
results: Dict[BigFrameNode, BigFrameNode] = {}
333+
334+
while to_process:
335+
item = to_process.pop()
336+
if item not in results.keys():
337+
item_result = transform(item)
338+
results[item] = item_result
339+
to_process.extend(item_result.child_nodes)
340+
341+
to_process = [self]
342+
# for each processed item, replace its children
343+
for item in reversed(list(results.keys())):
344+
results[item] = results[item].transform_children(lambda x: results[x])
345+
346+
return results[self]
347+
348+
def bottom_up(
349+
self: BigFrameNode,
350+
transform: Callable[[BigFrameNode], BigFrameNode],
351+
) -> BigFrameNode:
352+
"""
353+
Perform a bottom-up transformation of the BigFrameNode tree.
354+
355+
The `transform` function is applied to each node *after* its children
356+
have been transformed. This allows for transformations that depend
357+
on the results of transforming subtrees.
358+
359+
Returns the transformed root node.
360+
"""
361+
results: dict[BigFrameNode, BigFrameNode] = {}
362+
for node in list(self.iter_nodes_topo()):
363+
# child nodes have already been transformed
364+
result = node.transform_children(lambda x: results[x])
365+
result = transform(result)
366+
results[node] = result
367+
368+
return results[self]

0 commit comments

Comments
 (0)