Skip to content

Commit 1c3879d

Browse files
perf: Reduce schema tracking overhead (#1056)
1 parent 4379438 commit 1c3879d

File tree

1 file changed

+48
-37
lines changed

1 file changed

+48
-37
lines changed

bigframes/core/nodes.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ def roots(self) -> typing.Set[BigFrameNode]:
127127
)
128128
return set(roots)
129129

130-
# TODO: For deep trees, this can create a lot of overhead, maybe use zero-copy persistent datastructure?
130+
# TODO: Store some local data lazily for select, aggregate nodes.
131131
@property
132132
@abc.abstractmethod
133-
def fields(self) -> Tuple[Field, ...]:
133+
def fields(self) -> Iterable[Field]:
134134
...
135135

136136
@property
@@ -252,8 +252,8 @@ class UnaryNode(BigFrameNode):
252252
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
253253
return (self.child,)
254254

255-
@functools.cached_property
256-
def fields(self) -> Tuple[Field, ...]:
255+
@property
256+
def fields(self) -> Iterable[Field]:
257257
return self.child.fields
258258

259259
@property
@@ -303,9 +303,9 @@ def explicitly_ordered(self) -> bool:
303303
# Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
304304
return False
305305

306-
@functools.cached_property
307-
def fields(self) -> Tuple[Field, ...]:
308-
return tuple(itertools.chain(self.left_child.fields, self.right_child.fields))
306+
@property
307+
def fields(self) -> Iterable[Field]:
308+
return itertools.chain(self.left_child.fields, self.right_child.fields)
309309

310310
@functools.cached_property
311311
def variables_introduced(self) -> int:
@@ -360,10 +360,10 @@ def explicitly_ordered(self) -> bool:
360360
# Consider concat as an ordered operations (even though input frames may not be ordered)
361361
return True
362362

363-
@functools.cached_property
364-
def fields(self) -> Tuple[Field, ...]:
363+
@property
364+
def fields(self) -> Iterable[Field]:
365365
# TODO: Output names should probably be aligned beforehand or be part of concat definition
366-
return tuple(
366+
return (
367367
Field(bfet_ids.ColumnId(f"column_{i}"), field.dtype)
368368
for i, field in enumerate(self.children[0].fields)
369369
)
@@ -407,8 +407,10 @@ def explicitly_ordered(self) -> bool:
407407
return True
408408

409409
@functools.cached_property
410-
def fields(self) -> Tuple[Field, ...]:
411-
return (Field(bfet_ids.ColumnId("labels"), self.start.fields[0].dtype),)
410+
def fields(self) -> Iterable[Field]:
411+
return (
412+
Field(bfet_ids.ColumnId("labels"), next(iter(self.start.fields)).dtype),
413+
)
412414

413415
@functools.cached_property
414416
def variables_introduced(self) -> int:
@@ -469,11 +471,11 @@ class ReadLocalNode(LeafNode):
469471
scan_list: ScanList
470472
session: typing.Optional[bigframes.session.Session] = None
471473

472-
@functools.cached_property
473-
def fields(self) -> Tuple[Field, ...]:
474-
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
474+
@property
475+
def fields(self) -> Iterable[Field]:
476+
return (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
475477

476-
@functools.cached_property
478+
@property
477479
def variables_introduced(self) -> int:
478480
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
479481
return len(self.scan_list.items) + 1
@@ -576,9 +578,9 @@ def __post_init__(self):
576578
def session(self):
577579
return self.table_session
578580

579-
@functools.cached_property
580-
def fields(self) -> Tuple[Field, ...]:
581-
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
581+
@property
582+
def fields(self) -> Iterable[Field]:
583+
return (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
582584

583585
@property
584586
def relation_ops_created(self) -> int:
@@ -644,8 +646,10 @@ def non_local(self) -> bool:
644646
return True
645647

646648
@property
647-
def fields(self) -> Tuple[Field, ...]:
648-
return (*self.child.fields, Field(self.col_id, bigframes.dtypes.INT_DTYPE))
649+
def fields(self) -> Iterable[Field]:
650+
return itertools.chain(
651+
self.child.fields, [Field(self.col_id, bigframes.dtypes.INT_DTYPE)]
652+
)
649653

650654
@property
651655
def relation_ops_created(self) -> int:
@@ -729,7 +733,7 @@ class SelectionNode(UnaryNode):
729733
]
730734

731735
@functools.cached_property
732-
def fields(self) -> Tuple[Field, ...]:
736+
def fields(self) -> Iterable[Field]:
733737
return tuple(
734738
Field(output, self.child.get_type(input.id))
735739
for input, output in self.input_output_pairs
@@ -774,13 +778,16 @@ def __post_init__(self):
774778
assert all(name not in self.child.schema.names for _, name in self.assignments)
775779

776780
@functools.cached_property
777-
def fields(self) -> Tuple[Field, ...]:
781+
def added_fields(self) -> Tuple[Field, ...]:
778782
input_types = self.child._dtype_lookup
779-
new_fields = (
783+
return tuple(
780784
Field(id, bigframes.dtypes.dtype_for_etype(ex.output_type(input_types)))
781785
for ex, id in self.assignments
782786
)
783-
return (*self.child.fields, *new_fields)
787+
788+
@property
789+
def fields(self) -> Iterable[Field]:
790+
return itertools.chain(self.child.fields, self.added_fields)
784791

785792
@property
786793
def variables_introduced(self) -> int:
@@ -811,8 +818,8 @@ def row_preserving(self) -> bool:
811818
def non_local(self) -> bool:
812819
return True
813820

814-
@functools.cached_property
815-
def fields(self) -> Tuple[Field, ...]:
821+
@property
822+
def fields(self) -> Iterable[Field]:
816823
return (Field(bfet_ids.ColumnId("count"), bigframes.dtypes.INT_DTYPE),)
817824

818825
@property
@@ -841,7 +848,7 @@ def non_local(self) -> bool:
841848
return True
842849

843850
@functools.cached_property
844-
def fields(self) -> Tuple[Field, ...]:
851+
def fields(self) -> Iterable[Field]:
845852
by_items = (
846853
Field(ref.id, self.child.get_type(ref.id)) for ref in self.by_column_ids
847854
)
@@ -854,7 +861,7 @@ def fields(self) -> Tuple[Field, ...]:
854861
)
855862
for agg, id in self.aggregations
856863
)
857-
return (*by_items, *agg_items)
864+
return tuple(itertools.chain(by_items, agg_items))
858865

859866
@property
860867
def variables_introduced(self) -> int:
@@ -896,11 +903,9 @@ class WindowOpNode(UnaryNode):
896903
def non_local(self) -> bool:
897904
return True
898905

899-
@functools.cached_property
900-
def fields(self) -> Tuple[Field, ...]:
901-
input_type = self.child.get_type(self.column_name.id)
902-
new_item_dtype = self.op.output_type(input_type)
903-
return (*self.child.fields, Field(self.output_name, new_item_dtype))
906+
@property
907+
def fields(self) -> Iterable[Field]:
908+
return itertools.chain(self.child.fields, [self.added_field])
904909

905910
@property
906911
def variables_introduced(self) -> int:
@@ -911,6 +916,12 @@ def relation_ops_created(self) -> int:
911916
# Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
912917
return 0 if self.skip_reproject_unsafe else 4
913918

919+
@functools.cached_property
920+
def added_field(self) -> Field:
921+
input_type = self.child.get_type(self.column_name.id)
922+
new_item_dtype = self.op.output_type(input_type)
923+
return Field(self.output_name, new_item_dtype)
924+
914925
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
915926
if self.output_name not in used_cols:
916927
return self.child
@@ -959,9 +970,9 @@ class ExplodeNode(UnaryNode):
959970
def row_preserving(self) -> bool:
960971
return False
961972

962-
@functools.cached_property
963-
def fields(self) -> Tuple[Field, ...]:
964-
return tuple(
973+
@property
974+
def fields(self) -> Iterable[Field]:
975+
return (
965976
Field(
966977
field.id,
967978
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(

0 commit comments

Comments
 (0)