Skip to content

Commit 7c9b816

Browse files
authored
refactor: subclass DerefOp as ResolvedDerefOp (#1874)
* refactor: subclass DerefOp as ResolvedDerefOp * replace the `field` attribute by id, dtype, nullable * final cleanup
1 parent a4682e9 commit 7c9b816

File tree

5 files changed

+29
-51
lines changed

5 files changed

+29
-51
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def _(
120120
@compile_expression.register
121121
def _(
122122
self,
123-
expression: ex.SchemaFieldRefExpression,
123+
expression: ex.ResolvedDerefOp,
124124
) -> pl.Expr:
125-
return pl.col(expression.field.id.sql)
125+
return pl.col(expression.id.sql)
126126

127127
@compile_expression.register
128128
def _(

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,6 @@ def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression:
4242
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True))
4343

4444

45-
@compile_scalar_expression.register
46-
def compile_field_ref_expression(
47-
expr: expression.SchemaFieldRefExpression,
48-
) -> sge.Expression:
49-
return sge.ColumnDef(this=sge.to_identifier(expr.field.id.sql, quoted=True))
50-
51-
5245
@compile_scalar_expression.register
5346
def compile_constant_expression(
5447
expr: expression.ScalarConstantExpression,

bigframes/core/expression.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -429,55 +429,27 @@ def transform_children(self, t: Callable[[Expression], Expression]) -> Expressio
429429

430430

431431
@dataclasses.dataclass(frozen=True)
432-
class SchemaFieldRefExpression(Expression):
433-
"""An expression representing a schema field. This is essentially a DerefOp with input schema bound."""
432+
class ResolvedDerefOp(DerefOp):
433+
"""An expression that refers to a column by ID and resolved with schema bound."""
434434

435-
field: field.Field
435+
dtype: dtypes.Dtype
436+
is_nullable: bool
436437

437-
@property
438-
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
439-
return (self.field.id,)
440-
441-
@property
442-
def is_const(self) -> bool:
443-
return False
444-
445-
@property
446-
def nullable(self) -> bool:
447-
return self.field.nullable
438+
@classmethod
439+
def from_field(cls, f: field.Field):
440+
return cls(id=f.id, dtype=f.dtype, is_nullable=f.nullable)
448441

449442
@property
450443
def is_resolved(self) -> bool:
451444
return True
452445

453446
@property
454-
def output_type(self) -> dtypes.ExpressionType:
455-
return self.field.dtype
456-
457-
def bind_variables(
458-
self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False
459-
) -> Expression:
460-
return self
461-
462-
def bind_refs(
463-
self,
464-
bindings: Mapping[ids.ColumnId, Expression],
465-
allow_partial_bindings: bool = False,
466-
) -> Expression:
467-
if self.field.id in bindings.keys():
468-
return bindings[self.field.id]
469-
return self
470-
471-
@property
472-
def is_bijective(self) -> bool:
473-
return True
447+
def nullable(self) -> bool:
448+
return self.is_nullable
474449

475450
@property
476-
def is_identity(self) -> bool:
477-
return True
478-
479-
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
480-
return self
451+
def output_type(self) -> dtypes.ExpressionType:
452+
return self.dtype
481453

482454

483455
@dataclasses.dataclass(frozen=True)
@@ -589,7 +561,7 @@ def bind_schema_fields(
589561
return expr
590562

591563
expr_by_id = {
592-
id: SchemaFieldRefExpression(field) for id, field in field_by_id.items()
564+
id: ResolvedDerefOp.from_field(field) for id, field in field_by_id.items()
593565
}
594566
return expr.bind_refs(expr_by_id)
595567

bigframes/core/rewrite/schema_binding.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,17 @@ def bind_schema_to_node(
5252

5353
return dataclasses.replace(node, by=tuple(bound_bys))
5454

55+
if isinstance(node, nodes.JoinNode):
56+
conditions = tuple(
57+
(
58+
ex.ResolvedDerefOp.from_field(node.left_child.field_by_id[left.id]),
59+
ex.ResolvedDerefOp.from_field(node.right_child.field_by_id[right.id]),
60+
)
61+
for left, right in node.conditions
62+
)
63+
return dataclasses.replace(
64+
node,
65+
conditions=conditions,
66+
)
67+
5568
return node

tests/unit/core/test_expression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def test_deref_op_dtype_resolution():
7777

7878

7979
def test_field_ref_expr_dtype_resolution_short_circuit():
80-
expression = ex.SchemaFieldRefExpression(
81-
field.Field(ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE)
80+
expression = ex.ResolvedDerefOp(
81+
id=ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE, is_nullable=True
8282
)
8383
field_bindings = _create_field_bindings({"anotherCol": dtypes.STRING_DTYPE})
8484

0 commit comments

Comments
 (0)