Skip to content

Commit e1ebc53

Browse files
fix: Used query row count metadata instead of table metadata (#1893)
1 parent e3f5e65 commit e1ebc53

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

bigframes/core/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def is_noop(self) -> bool:
161161
return (
162162
((not self.start) or (self.start == 0))
163163
and (self.step == 1)
164-
and ((self.stop is None) or (self.stop == self.row_count))
164+
and ((self.stop is None) or (self.stop == self.child.row_count))
165165
)
166166

167167
@property

bigframes/session/bq_caching_executor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def cache_results_table(
100100
original_root: nodes.BigFrameNode,
101101
table: bigquery.Table,
102102
ordering: order.RowOrdering,
103+
num_rows: Optional[int] = None,
103104
):
104105
# Assumption: GBQ cached table uses field name as bq column name
105106
scan_list = nodes.ScanList(
@@ -112,7 +113,7 @@ def cache_results_table(
112113
source=nodes.BigqueryDataSource(
113114
nodes.GbqTable.from_table(table),
114115
ordering=ordering,
115-
n_rows=table.num_rows,
116+
n_rows=num_rows,
116117
),
117118
scan_list=scan_list,
118119
table_session=original_root.session,
@@ -468,14 +469,16 @@ def _cache_with_cluster_cols(
468469
plan, sort_rows=False, materialize_all_order_keys=True
469470
)
470471
)
471-
tmp_table_ref = self._sql_as_cached_temp_table(
472+
tmp_table_ref, num_rows = self._sql_as_cached_temp_table(
472473
compiled.sql,
473474
compiled.sql_schema,
474475
cluster_cols=bq_io.select_cluster_cols(compiled.sql_schema, cluster_cols),
475476
)
476477
tmp_table = self.bqclient.get_table(tmp_table_ref)
477478
assert compiled.row_order is not None
478-
self.cache.cache_results_table(array_value.node, tmp_table, compiled.row_order)
479+
self.cache.cache_results_table(
480+
array_value.node, tmp_table, compiled.row_order, num_rows=num_rows
481+
)
479482

480483
def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
481484
"""Executes the query and uses the resulting table to rewrite future executions."""
@@ -487,14 +490,16 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
487490
sort_rows=False,
488491
)
489492
)
490-
tmp_table_ref = self._sql_as_cached_temp_table(
493+
tmp_table_ref, num_rows = self._sql_as_cached_temp_table(
491494
compiled.sql,
492495
compiled.sql_schema,
493496
cluster_cols=[offset_column],
494497
)
495498
tmp_table = self.bqclient.get_table(tmp_table_ref)
496499
assert compiled.row_order is not None
497-
self.cache.cache_results_table(array_value.node, tmp_table, compiled.row_order)
500+
self.cache.cache_results_table(
501+
array_value.node, tmp_table, compiled.row_order, num_rows=num_rows
502+
)
498503

499504
def _cache_with_session_awareness(
500505
self,
@@ -552,7 +557,7 @@ def _sql_as_cached_temp_table(
552557
sql: str,
553558
schema: Sequence[bigquery.SchemaField],
554559
cluster_cols: Sequence[str],
555-
) -> bigquery.TableReference:
560+
) -> tuple[bigquery.TableReference, Optional[int]]:
556561
assert len(cluster_cols) <= _MAX_CLUSTER_COLUMNS
557562
temp_table = self.storage_manager.create_temp_table(schema, cluster_cols)
558563

@@ -567,8 +572,8 @@ def _sql_as_cached_temp_table(
567572
job_config=job_config,
568573
)
569574
assert query_job is not None
570-
query_job.result()
571-
return query_job.destination
575+
iter = query_job.result()
576+
return query_job.destination, iter.total_rows
572577

573578
def _validate_result_schema(
574579
self,

tests/system/small/test_dataframe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3458,6 +3458,24 @@ def test_iloc_slice(scalars_df_index, scalars_pandas_df_index, start, stop, step
34583458
)
34593459

34603460

3461+
@pytest.mark.parametrize(
3462+
("start", "stop", "step"),
3463+
[
3464+
(0, 0, None),
3465+
],
3466+
)
3467+
def test_iloc_slice_after_cache(
3468+
scalars_df_index, scalars_pandas_df_index, start, stop, step
3469+
):
3470+
scalars_df_index.cache()
3471+
bf_result = scalars_df_index.iloc[start:stop:step].to_pandas()
3472+
pd_result = scalars_pandas_df_index.iloc[start:stop:step]
3473+
pd.testing.assert_frame_equal(
3474+
bf_result,
3475+
pd_result,
3476+
)
3477+
3478+
34613479
def test_iloc_slice_zero_step(scalars_df_index):
34623480
with pytest.raises(ValueError):
34633481
scalars_df_index.iloc[0:0:0]

0 commit comments

Comments
 (0)