Skip to content

Commit 02f7ab6

Browse files
refactor: move query execution from ArrayValue to Session (#255)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 7cbbb7d commit 02f7ab6

File tree

9 files changed

+155
-111
lines changed

9 files changed

+155
-111
lines changed

bigframes/core/__init__.py

Lines changed: 15 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
from dataclasses import dataclass
1717
import io
1818
import typing
19-
from typing import Iterable, Literal, Optional, Sequence, Tuple
19+
from typing import Iterable, Literal, Sequence
2020

21-
from google.cloud import bigquery
22-
import ibis
2321
import ibis.expr.types as ibis_types
2422
import pandas
2523

@@ -86,7 +84,17 @@ def session(self) -> Session:
8684
required_session = self.node.session
8785
from bigframes import get_global_session
8886

89-
return self.node.session[0] if required_session else get_global_session()
87+
return (
88+
required_session if (required_session is not None) else get_global_session()
89+
)
90+
91+
def _try_evaluate_local(self):
92+
"""Use only for unit testing paths - not fully featured. Will throw exception if fails."""
93+
import ibis
94+
95+
return ibis.pandas.connect({}).execute(
96+
self._compile_ordered()._to_ibis_expr(ordering_mode="unordered")
97+
)
9098

9199
def get_column_type(self, key: str) -> bigframes.dtypes.Dtype:
92100
return self._compile_ordered().get_column_type(key)
@@ -97,97 +105,9 @@ def _compile_ordered(self) -> compiled.OrderedIR:
97105
def _compile_unordered(self) -> compiled.UnorderedIR:
98106
return compiler.compile_unordered(self.node)
99107

100-
def shape(self) -> typing.Tuple[int, int]:
101-
"""Returns dimensions as (length, width) tuple."""
102-
width = len(self._compile_unordered().columns)
103-
count_expr = self._compile_unordered()._to_ibis_expr().count()
104-
105-
# Support in-memory engines for hermetic unit tests.
106-
if not self.node.session:
107-
try:
108-
length = ibis.pandas.connect({}).execute(count_expr)
109-
return (length, width)
110-
except Exception:
111-
# Not all cases can be handled by pandas engine
112-
pass
113-
114-
sql = self.session.ibis_client.compile(count_expr)
115-
row_iterator, _ = self.session._start_query(
116-
sql=sql,
117-
max_results=1,
118-
)
119-
length = next(row_iterator)[0]
120-
return (length, width)
121-
122-
def to_sql(
123-
self,
124-
offset_column: typing.Optional[str] = None,
125-
col_id_overrides: typing.Mapping[str, str] = {},
126-
sorted: bool = False,
127-
) -> str:
128-
array_value = self
129-
if offset_column:
130-
array_value = self.promote_offsets(offset_column)
131-
if sorted:
132-
return array_value._compile_ordered().to_sql(
133-
col_id_overrides=col_id_overrides,
134-
sorted=sorted,
135-
)
136-
else:
137-
return array_value._compile_unordered().to_sql(
138-
col_id_overrides=col_id_overrides
139-
)
140-
141-
def start_query(
142-
self,
143-
job_config: Optional[bigquery.job.QueryJobConfig] = None,
144-
max_results: Optional[int] = None,
145-
*,
146-
sorted: bool = True,
147-
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
148-
"""Execute a query and return metadata about the results."""
149-
# TODO(swast): Cache the job ID so we can look it up again if they ask
150-
# for the results? We'd need a way to invalidate the cache if DataFrame
151-
# becomes mutable, though. Or move this method to the immutable
152-
# expression class.
153-
# TODO(swast): We might want to move this method to Session and/or
154-
# provide our own minimal metadata class. Tight coupling to the
155-
# BigQuery client library isn't ideal, especially if we want to support
156-
# a LocalSession for unit testing.
157-
# TODO(swast): Add a timeout here? If the query is taking a long time,
158-
# maybe we just print the job metadata that we have so far?
159-
sql = self.to_sql(sorted=sorted) # type:ignore
160-
return self.session._start_query(
161-
sql=sql,
162-
job_config=job_config,
163-
max_results=max_results,
164-
)
165-
166-
def cached(self, cluster_cols: typing.Sequence[str]) -> ArrayValue:
167-
"""Write the ArrayValue to a session table and create a new block object that references it."""
168-
compiled_value = self._compile_ordered()
169-
ibis_expr = compiled_value._to_ibis_expr(
170-
ordering_mode="unordered", expose_hidden_cols=True
171-
)
172-
tmp_table = self.session._ibis_to_temp_table(
173-
ibis_expr, cluster_cols=cluster_cols, api_name="cached"
174-
)
175-
176-
table_expression = self.session.ibis_client.table(
177-
f"{tmp_table.project}.{tmp_table.dataset_id}.{tmp_table.table_id}"
178-
)
179-
new_columns = [table_expression[column] for column in compiled_value.column_ids]
180-
new_hidden_columns = [
181-
table_expression[column]
182-
for column in compiled_value._hidden_ordering_column_names
183-
]
184-
return ArrayValue.from_ibis(
185-
self.session,
186-
table_expression,
187-
columns=new_columns,
188-
hidden_ordering_columns=new_hidden_columns,
189-
ordering=compiled_value._ordering,
190-
)
108+
def row_count(self) -> ArrayValue:
109+
"""Get number of rows in ArrayValue as a single-entry ArrayValue."""
110+
return ArrayValue(nodes.RowCountNode(child=self.node))
191111

192112
# Operations
193113

bigframes/core/blocks.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,19 @@ def index(self) -> indexes.IndexValue:
137137
@functools.cached_property
138138
def shape(self) -> typing.Tuple[int, int]:
139139
"""Returns dimensions as (length, width) tuple."""
140-
impl_length, _ = self._expr.shape()
141-
return (impl_length, len(self.value_columns))
140+
row_count_expr = self.expr.row_count()
141+
142+
# Support in-memory engines for hermetic unit tests.
143+
if self.expr.node.session is None:
144+
try:
145+
row_count = row_count_expr._try_evaluate_local().squeeze()
146+
return (row_count, len(self.value_columns))
147+
except Exception:
148+
pass
149+
150+
iter, _ = self.session._execute(row_count_expr, sorted=False)
151+
row_count = next(iter)[0]
152+
return (row_count, len(self.value_columns))
142153

143154
@property
144155
def index_columns(self) -> Sequence[str]:
@@ -182,6 +193,10 @@ def index_dtypes(
182193
"""Returns the dtypes of the index columns."""
183194
return [self.expr.get_column_type(col) for col in self.index_columns]
184195

196+
@property
197+
def session(self) -> core.Session:
198+
return self._expr.session
199+
185200
@functools.cached_property
186201
def col_id_to_label(self) -> typing.Mapping[str, Label]:
187202
"""Get column label for value columns, or index name for index columns"""
@@ -376,7 +391,7 @@ def _to_dataframe(self, result) -> pd.DataFrame:
376391
"""Convert BigQuery data to pandas DataFrame with specific dtypes."""
377392
dtypes = dict(zip(self.index_columns, self.index_dtypes))
378393
dtypes.update(zip(self.value_columns, self.dtypes))
379-
return self._expr.session._rows_to_dataframe(result, dtypes)
394+
return self.session._rows_to_dataframe(result, dtypes)
380395

381396
def to_pandas(
382397
self,
@@ -404,9 +419,9 @@ def to_pandas_batches(self):
404419
"""Download results one message at a time."""
405420
dtypes = dict(zip(self.index_columns, self.index_dtypes))
406421
dtypes.update(zip(self.value_columns, self.dtypes))
407-
results_iterator, _ = self._expr.start_query()
422+
results_iterator, _ = self.session._execute(self.expr, sorted=True)
408423
for arrow_table in results_iterator.to_arrow_iterable(
409-
bqstorage_client=self._expr.session.bqstoragereadclient
424+
bqstorage_client=self.session.bqstoragereadclient
410425
):
411426
df = bigframes.session._io.pandas.arrow_to_pandas(arrow_table, dtypes)
412427
self._copy_index_to_pandas(df)
@@ -460,12 +475,12 @@ def _compute_and_count(
460475

461476
expr = self._apply_value_keys_to_expr(value_keys=value_keys)
462477

463-
results_iterator, query_job = expr.start_query(
464-
max_results=max_results, sorted=ordered
478+
results_iterator, query_job = self.session._execute(
479+
expr, max_results=max_results, sorted=ordered
465480
)
466481

467482
table_size = (
468-
expr.session._get_table_size(query_job.destination) / _BYTES_TO_MEGABYTES
483+
self.session._get_table_size(query_job.destination) / _BYTES_TO_MEGABYTES
469484
)
470485
fraction = (
471486
max_download_size / table_size
@@ -607,7 +622,7 @@ def _compute_dry_run(
607622
) -> bigquery.QueryJob:
608623
expr = self._apply_value_keys_to_expr(value_keys=value_keys)
609624
job_config = bigquery.QueryJobConfig(dry_run=True)
610-
_, query_job = expr.start_query(job_config=job_config)
625+
_, query_job = self.session._execute(expr, job_config=job_config, dry_run=True)
611626
return query_job
612627

613628
def _apply_value_keys_to_expr(self, value_keys: Optional[Iterable[str]] = None):
@@ -1668,7 +1683,7 @@ def to_sql_query(
16681683
# the BigQuery unicode column name feature?
16691684
substitutions[old_id] = new_id
16701685

1671-
sql = array_value.to_sql(col_id_overrides=substitutions)
1686+
sql = self.session._to_sql(array_value, col_id_overrides=substitutions)
16721687
return (
16731688
sql,
16741689
new_ids[: len(idx_labels)],
@@ -1678,7 +1693,7 @@ def to_sql_query(
16781693
def cached(self) -> Block:
16791694
"""Write the block to a session table and create a new block object that references it."""
16801695
return Block(
1681-
self.expr.cached(cluster_cols=self.index_columns),
1696+
self.session._execute_and_cache(self.expr, cluster_cols=self.index_columns),
16821697
index_columns=self.index_columns,
16831698
column_labels=self.column_labels,
16841699
index_labels=self.index_labels,

bigframes/core/compile/compiled.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,22 @@ def to_sql(
268268
)
269269
return typing.cast(str, sql)
270270

271+
def row_count(self) -> OrderedIR:
272+
original_table = self._to_ibis_expr()
273+
ibis_table = original_table.agg(
274+
[
275+
original_table.count().name("count"),
276+
]
277+
)
278+
return OrderedIR(
279+
ibis_table,
280+
(ibis_table["count"],),
281+
ordering=ExpressionOrdering(
282+
ordering_value_columns=(OrderingColumnReference("count"),),
283+
total_ordering_columns=frozenset(["count"]),
284+
),
285+
)
286+
271287
def _to_ibis_expr(
272288
self,
273289
*,

bigframes/core/compile/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ def compile_concat(node: nodes.ConcatNode, ordered: bool = True):
173173
return concat_impl.concat_unordered(compiled_unordered)
174174

175175

176+
@_compile_node.register
177+
def compile_rowcount(node: nodes.RowCountNode, ordered: bool = True):
178+
result = compile_unordered(node.child).row_count()
179+
return result if ordered else result.to_unordered()
180+
181+
176182
@_compile_node.register
177183
def compile_aggregate(node: nodes.AggregateNode, ordered: bool = True):
178184
result = compile_unordered(node.child).aggregate(

bigframes/core/indexes/index.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,10 @@ def dtypes(
396396
) -> typing.Sequence[typing.Union[bf_dtypes.Dtype, np.dtype[typing.Any]]]:
397397
return self._block.index_dtypes
398398

399+
@property
400+
def session(self) -> core.Session:
401+
return self._expr.session
402+
399403
def __repr__(self) -> str:
400404
"""Converts an Index to a string."""
401405
# TODO(swast): Add a timeout here? If the query is taking a long time,
@@ -411,7 +415,7 @@ def to_pandas(self) -> pandas.Index:
411415
index_columns = list(self._block.index_columns)
412416
dtypes = dict(zip(index_columns, self.dtypes))
413417
expr = self._expr.select_columns(index_columns)
414-
results, _ = expr.start_query()
418+
results, _ = self.session._execute(expr)
415419
df = expr.session._rows_to_dataframe(results, dtypes)
416420
df = df.set_index(index_columns)
417421
index = df.index

bigframes/core/nodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class ReadGbqNode(BigFrameNode):
139139

140140
@property
141141
def session(self):
142-
return (self.table_session,)
142+
return self.table_session
143143

144144
def __hash__(self):
145145
return self._node_hash
@@ -229,6 +229,12 @@ def __hash__(self):
229229
return self._node_hash
230230

231231

232+
# TODO: Merge RowCount and Corr into Aggregate Node
233+
@dataclass(frozen=True)
234+
class RowCountNode(UnaryNode):
235+
pass
236+
237+
232238
@dataclass(frozen=True)
233239
class AggregateNode(UnaryNode):
234240
aggregations: typing.Tuple[typing.Tuple[str, agg_ops.AggregateOp, str], ...]

bigframes/dataframe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2701,7 +2701,8 @@ def _create_io_query(self, index: bool, ordering_id: Optional[str]) -> str:
27012701

27022702
if ordering_id is not None:
27032703
array_value = array_value.promote_offsets(ordering_id)
2704-
return array_value.to_sql(
2704+
return self._block.session._to_sql(
2705+
array_value=array_value,
27052706
col_id_overrides=id_overrides,
27062707
)
27072708

0 commit comments

Comments
 (0)