Skip to content

Commit c89e92e

Browse files
refactor: Simplify query executor interface (#1015)
1 parent 1a38063 commit c89e92e

File tree

18 files changed

+337
-216
lines changed

18 files changed

+337
-216
lines changed

bigframes/core/blocks.py

Lines changed: 62 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,20 @@
2525
import dataclasses
2626
import functools
2727
import itertools
28-
import os
2928
import random
3029
import textwrap
3130
import typing
32-
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union
31+
from typing import (
32+
Iterable,
33+
List,
34+
Literal,
35+
Mapping,
36+
Optional,
37+
Sequence,
38+
Tuple,
39+
TYPE_CHECKING,
40+
Union,
41+
)
3342
import warnings
3443

3544
import bigframes_vendored.constants as constants
@@ -56,7 +65,10 @@
5665
import bigframes.features
5766
import bigframes.operations as ops
5867
import bigframes.operations.aggregations as agg_ops
59-
import bigframes.session._io.pandas
68+
import bigframes.session._io.pandas as io_pandas
69+
70+
if TYPE_CHECKING:
71+
import bigframes.session.executor
6072

6173
# Type constraint for wherever column labels are used
6274
Label = typing.Hashable
@@ -450,46 +462,14 @@ def reorder_levels(self, ids: typing.Sequence[str]):
450462
level_names = [self.col_id_to_index_name[index_id] for index_id in ids]
451463
return Block(self.expr, ids, self.column_labels, level_names)
452464

453-
def _to_dataframe(self, result) -> pd.DataFrame:
454-
"""Convert BigQuery data to pandas DataFrame with specific dtypes."""
455-
result_dataframe = self.session._rows_to_dataframe(result)
456-
# Runs strict validations to ensure internal type predictions and ibis are completely in sync
457-
# Do not execute these validations outside of testing suite.
458-
if "PYTEST_CURRENT_TEST" in os.environ:
459-
self._validate_result_schema(result.schema)
460-
return result_dataframe
461-
462-
def _validate_result_schema(
463-
self, bq_result_schema: list[bigquery.schema.SchemaField]
464-
):
465-
actual_schema = tuple(bq_result_schema)
466-
ibis_schema = self.expr._compiled_schema
467-
internal_schema = self.expr.schema
468-
if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:
469-
return
470-
if internal_schema.to_bigquery() != actual_schema:
471-
raise ValueError(
472-
f"This error should only occur while testing. BigFrames internal schema: {internal_schema.to_bigquery()} does not match actual schema: {actual_schema}"
473-
)
474-
if ibis_schema.to_bigquery() != actual_schema:
475-
raise ValueError(
476-
f"This error should only occur while testing. Ibis schema: {ibis_schema.to_bigquery()} does not match actual schema: {actual_schema}"
477-
)
478-
479465
def to_arrow(
480466
self,
481467
*,
482468
ordered: bool = True,
483469
) -> Tuple[pa.Table, bigquery.QueryJob]:
484470
"""Run query and download results as a pyarrow Table."""
485-
# pa.Table.from_pandas puts index columns last, so update the expression to match.
486-
expr = self.expr.select_columns(
487-
list(self.value_columns) + list(self.index_columns)
488-
)
489-
490-
_, query_job = self.session._execute(expr, ordered=ordered)
491-
results_iterator = query_job.result()
492-
pa_table = results_iterator.to_arrow()
471+
execute_result = self.session._executor.execute(self.expr, ordered=ordered)
472+
pa_table = execute_result.to_arrow_table()
493473

494474
pa_index_labels = []
495475
for index_level, index_label in enumerate(self._index_labels):
@@ -498,8 +478,10 @@ def to_arrow(
498478
else:
499479
pa_index_labels.append(f"__index_level_{index_level}__")
500480

481+
# pa.Table.from_pandas puts index columns last, so update to match.
482+
pa_table = pa_table.select([*self.value_columns, *self.index_columns])
501483
pa_table = pa_table.rename_columns(list(self.column_labels) + pa_index_labels)
502-
return pa_table, query_job
484+
return pa_table, execute_result.query_job
503485

504486
def to_pandas(
505487
self,
@@ -508,7 +490,7 @@ def to_pandas(
508490
random_state: Optional[int] = None,
509491
*,
510492
ordered: bool = True,
511-
) -> Tuple[pd.DataFrame, bigquery.QueryJob]:
493+
) -> Tuple[pd.DataFrame, Optional[bigquery.QueryJob]]:
512494
"""Run query and download results as a pandas DataFrame.
513495
514496
Args:
@@ -560,8 +542,8 @@ def try_peek(
560542
self, n: int = 20, force: bool = False
561543
) -> typing.Optional[pd.DataFrame]:
562544
if force or self.expr.supports_fast_peek:
563-
iterator, _ = self.session._peek(self.expr, n)
564-
df = self._to_dataframe(iterator)
545+
result = self.session._executor.peek(self.expr, n)
546+
df = io_pandas.arrow_to_pandas(result.to_arrow_table(), self.expr.schema)
565547
self._copy_index_to_pandas(df)
566548
return df
567549
else:
@@ -574,18 +556,15 @@ def to_pandas_batches(
574556
575557
page_size and max_results determine the size and number of batches,
576558
see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJob#google_cloud_bigquery_job_QueryJob_result"""
577-
dtypes = dict(zip(self.index_columns, self.index.dtypes))
578-
dtypes.update(zip(self.value_columns, self.dtypes))
579-
_, query_job = self.session._executor.execute(
580-
self.expr, ordered=True, use_explicit_destination=True
581-
)
582-
results_iterator = query_job.result(
583-
page_size=page_size, max_results=max_results
584-
)
585-
for arrow_table in results_iterator.to_arrow_iterable(
586-
bqstorage_client=self.session.bqstoragereadclient
587-
):
588-
df = bigframes.session._io.pandas.arrow_to_pandas(arrow_table, dtypes)
559+
execute_result = self.session._executor.execute(
560+
self.expr,
561+
ordered=True,
562+
use_explicit_destination=True,
563+
page_size=page_size,
564+
max_results=max_results,
565+
)
566+
for record_batch in execute_result.arrow_batches():
567+
df = io_pandas.arrow_to_pandas(record_batch, self.expr.schema)
589568
self._copy_index_to_pandas(df)
590569
yield df
591570

@@ -605,22 +584,19 @@ def _copy_index_to_pandas(self, df: pd.DataFrame):
605584

606585
def _materialize_local(
607586
self, materialize_options: MaterializationOptions = MaterializationOptions()
608-
) -> Tuple[pd.DataFrame, bigquery.QueryJob]:
587+
) -> Tuple[pd.DataFrame, Optional[bigquery.QueryJob]]:
609588
"""Run query and download results as a pandas DataFrame. Return the total number of results as well."""
610589
# TODO(swast): Allow for dry run and timeout.
611-
_, query_job = self.session._execute(
612-
self.expr, ordered=materialize_options.ordered
613-
)
614-
results_iterator = query_job.result()
615-
616-
table_size = (
617-
self.session._get_table_size(query_job.destination) / _BYTES_TO_MEGABYTES
590+
execute_result = self.session._executor.execute(
591+
self.expr, ordered=materialize_options.ordered, get_size_bytes=True
618592
)
593+
assert execute_result.total_bytes is not None
594+
table_mb = execute_result.total_bytes / _BYTES_TO_MEGABYTES
619595
sample_config = materialize_options.downsampling
620596
max_download_size = sample_config.max_download_size
621597
fraction = (
622-
max_download_size / table_size
623-
if (max_download_size is not None) and (table_size != 0)
598+
max_download_size / table_mb
599+
if (max_download_size is not None) and (table_mb != 0)
624600
else 2
625601
)
626602

@@ -629,7 +605,7 @@ def _materialize_local(
629605
if fraction < 1:
630606
if not sample_config.enable_downsampling:
631607
raise RuntimeError(
632-
f"The data size ({table_size:.2f} MB) exceeds the maximum download limit of "
608+
f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of "
633609
f"{max_download_size} MB. You can:\n\t* Enable downsampling in global options:\n"
634610
"\t\t`bigframes.options.sampling.enable_downsampling = True`\n"
635611
"\t* Update the global `max_download_size` option. Please make sure "
@@ -640,12 +616,12 @@ def _materialize_local(
640616
)
641617

642618
warnings.warn(
643-
f"The data size ({table_size:.2f} MB) exceeds the maximum download limit of"
619+
f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of"
644620
f"({max_download_size} MB). It will be downsampled to {max_download_size} MB for download."
645621
"\nPlease refer to the documentation for configuring the downloading limit.",
646622
UserWarning,
647623
)
648-
total_rows = results_iterator.total_rows
624+
total_rows = execute_result.total_rows
649625
# Remove downsampling config from subsequent invocations, as otherwise could result in many
650626
# iterations if downsampling undershoots
651627
return self._downsample(
@@ -657,11 +633,12 @@ def _materialize_local(
657633
MaterializationOptions(ordered=materialize_options.ordered)
658634
)
659635
else:
660-
total_rows = results_iterator.total_rows
661-
df = self._to_dataframe(results_iterator)
636+
total_rows = execute_result.total_rows
637+
arrow = self.session._executor.execute(self.expr).to_arrow_table()
638+
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
662639
self._copy_index_to_pandas(df)
663640

664-
return df, query_job
641+
return df, execute_result.query_job
665642

666643
def _downsample(
667644
self, total_rows: int, sampling_method: str, fraction: float, random_state
@@ -680,7 +657,7 @@ def _downsample(
680657
)
681658
return block
682659
elif sampling_method == _UNIFORM:
683-
block = self._split(
660+
block = self.split(
684661
fracs=(fraction,),
685662
random_state=random_state,
686663
sort=False,
@@ -693,7 +670,7 @@ def _downsample(
693670
f"please choose from {','.join(_SAMPLING_METHODS)}."
694671
)
695672

696-
def _split(
673+
def split(
697674
self,
698675
ns: Iterable[int] = (),
699676
fracs: Iterable[float] = (),
@@ -785,7 +762,7 @@ def _compute_dry_run(
785762
self, value_keys: Optional[Iterable[str]] = None
786763
) -> bigquery.QueryJob:
787764
expr = self._apply_value_keys_to_expr(value_keys=value_keys)
788-
_, query_job = self.session._dry_run(expr)
765+
query_job = self.session._executor.dry_run(expr)
789766
return query_job
790767

791768
def _apply_value_keys_to_expr(self, value_keys: Optional[Iterable[str]] = None):
@@ -1567,20 +1544,21 @@ def _forward_slice(self, start: int = 0, stop=None, step: int = 1):
15671544
@functools.cache
15681545
def retrieve_repr_request_results(
15691546
self, max_results: int
1570-
) -> Tuple[pd.DataFrame, int, bigquery.QueryJob]:
1547+
) -> Tuple[pd.DataFrame, int, Optional[bigquery.QueryJob]]:
15711548
"""
15721549
Retrieves a pandas dataframe containing only max_results many rows for use
15731550
with printing methods.
15741551
15751552
Returns a tuple of the dataframe and the overall number of rows of the query.
15761553
"""
15771554

1578-
results, query_job = self.session._executor.head(self.expr, max_results)
1555+
head_result = self.session._executor.head(self.expr, max_results)
15791556
count = self.session._executor.get_row_count(self.expr)
15801557

1581-
computed_df = self._to_dataframe(results)
1582-
self._copy_index_to_pandas(computed_df)
1583-
return computed_df, count, query_job
1558+
arrow = self.session._executor.execute(self.expr).to_arrow_table()
1559+
df = io_pandas.arrow_to_pandas(arrow, schema=self.expr.schema)
1560+
self._copy_index_to_pandas(df)
1561+
return df, count, head_result.query_job
15841562

15851563
def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
15861564
expr, result_id = self._expr.promote_offsets()
@@ -2330,7 +2308,10 @@ def to_sql_query(
23302308
# the BigQuery unicode column name feature?
23312309
substitutions[old_id] = new_id
23322310

2333-
sql = self.session._to_sql(
2311+
# Note: this uses the sql from the executor, so is coupled tightly to execution
2312+
# implementaton. It will reference cached tables instead of original data sources.
2313+
# Maybe should just compile raw BFET? Depends on user intent.
2314+
sql = self.session._executor.to_sql(
23342315
array_value, col_id_overrides=substitutions, enable_cache=enable_cache
23352316
)
23362317
return (
@@ -2424,7 +2405,7 @@ def _get_rows_as_json_values(self) -> Block:
24242405
# TODO(shobs): Replace direct SQL manipulation by structured expression
24252406
# manipulation
24262407
expr, ordering_column_name = self.expr.promote_offsets()
2427-
expr_sql = self.session._to_sql(expr)
2408+
expr_sql = self.session._executor.to_sql(expr)
24282409

24292410
# Names of the columns to serialize for the row.
24302411
# We will use the repr-eval pattern to serialize a value here and
@@ -2578,17 +2559,8 @@ def to_pandas(self, *, ordered: Optional[bool] = None) -> pd.Index:
25782559
raise bigframes.exceptions.NullIndexError(
25792560
"Cannot materialize index, as this object does not have an index. Set index column(s) using set_index."
25802561
)
2581-
# Project down to only the index column. So the query can be cached to visualize other data.
2582-
index_columns = list(self._block.index_columns)
2583-
expr = self._expr.select_columns(index_columns)
2584-
results, _ = self.session._execute(
2585-
expr, ordered=ordered if ordered is not None else True
2586-
)
2587-
df = expr.session._rows_to_dataframe(results)
2588-
df = df.set_index(index_columns)
2589-
index = df.index
2590-
index.names = list(self._block._index_labels) # type:ignore
2591-
return index
2562+
ordered = ordered if ordered is not None else True
2563+
return self._block.select_columns([]).to_pandas(ordered=ordered)[0].index
25922564

25932565
def resolve_level(self, level: LevelsType) -> typing.Sequence[str]:
25942566
if utils.is_list_like(level):

bigframes/core/compile/ibis_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,11 @@ def _ibis_dtype_to_arrow_dtype(ibis_dtype: ibis_dtypes.DataType) -> pa.DataType:
330330
if isinstance(ibis_dtype, ibis_dtypes.Struct):
331331
return pa.struct(
332332
[
333-
(name, _ibis_dtype_to_arrow_dtype(dtype))
333+
pa.field(
334+
name,
335+
_ibis_dtype_to_arrow_dtype(dtype),
336+
nullable=not pa.types.is_list(_ibis_dtype_to_arrow_dtype(dtype)),
337+
)
334338
for name, dtype in ibis_dtype.fields.items()
335339
]
336340
)

bigframes/core/schema.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import typing
2020

2121
import google.cloud.bigquery
22+
import pyarrow
2223

2324
import bigframes.core.guid
2425
import bigframes.dtypes
@@ -64,6 +65,19 @@ def to_bigquery(self) -> typing.Tuple[google.cloud.bigquery.SchemaField, ...]:
6465
for item in self.items
6566
)
6667

68+
def to_pyarrow(self) -> pyarrow.Schema:
69+
fields = []
70+
for item in self.items:
71+
pa_type = bigframes.dtypes.bigframes_dtype_to_arrow_dtype(item.dtype)
72+
fields.append(
73+
pyarrow.field(
74+
item.column,
75+
pa_type,
76+
nullable=not pyarrow.types.is_list(pa_type),
77+
)
78+
)
79+
return pyarrow.schema(fields)
80+
6781
def drop(self, columns: typing.Iterable[str]) -> ArraySchema:
6882
return ArraySchema(
6983
tuple(item for item in self.items if item.column not in columns)
@@ -74,6 +88,14 @@ def select(self, columns: typing.Iterable[str]) -> ArraySchema:
7488
tuple(SchemaItem(name, self.get_type(name)) for name in columns)
7589
)
7690

91+
def rename(self, mapping: typing.Mapping[str, str]) -> ArraySchema:
92+
return ArraySchema(
93+
tuple(
94+
SchemaItem(mapping.get(item.column, item.column), item.dtype)
95+
for item in self.items
96+
)
97+
)
98+
7799
def append(self, item: SchemaItem):
78100
return ArraySchema(tuple([*self.items, item]))
79101

0 commit comments

Comments
 (0)