Skip to content

Commit ff11ac8

Browse files
refactor: Create class for column ids and column refs (#1022)
1 parent c89e92e commit ff11ac8

28 files changed

+654
-437
lines changed

bigframes/core/__init__.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import bigframes.core.compile
3030
import bigframes.core.expression as ex
3131
import bigframes.core.guid
32+
import bigframes.core.identifiers as ids
3233
import bigframes.core.join_def as join_def
3334
import bigframes.core.local_data as local_data
3435
import bigframes.core.nodes as nodes
@@ -169,7 +170,7 @@ def row_count(self) -> ArrayValue:
169170
# Operations
170171
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
171172
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
172-
predicate: ex.Expression = ex.free_var(predicate_id)
173+
predicate: ex.Expression = ex.deref(predicate_id)
173174
if keep_null:
174175
predicate = ops.fillna_op.as_expr(predicate, ex.const(True))
175176
return self.filter(predicate)
@@ -200,7 +201,9 @@ def promote_offsets(self) -> Tuple[ArrayValue, str]:
200201
)
201202

202203
return (
203-
ArrayValue(nodes.PromoteOffsetsNode(child=self.node, col_id=col_id)),
204+
ArrayValue(
205+
nodes.PromoteOffsetsNode(child=self.node, col_id=ids.ColumnId(col_id))
206+
),
204207
col_id,
205208
)
206209

@@ -212,7 +215,9 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue:
212215

213216
def compute_values(self, assignments: Sequence[ex.Expression]):
214217
col_ids = self._gen_namespaced_uids(len(assignments))
215-
ex_id_pairs = tuple((ex, id) for ex, id in zip(assignments, col_ids))
218+
ex_id_pairs = tuple(
219+
(ex, ids.ColumnId(id)) for ex, id in zip(assignments, col_ids)
220+
)
216221
return (
217222
ArrayValue(nodes.ProjectionNode(child=self.node, assignments=ex_id_pairs)),
218223
col_ids,
@@ -228,14 +233,19 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue:
228233
if destination_id in self.column_ids: # Mutate case
229234
exprs = [
230235
(
231-
(source_id if (col_id == destination_id) else col_id),
232-
col_id,
236+
ex.deref(source_id if (col_id == destination_id) else col_id),
237+
ids.ColumnId(col_id),
233238
)
234239
for col_id in self.column_ids
235240
]
236241
else: # append case
237-
self_projection = ((col_id, col_id) for col_id in self.column_ids)
238-
exprs = [*self_projection, (source_id, destination_id)]
242+
self_projection = (
243+
(ex.deref(col_id), ids.ColumnId(col_id)) for col_id in self.column_ids
244+
)
245+
exprs = [
246+
*self_projection,
247+
(ex.deref(source_id), ids.ColumnId(destination_id)),
248+
]
239249
return ArrayValue(
240250
nodes.SelectionNode(
241251
child=self.node,
@@ -248,24 +258,15 @@ def create_constant(
248258
value: typing.Any,
249259
dtype: typing.Optional[bigframes.dtypes.Dtype],
250260
) -> Tuple[ArrayValue, str]:
251-
destination_id = self._gen_namespaced_uid()
252261
if pandas.isna(value):
253262
# Need to assign a data type when value is NaN.
254263
dtype = dtype or bigframes.dtypes.DEFAULT_DTYPE
255264

256-
return (
257-
ArrayValue(
258-
nodes.ProjectionNode(
259-
child=self.node,
260-
assignments=((ex.const(value, dtype), destination_id),),
261-
)
262-
),
263-
destination_id,
264-
)
265+
return self.project_to_id(ex.const(value, dtype))
265266

266267
def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
267268
# This basically just drops and reorders columns - logically a no-op except as a final step
268-
selections = ((col_id, col_id) for col_id in column_ids)
269+
selections = ((ex.deref(col_id), ids.ColumnId(col_id)) for col_id in column_ids)
269270
return ArrayValue(
270271
nodes.SelectionNode(
271272
child=self.node,
@@ -274,14 +275,8 @@ def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
274275
)
275276

276277
def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
277-
new_projection = (
278-
(col_id, col_id) for col_id in self.column_ids if col_id not in columns
279-
)
280-
return ArrayValue(
281-
nodes.SelectionNode(
282-
child=self.node,
283-
input_output_pairs=tuple(new_projection),
284-
)
278+
return self.select_columns(
279+
[col_id for col_id in self.column_ids if col_id not in columns]
285280
)
286281

287282
def aggregate(
@@ -297,11 +292,12 @@ def aggregate(
297292
by_column_id: column id of the aggregation key, this is preserved through the transform
298293
dropna: whether null keys should be dropped
299294
"""
295+
agg_defs = tuple((agg, ids.ColumnId(name)) for agg, name in aggregations)
300296
return ArrayValue(
301297
nodes.AggregateNode(
302298
child=self.node,
303-
aggregations=tuple(aggregations),
304-
by_column_ids=tuple(by_column_ids),
299+
aggregations=agg_defs,
300+
by_column_ids=tuple(map(ex.deref, by_column_ids)),
305301
dropna=dropna,
306302
)
307303
)
@@ -342,10 +338,10 @@ def project_window_op(
342338
ArrayValue(
343339
nodes.WindowOpNode(
344340
child=self.node,
345-
column_name=column_name,
341+
column_name=ex.deref(column_name),
346342
op=op,
347343
window_spec=window_spec,
348-
output_name=output_name,
344+
output_name=ids.ColumnId(output_name),
349345
never_skip_nulls=never_skip_nulls,
350346
skip_reproject_unsafe=skip_reproject_unsafe,
351347
)
@@ -376,7 +372,9 @@ def relational_join(
376372
join_node = nodes.JoinNode(
377373
left_child=self.node,
378374
right_child=other.node,
379-
conditions=conditions,
375+
conditions=tuple(
376+
(ex.deref(l_col), ex.deref(r_col)) for l_col, r_col in conditions
377+
),
380378
type=type,
381379
)
382380
# Maps input ids to output ids for caller convenience
@@ -414,7 +412,7 @@ def explode(self, column_ids: typing.Sequence[str]) -> ArrayValue:
414412
for column_id in column_ids:
415413
assert bigframes.dtypes.is_array_like(self.get_column_type(column_id))
416414

417-
offsets = tuple(self.get_offset_for_name(id) for id in column_ids)
415+
offsets = tuple(ex.deref(id) for id in column_ids)
418416
return ArrayValue(nodes.ExplodeNode(child=self.node, column_ids=offsets))
419417

420418
def _uniform_sampling(self, fraction: float) -> ArrayValue:
@@ -425,9 +423,6 @@ def _uniform_sampling(self, fraction: float) -> ArrayValue:
425423
"""
426424
return ArrayValue(nodes.RandomSampleNode(self.node, fraction))
427425

428-
def get_offset_for_name(self, name: str):
429-
return self.schema.names.index(name)
430-
431426
# Deterministically generate namespaced ids for new variables
432427
# These new ids are only unique within the current namespace.
433428
# Many operations, such as joins, create new namespaces. See: BigFrameNode.defines_namespace

bigframes/core/block_transforms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _interpolate_column(
210210
) -> typing.Tuple[blocks.Block, str]:
211211
if interpolate_method not in ["linear", "nearest", "ffill"]:
212212
raise ValueError("interpolate method not supported")
213-
window_ordering = (ordering.OrderingExpression(ex.free_var(x_values)),)
213+
window_ordering = (ordering.OrderingExpression(ex.deref(x_values)),)
214214
backwards_window = windows.rows(following=0, ordering=window_ordering)
215215
forwards_window = windows.rows(preceding=0, ordering=window_ordering)
216216

@@ -373,7 +373,7 @@ def value_counts(
373373
block = block.order_by(
374374
[
375375
ordering.OrderingExpression(
376-
ex.free_var(count_id),
376+
ex.deref(count_id),
377377
direction=ordering.OrderingDirection.ASC
378378
if ascending
379379
else ordering.OrderingDirection.DESC,
@@ -430,7 +430,7 @@ def rank(
430430
nullity_col_ids.append(nullity_col_id)
431431
window_ordering = (
432432
ordering.OrderingExpression(
433-
ex.free_var(col),
433+
ex.deref(col),
434434
ordering.OrderingDirection.ASC
435435
if ascending
436436
else ordering.OrderingDirection.DESC,
@@ -522,7 +522,7 @@ def nsmallest(
522522
block = block.reversed()
523523
order_refs = [
524524
ordering.OrderingExpression(
525-
ex.free_var(col_id), direction=ordering.OrderingDirection.ASC
525+
ex.deref(col_id), direction=ordering.OrderingDirection.ASC
526526
)
527527
for col_id in column_ids
528528
]
@@ -552,7 +552,7 @@ def nlargest(
552552
block = block.reversed()
553553
order_refs = [
554554
ordering.OrderingExpression(
555-
ex.free_var(col_id), direction=ordering.OrderingDirection.DESC
555+
ex.deref(col_id), direction=ordering.OrderingDirection.DESC
556556
)
557557
for col_id in column_ids
558558
]
@@ -849,9 +849,9 @@ def _idx_extrema(
849849
)
850850
# Have to find the min for each
851851
order_refs = [
852-
ordering.OrderingExpression(ex.free_var(value_col), direction),
852+
ordering.OrderingExpression(ex.deref(value_col), direction),
853853
*[
854-
ordering.OrderingExpression(ex.free_var(idx_col))
854+
ordering.OrderingExpression(ex.deref(idx_col))
855855
for idx_col in original_block.index_columns
856856
],
857857
]

0 commit comments

Comments
 (0)