1515from __future__ import annotations
1616
1717import abc
18- from dataclasses import dataclass , field , fields
18+ from dataclasses import dataclass , field , fields , replace
1919import functools
2020import itertools
2121import typing
22- from typing import Tuple
22+ from typing import Callable , Tuple
2323
2424import pandas
2525
3939 import bigframes .session
4040
4141
42+ # A fixed number of variable to assume for overhead on some operations
43+ OVERHEAD_VARIABLES = 5
44+
45+
4246@dataclass (frozen = True )
4347class BigFrameNode :
4448 """
@@ -102,6 +106,60 @@ def roots(self) -> typing.Set[BigFrameNode]:
102106 def schema (self ) -> schemata .ArraySchema :
103107 ...
104108
109+ @property
110+ @abc .abstractmethod
111+ def variables_introduced (self ) -> int :
112+ """
113+ Defines number of values created by the current node. Helps represent the "width" of a query
114+ """
115+ ...
116+
117+ @property
118+ def relation_ops_created (self ) -> int :
119+ """
120+ Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
121+ """
122+ return 1
123+
124+ @property
125+ def joins (self ) -> bool :
126+ """
127+ Defines whether the node joins data.
128+ """
129+ return False
130+
131+ @functools .cached_property
132+ def total_variables (self ) -> int :
133+ return self .variables_introduced + sum (
134+ map (lambda x : x .total_variables , self .child_nodes )
135+ )
136+
137+ @functools .cached_property
138+ def total_relational_ops (self ) -> int :
139+ return self .relation_ops_created + sum (
140+ map (lambda x : x .total_relational_ops , self .child_nodes )
141+ )
142+
143+ @functools .cached_property
144+ def total_joins (self ) -> int :
145+ return int (self .joins ) + sum (map (lambda x : x .total_joins , self .child_nodes ))
146+
147+ @property
148+ def planning_complexity (self ) -> int :
149+ """
150+ Empirical heuristic measure of planning complexity.
151+
152+ Used to determine when to decompose overly complex computations. May require tuning.
153+ """
154+ return self .total_variables * self .total_relational_ops * (1 + self .total_joins )
155+
156+ @abc .abstractmethod
157+ def transform_children (
158+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
159+ ) -> BigFrameNode :
160+ """Apply a function to each child node."""
161+ ...
162+
105163
106164@dataclass (frozen = True )
107165class UnaryNode (BigFrameNode ):
@@ -115,6 +173,11 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
115173 def schema (self ) -> schemata .ArraySchema :
116174 return self .child .schema
117175
176+ def transform_children (
177+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
178+ ) -> BigFrameNode :
179+ return replace (self , child = t (self .child ))
180+
118181
119182@dataclass (frozen = True )
120183class JoinNode (BigFrameNode ):
@@ -154,6 +217,22 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
154217 )
155218 return schemata .ArraySchema (items )
156219
220+ @functools .cached_property
221+ def variables_introduced (self ) -> int :
222+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
223+ return OVERHEAD_VARIABLES
224+
225+ @property
226+ def joins (self ) -> bool :
227+ return True
228+
229+ def transform_children (
230+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
231+ ) -> BigFrameNode :
232+ return replace (
233+ self , left_child = t (self .left_child ), right_child = t (self .right_child )
234+ )
235+
157236
158237@dataclass (frozen = True )
159238class ConcatNode (BigFrameNode ):
@@ -182,6 +261,16 @@ def schema(self) -> schemata.ArraySchema:
182261 )
183262 return schemata .ArraySchema (items )
184263
264+ @functools .cached_property
265+ def variables_introduced (self ) -> int :
266+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
267+ return len (self .schema .items ) + OVERHEAD_VARIABLES
268+
269+ def transform_children (
270+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
271+ ) -> BigFrameNode :
272+ return replace (self , children = tuple (t (child ) for child in self .children ))
273+
185274
186275# Input Nodex
187276@dataclass (frozen = True )
@@ -201,6 +290,16 @@ def roots(self) -> typing.Set[BigFrameNode]:
201290 def schema (self ) -> schemata .ArraySchema :
202291 return self .data_schema
203292
293+ @functools .cached_property
294+ def variables_introduced (self ) -> int :
295+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
296+ return len (self .schema .items ) + 1
297+
298+ def transform_children (
299+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
300+ ) -> BigFrameNode :
301+ return self
302+
204303
205304# TODO: Refactor to take raw gbq object reference
206305@dataclass (frozen = True )
@@ -233,6 +332,20 @@ def schema(self) -> schemata.ArraySchema:
233332 )
234333 return schemata .ArraySchema (items )
235334
335+ @functools .cached_property
336+ def variables_introduced (self ) -> int :
337+ return len (self .columns ) + len (self .hidden_ordering_columns )
338+
339+ @property
340+ def relation_ops_created (self ) -> int :
341+ # Assume worst case, where readgbq actually has baked in analytic operation to generate index
342+ return 2
343+
344+ def transform_children (
345+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
346+ ) -> BigFrameNode :
347+ return self
348+
236349
237350# Unary nodes
238351@dataclass (frozen = True )
@@ -252,6 +365,14 @@ def schema(self) -> schemata.ArraySchema:
252365 schemata .SchemaItem (self .col_id , bigframes .dtypes .INT_DTYPE )
253366 )
254367
368+ @property
369+ def relation_ops_created (self ) -> int :
370+ return 2
371+
372+ @functools .cached_property
373+ def variables_introduced (self ) -> int :
374+ return 1
375+
255376
256377@dataclass (frozen = True )
257378class FilterNode (UnaryNode ):
@@ -264,6 +385,10 @@ def row_preserving(self) -> bool:
264385 def __hash__ (self ):
265386 return self ._node_hash
266387
388+ @property
389+ def variables_introduced (self ) -> int :
390+ return 1
391+
267392
268393@dataclass (frozen = True )
269394class OrderByNode (UnaryNode ):
@@ -281,6 +406,15 @@ def __post_init__(self):
281406 def __hash__ (self ):
282407 return self ._node_hash
283408
409+ @property
410+ def variables_introduced (self ) -> int :
411+ return 0
412+
413+ @property
414+ def relation_ops_created (self ) -> int :
415+ # Doesnt directly create any relational operations
416+ return 0
417+
284418
285419@dataclass (frozen = True )
286420class ReversedNode (UnaryNode ):
@@ -290,6 +424,15 @@ class ReversedNode(UnaryNode):
290424 def __hash__ (self ):
291425 return self ._node_hash
292426
427+ @property
428+ def variables_introduced (self ) -> int :
429+ return 0
430+
431+ @property
432+ def relation_ops_created (self ) -> int :
433+ # Doesnt directly create any relational operations
434+ return 0
435+
293436
294437@dataclass (frozen = True )
295438class ProjectionNode (UnaryNode ):
@@ -315,6 +458,12 @@ def schema(self) -> schemata.ArraySchema:
315458 )
316459 return schemata .ArraySchema (items )
317460
461+ @property
462+ def variables_introduced (self ) -> int :
463+ # ignore passthrough expressions
464+ new_vars = sum (1 for i in self .assignments if not i [0 ].is_identity )
465+ return new_vars
466+
318467
319468# TODO: Merge RowCount into Aggregate Node?
320469# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -334,6 +483,10 @@ def schema(self) -> schemata.ArraySchema:
334483 (schemata .SchemaItem ("count" , bigframes .dtypes .INT_DTYPE ),)
335484 )
336485
486+ @property
487+ def variables_introduced (self ) -> int :
488+ return 1
489+
337490
338491@dataclass (frozen = True )
339492class AggregateNode (UnaryNode ):
@@ -367,6 +520,10 @@ def schema(self) -> schemata.ArraySchema:
367520 )
368521 return schemata .ArraySchema (tuple ([* by_items , * agg_items ]))
369522
523+ @property
524+ def variables_introduced (self ) -> int :
525+ return len (self .aggregations ) + len (self .by_column_ids )
526+
370527
371528@dataclass (frozen = True )
372529class WindowOpNode (UnaryNode ):
@@ -396,12 +553,31 @@ def schema(self) -> schemata.ArraySchema:
396553 schemata .SchemaItem (self .output_name , new_item_dtype )
397554 )
398555
556+ @property
557+ def variables_introduced (self ) -> int :
558+ return 1
559+
560+ @property
561+ def relation_ops_created (self ) -> int :
562+ # Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
563+ return 0 if self .skip_reproject_unsafe else 4
564+
399565
566+ # TODO: Remove this op
400567@dataclass (frozen = True )
401568class ReprojectOpNode (UnaryNode ):
402569 def __hash__ (self ):
403570 return self ._node_hash
404571
572+ @property
573+ def variables_introduced (self ) -> int :
574+ return 0
575+
576+ @property
577+ def relation_ops_created (self ) -> int :
578+ # This op is not a real transformation, just a hint to the sql generator
579+ return 0
580+
405581
406582@dataclass (frozen = True )
407583class UnpivotNode (UnaryNode ):
@@ -428,6 +604,10 @@ def row_preserving(self) -> bool:
428604 def non_local (self ) -> bool :
429605 return True
430606
607+ @property
608+ def joins (self ) -> bool :
609+ return True
610+
431611 @functools .cached_property
432612 def schema (self ) -> schemata .ArraySchema :
433613 def infer_dtype (
@@ -469,6 +649,17 @@ def infer_dtype(
469649 ]
470650 return schemata .ArraySchema ((* index_items , * value_items , * passthrough_items ))
471651
652+ @property
653+ def variables_introduced (self ) -> int :
654+ return (
655+ len (self .schema .items ) - len (self .passthrough_columns ) + OVERHEAD_VARIABLES
656+ )
657+
658+ @property
659+ def relation_ops_created (self ) -> int :
660+ # Unpivot is essentially a cross join and a projection.
661+ return 2
662+
472663
473664@dataclass (frozen = True )
474665class RandomSampleNode (UnaryNode ):
@@ -485,6 +676,10 @@ def row_preserving(self) -> bool:
485676 def __hash__ (self ):
486677 return self ._node_hash
487678
679+ @property
680+ def variables_introduced (self ) -> int :
681+ return 1
682+
488683
489684@dataclass (frozen = True )
490685class ExplodeNode (UnaryNode ):
@@ -511,3 +706,11 @@ def schema(self) -> schemata.ArraySchema:
511706 for name in self .child .schema .names
512707 )
513708 return schemata .ArraySchema (items )
709+
710+ @property
711+ def relation_ops_created (self ) -> int :
712+ return 3
713+
714+ @functools .cached_property
715+ def variables_introduced (self ) -> int :
716+ return len (self .column_ids ) + 1
0 commit comments