@@ -100,6 +100,7 @@ def cache_results_table(
100
100
original_root : nodes .BigFrameNode ,
101
101
table : bigquery .Table ,
102
102
ordering : order .RowOrdering ,
103
+ num_rows : Optional [int ] = None ,
103
104
):
104
105
# Assumption: GBQ cached table uses field name as bq column name
105
106
scan_list = nodes .ScanList (
@@ -112,7 +113,7 @@ def cache_results_table(
112
113
source = nodes .BigqueryDataSource (
113
114
nodes .GbqTable .from_table (table ),
114
115
ordering = ordering ,
115
- n_rows = table . num_rows ,
116
+ n_rows = num_rows ,
116
117
),
117
118
scan_list = scan_list ,
118
119
table_session = original_root .session ,
@@ -468,14 +469,16 @@ def _cache_with_cluster_cols(
468
469
plan , sort_rows = False , materialize_all_order_keys = True
469
470
)
470
471
)
471
- tmp_table_ref = self ._sql_as_cached_temp_table (
472
+ tmp_table_ref , num_rows = self ._sql_as_cached_temp_table (
472
473
compiled .sql ,
473
474
compiled .sql_schema ,
474
475
cluster_cols = bq_io .select_cluster_cols (compiled .sql_schema , cluster_cols ),
475
476
)
476
477
tmp_table = self .bqclient .get_table (tmp_table_ref )
477
478
assert compiled .row_order is not None
478
- self .cache .cache_results_table (array_value .node , tmp_table , compiled .row_order )
479
+ self .cache .cache_results_table (
480
+ array_value .node , tmp_table , compiled .row_order , num_rows = num_rows
481
+ )
479
482
480
483
def _cache_with_offsets (self , array_value : bigframes .core .ArrayValue ):
481
484
"""Executes the query and uses the resulting table to rewrite future executions."""
@@ -487,14 +490,16 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
487
490
sort_rows = False ,
488
491
)
489
492
)
490
- tmp_table_ref = self ._sql_as_cached_temp_table (
493
+ tmp_table_ref , num_rows = self ._sql_as_cached_temp_table (
491
494
compiled .sql ,
492
495
compiled .sql_schema ,
493
496
cluster_cols = [offset_column ],
494
497
)
495
498
tmp_table = self .bqclient .get_table (tmp_table_ref )
496
499
assert compiled .row_order is not None
497
- self .cache .cache_results_table (array_value .node , tmp_table , compiled .row_order )
500
+ self .cache .cache_results_table (
501
+ array_value .node , tmp_table , compiled .row_order , num_rows = num_rows
502
+ )
498
503
499
504
def _cache_with_session_awareness (
500
505
self ,
@@ -552,7 +557,7 @@ def _sql_as_cached_temp_table(
552
557
sql : str ,
553
558
schema : Sequence [bigquery .SchemaField ],
554
559
cluster_cols : Sequence [str ],
555
- ) -> bigquery .TableReference :
560
+ ) -> tuple [ bigquery .TableReference , Optional [ int ]] :
556
561
assert len (cluster_cols ) <= _MAX_CLUSTER_COLUMNS
557
562
temp_table = self .storage_manager .create_temp_table (schema , cluster_cols )
558
563
@@ -567,8 +572,8 @@ def _sql_as_cached_temp_table(
567
572
job_config = job_config ,
568
573
)
569
574
assert query_job is not None
570
- query_job .result ()
571
- return query_job .destination
575
+ iter = query_job .result ()
576
+ return query_job .destination , iter . total_rows
572
577
573
578
def _validate_result_schema (
574
579
self ,
0 commit comments