21
21
22
22
from __future__ import annotations
23
23
24
+ import dataclasses
24
25
import functools
25
26
import itertools
26
27
import random
31
32
import google .cloud .bigquery as bigquery
32
33
import pandas as pd
33
34
35
+ import bigframes ._config .sampling_options as sampling_options
34
36
import bigframes .constants as constants
35
37
import bigframes .core as core
36
38
import bigframes .core .guid as guid
@@ -80,6 +82,14 @@ def _get_block(self) -> Block:
80
82
"""Get the underlying block value of the object"""
81
83
82
84
85
+ @dataclasses .dataclass ()
86
+ class MaterializationOptions :
87
+ downsampling : sampling_options .SamplingOptions = dataclasses .field (
88
+ default_factory = sampling_options .SamplingOptions
89
+ )
90
+ ordered : bool = True
91
+
92
+
83
93
class Block :
84
94
"""A immutable 2D data structure."""
85
95
@@ -395,23 +405,31 @@ def _to_dataframe(self, result) -> pd.DataFrame:
395
405
396
406
def to_pandas (
397
407
self ,
398
- value_keys : Optional [Iterable [str ]] = None ,
399
- max_results : Optional [int ] = None ,
400
408
max_download_size : Optional [int ] = None ,
401
409
sampling_method : Optional [str ] = None ,
402
410
random_state : Optional [int ] = None ,
403
411
* ,
404
412
ordered : bool = True ,
405
413
) -> Tuple [pd .DataFrame , bigquery .QueryJob ]:
406
414
"""Run query and download results as a pandas DataFrame."""
415
+ if (sampling_method is not None ) and (sampling_method not in _SAMPLING_METHODS ):
416
+ raise NotImplementedError (
417
+ f"The downsampling method { sampling_method } is not implemented, "
418
+ f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
419
+ )
407
420
408
- df , _ , query_job = self ._compute_and_count (
409
- value_keys = value_keys ,
410
- max_results = max_results ,
411
- max_download_size = max_download_size ,
412
- sampling_method = sampling_method ,
413
- random_state = random_state ,
414
- ordered = ordered ,
421
+ sampling = bigframes .options .sampling .with_max_download_size (max_download_size )
422
+ if sampling_method is not None :
423
+ sampling = sampling .with_method (sampling_method ).with_random_state ( # type: ignore
424
+ random_state
425
+ )
426
+ else :
427
+ sampling = sampling .with_disabled ()
428
+
429
+ df , query_job = self ._materialize_local (
430
+ materialize_options = MaterializationOptions (
431
+ downsampling = sampling , ordered = ordered
432
+ )
415
433
)
416
434
return df , query_job
417
435
@@ -439,57 +457,29 @@ def _copy_index_to_pandas(self, df: pd.DataFrame):
439
457
# See: https://github.com/pandas-dev/pandas-stubs/issues/804
440
458
df .index .names = self .index .names # type: ignore
441
459
442
- def _compute_and_count (
443
- self ,
444
- value_keys : Optional [Iterable [str ]] = None ,
445
- max_results : Optional [int ] = None ,
446
- max_download_size : Optional [int ] = None ,
447
- sampling_method : Optional [str ] = None ,
448
- random_state : Optional [int ] = None ,
449
- * ,
450
- ordered : bool = True ,
451
- ) -> Tuple [pd .DataFrame , int , bigquery .QueryJob ]:
460
+ def _materialize_local (
461
+ self , materialize_options : MaterializationOptions = MaterializationOptions ()
462
+ ) -> Tuple [pd .DataFrame , bigquery .QueryJob ]:
452
463
"""Run query and download results as a pandas DataFrame. Return the total number of results as well."""
453
464
# TODO(swast): Allow for dry run and timeout.
454
- enable_downsampling = (
455
- True
456
- if sampling_method is not None
457
- else bigframes .options .sampling .enable_downsampling
458
- )
459
-
460
- max_download_size = (
461
- max_download_size or bigframes .options .sampling .max_download_size
462
- )
463
-
464
- random_state = random_state or bigframes .options .sampling .random_state
465
-
466
- if sampling_method is None :
467
- sampling_method = bigframes .options .sampling .sampling_method or _UNIFORM
468
- sampling_method = sampling_method .lower ()
469
-
470
- if sampling_method not in _SAMPLING_METHODS :
471
- raise NotImplementedError (
472
- f"The downsampling method { sampling_method } is not implemented, "
473
- f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
474
- )
475
-
476
- expr = self ._apply_value_keys_to_expr (value_keys = value_keys )
477
-
478
465
results_iterator , query_job = self .session ._execute (
479
- expr , max_results = max_results , sorted = ordered
466
+ self . expr , sorted = materialize_options . ordered
480
467
)
481
-
482
468
table_size = (
483
469
self .session ._get_table_size (query_job .destination ) / _BYTES_TO_MEGABYTES
484
470
)
471
+ sample_config = materialize_options .downsampling
472
+ max_download_size = sample_config .max_download_size
485
473
fraction = (
486
474
max_download_size / table_size
487
475
if (max_download_size is not None ) and (table_size != 0 )
488
476
else 2
489
477
)
490
478
479
+ # TODO: Maybe materialize before downsampling
480
+ # Some downsampling methods
491
481
if fraction < 1 :
492
- if not enable_downsampling :
482
+ if not sample_config . enable_downsampling :
493
483
raise RuntimeError (
494
484
f"The data size ({ table_size :.2f} MB) exceeds the maximum download limit of "
495
485
f"{ max_download_size } MB. You can:\n \t * Enable downsampling in global options:\n "
@@ -507,42 +497,53 @@ def _compute_and_count(
507
497
"\n Please refer to the documentation for configuring the downloading limit." ,
508
498
UserWarning ,
509
499
)
510
- if sampling_method == _HEAD :
511
- total_rows = int (results_iterator .total_rows * fraction )
512
- results_iterator .max_results = total_rows
513
- df = self ._to_dataframe (results_iterator )
514
-
515
- if self .index_columns :
516
- df .set_index (list (self .index_columns ), inplace = True )
517
- df .index .names = self .index .names # type: ignore
518
- elif (sampling_method == _UNIFORM ) and (random_state is None ):
519
- filtered_expr = self .expr ._uniform_sampling (fraction )
520
- block = Block (
521
- filtered_expr ,
522
- index_columns = self .index_columns ,
523
- column_labels = self .column_labels ,
524
- index_labels = self .index .names ,
525
- )
526
- df , total_rows , _ = block ._compute_and_count (max_download_size = None )
527
- elif sampling_method == _UNIFORM :
528
- block = self ._split (
529
- fracs = (max_download_size / table_size ,),
530
- random_state = random_state ,
531
- preserve_order = True ,
532
- )[0 ]
533
- df , total_rows , _ = block ._compute_and_count (max_download_size = None )
534
- else :
535
- # This part should never be called, just in case.
536
- raise NotImplementedError (
537
- f"The downsampling method { sampling_method } is not implemented, "
538
- f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
539
- )
500
+ total_rows = results_iterator .total_rows
501
+ # Remove downsampling config from subsequent invocations, as otherwise could result in many
502
+ # iterations if downsampling undershoots
503
+ return self ._downsample (
504
+ total_rows = total_rows ,
505
+ sampling_method = sample_config .sampling_method ,
506
+ fraction = fraction ,
507
+ random_state = sample_config .random_state ,
508
+ )._materialize_local (
509
+ MaterializationOptions (ordered = materialize_options .ordered )
510
+ )
540
511
else :
541
512
total_rows = results_iterator .total_rows
542
513
df = self ._to_dataframe (results_iterator )
543
514
self ._copy_index_to_pandas (df )
544
515
545
- return df , total_rows , query_job
516
+ return df , query_job
517
+
518
+ def _downsample (
519
+ self , total_rows : int , sampling_method : str , fraction : float , random_state
520
+ ) -> Block :
521
+ # either selecting fraction or number of rows
522
+ if sampling_method == _HEAD :
523
+ filtered_block = self .slice (stop = int (total_rows * fraction ))
524
+ return filtered_block
525
+ elif (sampling_method == _UNIFORM ) and (random_state is None ):
526
+ filtered_expr = self .expr ._uniform_sampling (fraction )
527
+ block = Block (
528
+ filtered_expr ,
529
+ index_columns = self .index_columns ,
530
+ column_labels = self .column_labels ,
531
+ index_labels = self .index .names ,
532
+ )
533
+ return block
534
+ elif sampling_method == _UNIFORM :
535
+ block = self ._split (
536
+ fracs = (fraction ,),
537
+ random_state = random_state ,
538
+ preserve_order = True ,
539
+ )[0 ]
540
+ return block
541
+ else :
542
+ # This part should never be called, just in case.
543
+ raise NotImplementedError (
544
+ f"The downsampling method { sampling_method } is not implemented, "
545
+ f"please choose from { ',' .join (_SAMPLING_METHODS )} ."
546
+ )
546
547
547
548
def _split (
548
549
self ,
@@ -1209,10 +1210,9 @@ def retrieve_repr_request_results(
1209
1210
count = self .shape [0 ]
1210
1211
if count > max_results :
1211
1212
head_block = self .slice (0 , max_results )
1212
- computed_df , query_job = head_block .to_pandas (max_results = max_results )
1213
1213
else :
1214
1214
head_block = self
1215
- computed_df , query_job = head_block .to_pandas ()
1215
+ computed_df , query_job = head_block .to_pandas ()
1216
1216
formatted_df = computed_df .set_axis (self .column_labels , axis = 1 )
1217
1217
# we reset the axis and substitute the bf index name for the default
1218
1218
formatted_df .index .name = self .index .name
0 commit comments