2121
2222from __future__ import annotations
2323
24+ import dataclasses
2425import functools
2526import itertools
2627import random
3132import google .cloud .bigquery as bigquery
3233import pandas as pd
3334
35+ import bigframes ._config .sampling_options as sampling_options
3436import bigframes .constants as constants
3537import bigframes .core as core
3638import bigframes .core .guid as guid
@@ -80,6 +82,14 @@ def _get_block(self) -> Block:
8082 """Get the underlying block value of the object"""
8183
8284
85+ @dataclasses .dataclass ()
86+ class MaterializationOptions :
87+ downsampling : sampling_options .SamplingOptions = dataclasses .field (
88+ default_factory = sampling_options .SamplingOptions
89+ )
90+ ordered : bool = True
91+
92+
8393class Block :
8494 """A immutable 2D data structure."""
8595
@@ -395,23 +405,31 @@ def _to_dataframe(self, result) -> pd.DataFrame:
395405
396406 def to_pandas (
397407 self ,
398- value_keys : Optional [Iterable [str ]] = None ,
399- max_results : Optional [int ] = None ,
400408 max_download_size : Optional [int ] = None ,
401409 sampling_method : Optional [str ] = None ,
402410 random_state : Optional [int ] = None ,
403411 * ,
404412 ordered : bool = True ,
405413 ) -> Tuple [pd .DataFrame , bigquery .QueryJob ]:
406414 """Run query and download results as a pandas DataFrame."""
415+ if (sampling_method is not None ) and (sampling_method not in _SAMPLING_METHODS ):
416+ raise NotImplementedError (
417+ f"The downsampling method { sampling_method } is not implemented, "
418+ f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
419+ )
407420
408- df , _ , query_job = self ._compute_and_count (
409- value_keys = value_keys ,
410- max_results = max_results ,
411- max_download_size = max_download_size ,
412- sampling_method = sampling_method ,
413- random_state = random_state ,
414- ordered = ordered ,
421+ sampling = bigframes .options .sampling .with_max_download_size (max_download_size )
422+ if sampling_method is not None :
423+ sampling = sampling .with_method (sampling_method ).with_random_state ( # type: ignore
424+ random_state
425+ )
426+ else :
427+ sampling = sampling .with_disabled ()
428+
429+ df , query_job = self ._materialize_local (
430+ materialize_options = MaterializationOptions (
431+ downsampling = sampling , ordered = ordered
432+ )
415433 )
416434 return df , query_job
417435
@@ -439,57 +457,29 @@ def _copy_index_to_pandas(self, df: pd.DataFrame):
439457 # See: https://github.com/pandas-dev/pandas-stubs/issues/804
440458 df .index .names = self .index .names # type: ignore
441459
442- def _compute_and_count (
443- self ,
444- value_keys : Optional [Iterable [str ]] = None ,
445- max_results : Optional [int ] = None ,
446- max_download_size : Optional [int ] = None ,
447- sampling_method : Optional [str ] = None ,
448- random_state : Optional [int ] = None ,
449- * ,
450- ordered : bool = True ,
451- ) -> Tuple [pd .DataFrame , int , bigquery .QueryJob ]:
460+ def _materialize_local (
461+ self , materialize_options : MaterializationOptions = MaterializationOptions ()
462+ ) -> Tuple [pd .DataFrame , bigquery .QueryJob ]:
452463 """Run query and download results as a pandas DataFrame. Return the total number of results as well."""
453464 # TODO(swast): Allow for dry run and timeout.
454- enable_downsampling = (
455- True
456- if sampling_method is not None
457- else bigframes .options .sampling .enable_downsampling
458- )
459-
460- max_download_size = (
461- max_download_size or bigframes .options .sampling .max_download_size
462- )
463-
464- random_state = random_state or bigframes .options .sampling .random_state
465-
466- if sampling_method is None :
467- sampling_method = bigframes .options .sampling .sampling_method or _UNIFORM
468- sampling_method = sampling_method .lower ()
469-
470- if sampling_method not in _SAMPLING_METHODS :
471- raise NotImplementedError (
472- f"The downsampling method { sampling_method } is not implemented, "
473- f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
474- )
475-
476- expr = self ._apply_value_keys_to_expr (value_keys = value_keys )
477-
478465 results_iterator , query_job = self .session ._execute (
479- expr , max_results = max_results , sorted = ordered
466+ self . expr , sorted = materialize_options . ordered
480467 )
481-
482468 table_size = (
483469 self .session ._get_table_size (query_job .destination ) / _BYTES_TO_MEGABYTES
484470 )
471+ sample_config = materialize_options .downsampling
472+ max_download_size = sample_config .max_download_size
485473 fraction = (
486474 max_download_size / table_size
487475 if (max_download_size is not None ) and (table_size != 0 )
488476 else 2
489477 )
490478
479+ # TODO: Maybe materialize before downsampling
480+ # Some downsampling methods
491481 if fraction < 1 :
492- if not enable_downsampling :
482+ if not sample_config . enable_downsampling :
493483 raise RuntimeError (
494484 f"The data size ({ table_size :.2f} MB) exceeds the maximum download limit of "
495485 f"{ max_download_size } MB. You can:\n \t * Enable downsampling in global options:\n "
@@ -507,42 +497,53 @@ def _compute_and_count(
507497 "\n Please refer to the documentation for configuring the downloading limit." ,
508498 UserWarning ,
509499 )
510- if sampling_method == _HEAD :
511- total_rows = int (results_iterator .total_rows * fraction )
512- results_iterator .max_results = total_rows
513- df = self ._to_dataframe (results_iterator )
514-
515- if self .index_columns :
516- df .set_index (list (self .index_columns ), inplace = True )
517- df .index .names = self .index .names # type: ignore
518- elif (sampling_method == _UNIFORM ) and (random_state is None ):
519- filtered_expr = self .expr ._uniform_sampling (fraction )
520- block = Block (
521- filtered_expr ,
522- index_columns = self .index_columns ,
523- column_labels = self .column_labels ,
524- index_labels = self .index .names ,
525- )
526- df , total_rows , _ = block ._compute_and_count (max_download_size = None )
527- elif sampling_method == _UNIFORM :
528- block = self ._split (
529- fracs = (max_download_size / table_size ,),
530- random_state = random_state ,
531- preserve_order = True ,
532- )[0 ]
533- df , total_rows , _ = block ._compute_and_count (max_download_size = None )
534- else :
535- # This part should never be called, just in case.
536- raise NotImplementedError (
537- f"The downsampling method { sampling_method } is not implemented, "
538- f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
539- )
500+ total_rows = results_iterator .total_rows
501+ # Remove downsampling config from subsequent invocations, as otherwise could result in many
502+ # iterations if downsampling undershoots
503+ return self ._downsample (
504+ total_rows = total_rows ,
505+ sampling_method = sample_config .sampling_method ,
506+ fraction = fraction ,
507+ random_state = sample_config .random_state ,
508+ )._materialize_local (
509+ MaterializationOptions (ordered = materialize_options .ordered )
510+ )
540511 else :
541512 total_rows = results_iterator .total_rows
542513 df = self ._to_dataframe (results_iterator )
543514 self ._copy_index_to_pandas (df )
544515
545- return df , total_rows , query_job
516+ return df , query_job
517+
518+ def _downsample (
519+ self , total_rows : int , sampling_method : str , fraction : float , random_state
520+ ) -> Block :
521+ # either selecting fraction or number of rows
522+ if sampling_method == _HEAD :
523+ filtered_block = self .slice (stop = int (total_rows * fraction ))
524+ return filtered_block
525+ elif (sampling_method == _UNIFORM ) and (random_state is None ):
526+ filtered_expr = self .expr ._uniform_sampling (fraction )
527+ block = Block (
528+ filtered_expr ,
529+ index_columns = self .index_columns ,
530+ column_labels = self .column_labels ,
531+ index_labels = self .index .names ,
532+ )
533+ return block
534+ elif sampling_method == _UNIFORM :
535+ block = self ._split (
536+ fracs = (fraction ,),
537+ random_state = random_state ,
538+ preserve_order = True ,
539+ )[0 ]
540+ return block
541+ else :
542+ # This part should never be called, just in case.
543+ raise NotImplementedError (
544+ f"The downsampling method { sampling_method } is not implemented, "
545+ f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
546+ )
546547
547548 def _split (
548549 self ,
@@ -1209,10 +1210,9 @@ def retrieve_repr_request_results(
12091210 count = self .shape [0 ]
12101211 if count > max_results :
12111212 head_block = self .slice (0 , max_results )
1212- computed_df , query_job = head_block .to_pandas (max_results = max_results )
12131213 else :
12141214 head_block = self
1215- computed_df , query_job = head_block .to_pandas ()
1215+ computed_df , query_job = head_block .to_pandas ()
12161216 formatted_df = computed_df .set_axis (self .column_labels , axis = 1 )
12171217 # we reset the axis and substitute the bf index name for the default
12181218 formatted_df .index .name = self .index .name
0 commit comments