Skip to content

Commit 3a4a9de

Browse files
refactor: Simplify join node definition (#966)
1 parent d42d674 commit 3a4a9de

File tree

10 files changed

+257
-346
lines changed

10 files changed

+257
-346
lines changed

bigframes/core/__init__.py

Lines changed: 26 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
import datetime
1818
import functools
1919
import io
20-
import itertools
2120
import typing
22-
from typing import Iterable, Optional, Sequence
21+
from typing import Iterable, Optional, Sequence, Tuple
2322
import warnings
2423

2524
import google.cloud.bigquery
@@ -191,19 +190,14 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue:
191190
nodes.ConcatNode(children=tuple([self.node, *[val.node for val in other]]))
192191
)
193192

194-
def project_to_id(self, expression: ex.Expression, output_id: str):
193+
def compute_values(self, assignments: Sequence[Tuple[ex.Expression, str]]):
195194
return ArrayValue(
196-
nodes.ProjectionNode(
197-
child=self.node,
198-
assignments=(
199-
(
200-
expression,
201-
output_id,
202-
),
203-
),
204-
)
195+
nodes.ProjectionNode(child=self.node, assignments=tuple(assignments))
205196
)
206197

198+
def project_to_id(self, expression: ex.Expression, output_id: str):
199+
return self.compute_values(((expression, output_id),))
200+
207201
def assign(self, source_id: str, destination_id: str) -> ArrayValue:
208202
if destination_id in self.column_ids: # Mutate case
209203
exprs = [
@@ -341,124 +335,33 @@ def _reproject_to_table(self) -> ArrayValue:
341335
)
342336
)
343337

344-
def unpivot(
345-
self,
346-
row_labels: typing.Sequence[typing.Hashable],
347-
unpivot_columns: typing.Sequence[
348-
typing.Tuple[str, typing.Tuple[typing.Optional[str], ...]]
349-
],
350-
*,
351-
passthrough_columns: typing.Sequence[str] = (),
352-
index_col_ids: typing.Sequence[str] = ["index"],
353-
join_side: typing.Literal["left", "right"] = "left",
354-
) -> ArrayValue:
355-
"""
356-
Unpivot ArrayValue columns.
357-
358-
Args:
359-
row_labels: Identifies the source of the row. Must be equal to length to source column list in unpivot_columns argument.
360-
unpivot_columns: Mapping of column id to list of input column ids. Lists of input columns may use None.
361-
passthrough_columns: Columns that will not be unpivoted. Column id will be preserved.
362-
index_col_id (str): The column id to be used for the row labels.
363-
364-
Returns:
365-
ArrayValue: The unpivoted ArrayValue
366-
"""
367-
# There will be N labels, used to disambiguate which of N source columns produced each output row
368-
explode_offsets_id = bigframes.core.guid.generate_guid("unpivot_offsets_")
369-
labels_array = self._create_unpivot_labels_array(
370-
row_labels, index_col_ids, explode_offsets_id
371-
)
372-
373-
# Unpivot creates N output rows for each input row, labels disambiguate these N rows
374-
joined_array = self._cross_join_w_labels(labels_array, join_side)
375-
376-
# Build the output rows as a case statment that selects between the N input columns
377-
unpivot_exprs = []
378-
# Supports producing multiple stacked ouput columns for stacking only part of hierarchical index
379-
for col_id, input_ids in unpivot_columns:
380-
# row explode offset used to choose the input column
381-
# we use offset instead of label as labels are not necessarily unique
382-
cases = itertools.chain(
383-
*(
384-
(
385-
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
386-
ex.free_var(id_or_null)
387-
if (id_or_null is not None)
388-
else ex.const(None),
389-
)
390-
for i, id_or_null in enumerate(input_ids)
391-
)
392-
)
393-
col_expr = ops.case_when_op.as_expr(*cases)
394-
unpivot_exprs.append((col_expr, col_id))
395-
396-
unpivot_col_ids = [id for id, _ in unpivot_columns]
397-
return ArrayValue(
398-
nodes.ProjectionNode(
399-
child=joined_array.node,
400-
assignments=(*unpivot_exprs,),
401-
)
402-
).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns])
403-
404-
def _cross_join_w_labels(
405-
self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"]
406-
) -> ArrayValue:
407-
"""
408-
Convert each row in self to N rows, one for each label in labels array.
409-
"""
410-
table_join_side = (
411-
join_def.JoinSide.LEFT if join_side == "left" else join_def.JoinSide.RIGHT
412-
)
413-
labels_join_side = table_join_side.inverse()
414-
labels_mappings = tuple(
415-
join_def.JoinColumnMapping(labels_join_side, id, id)
416-
for id in labels_array.schema.names
417-
)
418-
table_mappings = tuple(
419-
join_def.JoinColumnMapping(table_join_side, id, id)
420-
for id in self.schema.names
421-
)
422-
join = join_def.JoinDefinition(
423-
conditions=(), mappings=(*labels_mappings, *table_mappings), type="cross"
424-
)
425-
if join_side == "left":
426-
joined_array = self.relational_join(labels_array, join_def=join)
427-
else:
428-
joined_array = labels_array.relational_join(self, join_def=join)
429-
return joined_array
430-
431-
def _create_unpivot_labels_array(
432-
self,
433-
former_column_labels: typing.Sequence[typing.Hashable],
434-
col_ids: typing.Sequence[str],
435-
offsets_id: str,
436-
) -> ArrayValue:
437-
"""Create an ArrayValue from a list of label tuples."""
438-
rows = []
439-
for row_offset in range(len(former_column_labels)):
440-
row_label = former_column_labels[row_offset]
441-
row_label = (row_label,) if not isinstance(row_label, tuple) else row_label
442-
row = {
443-
col_ids[i]: (row_label[i] if pandas.notnull(row_label[i]) else None)
444-
for i in range(len(col_ids))
445-
}
446-
row[offsets_id] = row_offset
447-
rows.append(row)
448-
449-
return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session)
450-
451338
def relational_join(
452339
self,
453340
other: ArrayValue,
454-
join_def: join_def.JoinDefinition,
455-
) -> ArrayValue:
341+
conditions: typing.Tuple[typing.Tuple[str, str], ...] = (),
342+
type: typing.Literal["inner", "outer", "left", "right", "cross"] = "inner",
343+
) -> typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]:
456344
join_node = nodes.JoinNode(
457345
left_child=self.node,
458346
right_child=other.node,
459-
join=join_def,
347+
conditions=conditions,
348+
type=type,
460349
)
461-
return ArrayValue(join_node)
350+
# Maps input ids to output ids for caller convenience
351+
l_size = len(self.node.schema)
352+
l_mapping = {
353+
lcol: ocol
354+
for lcol, ocol in zip(
355+
self.node.schema.names, join_node.schema.names[:l_size]
356+
)
357+
}
358+
r_mapping = {
359+
rcol: ocol
360+
for rcol, ocol in zip(
361+
other.node.schema.names, join_node.schema.names[l_size:]
362+
)
363+
}
364+
return ArrayValue(join_node), (l_mapping, r_mapping)
462365

463366
def try_align_as_projection(
464367
self,

0 commit comments

Comments
 (0)