Skip to content

Commit 7e8658b

Browse files
authored
refactor: add compile_random_sample (#1884)
Fixes internal issue 429248387
1 parent 3e6dfe7 commit 7e8658b

File tree

8 files changed

+380
-68
lines changed

8 files changed

+380
-68
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ def compile_explode(
261261
columns = tuple(ref.id.sql for ref in node.column_ids)
262262
return child.explode(columns, offsets_col)
263263

264+
@_compile_node.register
265+
def compile_random_sample(
266+
self, node: nodes.RandomSampleNode, child: ir.SQLGlotIR
267+
) -> ir.SQLGlotIR:
268+
return child.sample(node.fraction)
269+
264270

265271
def _replace_unsupported_ops(node: nodes.BigFrameNode):
266272
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 129 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import sqlglot.expressions as sge
2626

2727
from bigframes import dtypes
28-
from bigframes.core import guid
28+
from bigframes.core import guid, utils
2929
from bigframes.core.compile.sqlglot.expressions import typed_expr
3030
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
3131
import bigframes.core.local_data as local_data
@@ -71,7 +71,10 @@ def from_pyarrow(
7171
schema: bf_schema.ArraySchema,
7272
uid_gen: guid.SequentialUIDGenerator,
7373
) -> SQLGlotIR:
74-
"""Builds SQLGlot expression from pyarrow table."""
74+
"""Builds SQLGlot expression from a pyarrow table.
75+
76+
This is used to represent in-memory data as a SQL query.
77+
"""
7578
dtype_expr = sge.DataType(
7679
this=sge.DataType.Type.STRUCT,
7780
expressions=[
@@ -117,6 +120,16 @@ def from_table(
117120
alias_names: typing.Sequence[str],
118121
uid_gen: guid.SequentialUIDGenerator,
119122
) -> SQLGlotIR:
123+
"""Builds a SQLGlotIR expression from a BigQuery table.
124+
125+
Args:
126+
project_id (str): The project ID of the BigQuery table.
127+
dataset_id (str): The dataset ID of the BigQuery table.
128+
table_id (str): The table ID of the BigQuery table.
129+
col_names (typing.Sequence[str]): The names of the columns to select.
130+
alias_names (typing.Sequence[str]): The aliases for the selected columns.
131+
uid_gen (guid.SequentialUIDGenerator): A generator for unique identifiers.
132+
"""
120133
selections = [
121134
sge.Alias(
122135
this=sge.to_identifier(col_name, quoted=cls.quoted),
@@ -137,7 +150,7 @@ def from_query_string(
137150
cls,
138151
query_string: str,
139152
) -> SQLGlotIR:
140-
"""Builds SQLGlot expression from a query string"""
153+
"""Builds a SQLGlot expression from a query string"""
141154
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
142155
cte_name = sge.to_identifier(
143156
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
@@ -157,7 +170,7 @@ def from_union(
157170
output_ids: typing.Sequence[str],
158171
uid_gen: guid.SequentialUIDGenerator,
159172
) -> SQLGlotIR:
160-
"""Builds SQLGlot expression by union of multiple select expressions."""
173+
"""Builds a SQLGlot expression by unioning of multiple select expressions."""
161174
assert (
162175
len(list(selects)) >= 2
163176
), f"At least two select expressions must be provided, but got {selects}."
@@ -205,6 +218,7 @@ def select(
205218
self,
206219
selected_cols: tuple[tuple[str, sge.Expression], ...],
207220
) -> SQLGlotIR:
221+
"""Replaces new selected columns of the current SELECT clause."""
208222
selections = [
209223
sge.Alias(
210224
this=expr,
@@ -213,15 +227,41 @@ def select(
213227
for id, expr in selected_cols
214228
]
215229

216-
new_expr, _ = self._encapsulate_as_cte()
230+
new_expr = _select_to_cte(
231+
self.expr,
232+
sge.to_identifier(
233+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
234+
),
235+
)
217236
new_expr = new_expr.select(*selections, append=False)
218237
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
219238

239+
def project(
240+
self,
241+
projected_cols: tuple[tuple[str, sge.Expression], ...],
242+
) -> SQLGlotIR:
243+
"""Adds new columns to the SELECT clause."""
244+
projected_cols_expr = [
245+
sge.Alias(
246+
this=expr,
247+
alias=sge.to_identifier(id, quoted=self.quoted),
248+
)
249+
for id, expr in projected_cols
250+
]
251+
new_expr = _select_to_cte(
252+
self.expr,
253+
sge.to_identifier(
254+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
255+
),
256+
)
257+
new_expr = new_expr.select(*projected_cols_expr, append=True)
258+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
259+
220260
def order_by(
221261
self,
222262
ordering: tuple[sge.Ordered, ...],
223263
) -> SQLGlotIR:
224-
"""Adds ORDER BY clause to the query."""
264+
"""Adds an ORDER BY clause to the query."""
225265
if len(ordering) == 0:
226266
return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen)
227267
new_expr = self.expr.order_by(*ordering)
@@ -231,34 +271,24 @@ def limit(
231271
self,
232272
limit: int | None,
233273
) -> SQLGlotIR:
234-
"""Adds LIMIT clause to the query."""
274+
"""Adds a LIMIT clause to the query."""
235275
if limit is not None:
236276
new_expr = self.expr.limit(limit)
237277
else:
238278
new_expr = self.expr.copy()
239279
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
240280

241-
def project(
242-
self,
243-
projected_cols: tuple[tuple[str, sge.Expression], ...],
244-
) -> SQLGlotIR:
245-
projected_cols_expr = [
246-
sge.Alias(
247-
this=expr,
248-
alias=sge.to_identifier(id, quoted=self.quoted),
249-
)
250-
for id, expr in projected_cols
251-
]
252-
new_expr, _ = self._encapsulate_as_cte()
253-
new_expr = new_expr.select(*projected_cols_expr, append=True)
254-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
255-
256281
def filter(
257282
self,
258283
condition: sge.Expression,
259284
) -> SQLGlotIR:
260-
"""Filters the query with the given condition."""
261-
new_expr, _ = self._encapsulate_as_cte()
285+
"""Filters the query by adding a WHERE clause."""
286+
new_expr = _select_to_cte(
287+
self.expr,
288+
sge.to_identifier(
289+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
290+
),
291+
)
262292
return SQLGlotIR(
263293
expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen
264294
)
@@ -272,8 +302,15 @@ def join(
272302
joins_nulls: bool = True,
273303
) -> SQLGlotIR:
274304
"""Joins the current query with another SQLGlotIR instance."""
275-
left_select, left_table = self._encapsulate_as_cte()
276-
right_select, right_table = right._encapsulate_as_cte()
305+
left_cte_name = sge.to_identifier(
306+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
307+
)
308+
right_cte_name = sge.to_identifier(
309+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
310+
)
311+
312+
left_select = _select_to_cte(self.expr, left_cte_name)
313+
right_select = _select_to_cte(right.expr, right_cte_name)
277314

278315
left_ctes = left_select.args.pop("with", [])
279316
right_ctes = right_select.args.pop("with", [])
@@ -288,17 +325,50 @@ def join(
288325
new_expr = (
289326
sge.Select()
290327
.select(sge.Star())
291-
.from_(left_table)
292-
.join(right_table, on=join_on, join_type=join_type_str)
328+
.from_(sge.Table(this=left_cte_name))
329+
.join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str)
293330
)
294331
new_expr.set("with", sge.With(expressions=merged_ctes))
295332

296333
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
297334

335+
def explode(
336+
self,
337+
column_names: tuple[str, ...],
338+
offsets_col: typing.Optional[str],
339+
) -> SQLGlotIR:
340+
"""Unnests one or more array columns."""
341+
num_columns = len(list(column_names))
342+
assert num_columns > 0, "At least one column must be provided for explode."
343+
if num_columns == 1:
344+
return self._explode_single_column(column_names[0], offsets_col)
345+
else:
346+
return self._explode_multiple_columns(column_names, offsets_col)
347+
348+
def sample(self, fraction: float) -> SQLGlotIR:
349+
"""Uniform samples a fraction of the rows."""
350+
uuid_col = sge.to_identifier(
351+
next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted
352+
)
353+
uuid_expr = sge.Alias(this=sge.func("RAND"), alias=uuid_col)
354+
condition = sge.LT(
355+
this=uuid_col,
356+
expression=_literal(fraction, dtypes.FLOAT_DTYPE),
357+
)
358+
359+
new_cte_name = sge.to_identifier(
360+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
361+
)
362+
new_expr = _select_to_cte(
363+
self.expr.select(uuid_expr, append=True), new_cte_name
364+
).where(condition, append=False)
365+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
366+
298367
def insert(
299368
self,
300369
destination: bigquery.TableReference,
301370
) -> str:
371+
"""Generates an INSERT INTO SQL statement from the current SELECT clause."""
302372
return sge.insert(self.expr.subquery(), _table(destination)).sql(
303373
dialect=self.dialect, pretty=self.pretty
304374
)
@@ -307,6 +377,9 @@ def replace(
307377
self,
308378
destination: bigquery.TableReference,
309379
) -> str:
380+
"""Generates a MERGE statement to replace the destination table's contents.
381+
by the current SELECT clause.
382+
"""
310383
# Workaround for SQLGlot breaking change:
311384
# https://github.com/tobymao/sqlglot/pull/4495
312385
whens_expr = [
@@ -325,23 +398,10 @@ def replace(
325398
).sql(dialect=self.dialect, pretty=self.pretty)
326399
return f"{merge_str}\n{whens_str}"
327400

328-
def explode(
329-
self,
330-
column_names: tuple[str, ...],
331-
offsets_col: typing.Optional[str],
332-
) -> SQLGlotIR:
333-
num_columns = len(list(column_names))
334-
assert num_columns > 0, "At least one column must be provided for explode."
335-
if num_columns == 1:
336-
return self._explode_single_column(column_names[0], offsets_col)
337-
else:
338-
return self._explode_multiple_columns(column_names, offsets_col)
339-
340401
def _explode_single_column(
341402
self, column_name: str, offsets_col: typing.Optional[str]
342403
) -> SQLGlotIR:
343404
"""Helper method to handle the case of exploding a single column."""
344-
345405
offset = (
346406
sge.to_identifier(offsets_col, quoted=self.quoted) if offsets_col else None
347407
)
@@ -358,7 +418,12 @@ def _explode_single_column(
358418

359419
# TODO: "CROSS" if not keep_empty else "LEFT"
360420
# TODO: overlaps_with_parent to replace existing column.
361-
new_expr, _ = self._encapsulate_as_cte()
421+
new_expr = _select_to_cte(
422+
self.expr,
423+
sge.to_identifier(
424+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
425+
),
426+
)
362427
new_expr = new_expr.select(selection, append=False).join(
363428
unnest_expr, join_type="CROSS"
364429
)
@@ -408,33 +473,32 @@ def _explode_multiple_columns(
408473
for column in columns
409474
]
410475
)
411-
new_expr, _ = self._encapsulate_as_cte()
476+
new_expr = _select_to_cte(
477+
self.expr,
478+
sge.to_identifier(
479+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
480+
),
481+
)
412482
new_expr = new_expr.select(selection, append=False).join(
413483
unnest_expr, join_type="CROSS"
414484
)
415485
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
416486

417-
def _encapsulate_as_cte(
418-
self,
419-
) -> typing.Tuple[sge.Select, sge.Table]:
420-
"""Transforms a given sge.Select query by pushing its main SELECT statement
421-
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
422-
for the new query."""
423-
select_expr = self.expr.copy()
424487

425-
existing_ctes = select_expr.args.pop("with", [])
426-
new_cte_name = sge.to_identifier(
427-
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
428-
)
429-
new_cte = sge.CTE(
430-
this=select_expr,
431-
alias=new_cte_name,
432-
)
433-
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
434-
new_table_expr = sge.Table(this=new_cte_name)
435-
new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr)
436-
new_select_expr.set("with", new_with_clause)
437-
return new_select_expr, new_table_expr
488+
def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
489+
"""Transforms a given sge.Select query by pushing its main SELECT statement
490+
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
491+
for the new query."""
492+
select_expr = expr.copy()
493+
existing_ctes = select_expr.args.pop("with", [])
494+
new_cte = sge.CTE(
495+
this=select_expr,
496+
alias=cte_name,
497+
)
498+
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
499+
new_select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
500+
new_select_expr.set("with", new_with_clause)
501+
return new_select_expr
438502

439503

440504
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
@@ -454,6 +518,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
454518
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
455519
elif dtype == dtypes.JSON_DTYPE:
456520
return sge.ParseJSON(this=sge.convert(str(value)))
521+
elif dtype == dtypes.TIMEDELTA_DTYPE:
522+
return sge.convert(utils.timedelta_to_micros(value))
457523
elif dtypes.is_struct_like(dtype):
458524
items = [
459525
_literal(value=value[field_name], dtype=field_dtype).as_(

bigframes/core/compile/sqlglot/sqlglot_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def from_bigframes_dtype(
5959
return "JSON"
6060
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
6161
return "GEOGRAPHY"
62+
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
63+
return "INT64"
6264
elif isinstance(bigframes_dtype, pd.ArrowDtype):
6365
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
6466
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(

tests/unit/core/compile/sqlglot/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222

2323
from bigframes import dtypes
24+
import bigframes.core as core
2425
import bigframes.pandas as bpd
2526
import bigframes.testing.mocks as mocks
2627
import bigframes.testing.utils
@@ -115,6 +116,16 @@ def scalar_types_pandas_df() -> pd.DataFrame:
115116
return df
116117

117118

119+
@pytest.fixture(scope="module")
120+
def scalar_types_array_value(
121+
scalar_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session
122+
) -> core.ArrayValue:
123+
managed_data_source = core.local_data.ManagedArrowTable.from_pandas(
124+
scalar_types_pandas_df
125+
)
126+
return core.ArrayValue.from_managed(managed_data_source, compiler_session)
127+
128+
118129
@pytest.fixture(scope="session")
119130
def nested_structs_types_table_schema() -> typing.Sequence[bigquery.SchemaField]:
120131
return [

0 commit comments

Comments
 (0)