@@ -387,6 +387,31 @@ def test_merge_series(scalars_dfs, merge_how):
387387 assert_pandas_df_equal (bf_result , pd_result , ignore_order = True )
388388
389389
390+ def _convert_pandas_category (pd_s : pd .Series ):
391+ if not isinstance (pd_s .dtype , pd .CategoricalDtype ):
392+ raise ValueError ("Input must be a pandas Series with categorical data." )
393+
394+ if len (pd_s .dtype .categories ) == 0 :
395+ return pd .Series ([pd .NA ] * len (pd_s ), name = pd_s .name )
396+
397+ pd_interval : pd .IntervalIndex = pd_s .cat .categories [pd_s .cat .codes ] # type: ignore
398+ if pd_interval .closed == "left" :
399+ left_key = "left_inclusive"
400+ right_key = "right_exclusive"
401+ else :
402+ left_key = "left_exclusive"
403+ right_key = "right_inclusive"
404+ return pd .Series (
405+ [
406+ {left_key : interval .left , right_key : interval .right }
407+ if pd .notna (val )
408+ else pd .NA
409+ for val , interval in zip (pd_s , pd_interval )
410+ ],
411+ name = pd_s .name ,
412+ )
413+
414+
390415@pytest .mark .parametrize (
391416 ("right" ),
392417 [
@@ -420,23 +445,7 @@ def test_cut_default_labels(scalars_dfs, right):
420445 bf_result = bpd .cut (scalars_df ["float64_col" ], 5 , right = right ).to_pandas ()
421446
422447 # Convert to match data format
423- pd_interval = pd_result .cat .categories [pd_result .cat .codes ]
424- if pd_interval .closed == "left" :
425- left_key = "left_inclusive"
426- right_key = "right_exclusive"
427- else :
428- left_key = "left_exclusive"
429- right_key = "right_inclusive"
430- pd_result_converted = pd .Series (
431- [
432- {left_key : interval .left , right_key : interval .right }
433- if pd .notna (val )
434- else pd .NA
435- for val , interval in zip (pd_result , pd_interval )
436- ],
437- name = pd_result .name ,
438- )
439-
448+ pd_result_converted = _convert_pandas_category (pd_result )
440449 pd .testing .assert_series_equal (
441450 bf_result , pd_result_converted , check_index = False , check_dtype = False
442451 )
@@ -458,47 +467,36 @@ def test_cut_numeric_breaks(scalars_dfs, breaks, right):
458467 bf_result = bpd .cut (scalars_df ["float64_col" ], breaks , right = right ).to_pandas ()
459468
460469 # Convert to match data format
461- pd_interval = pd_result .cat .categories [pd_result .cat .codes ]
462- if pd_interval .closed == "left" :
463- left_key = "left_inclusive"
464- right_key = "right_exclusive"
465- else :
466- left_key = "left_exclusive"
467- right_key = "right_inclusive"
468-
469- pd_result_converted = pd .Series (
470- [
471- {left_key : interval .left , right_key : interval .right }
472- if pd .notna (val )
473- else pd .NA
474- for val , interval in zip (pd_result , pd_interval )
475- ],
476- name = pd_result .name ,
477- )
470+ pd_result_converted = _convert_pandas_category (pd_result )
478471
479472 pd .testing .assert_series_equal (
480473 bf_result , pd_result_converted , check_index = False , check_dtype = False
481474 )
482475
483476
484477@pytest .mark .parametrize (
485- ( "bins" ,) ,
478+ "bins" ,
486479 [
487- (- 1 ,), # negative integer bins argument
488- ([],), # empty iterable of bins
489- (["notabreak" ],), # iterable of wrong type
490- ([1 ],), # numeric breaks with only one numeric
491- # this is supported by pandas but not by
492- # the bigquery operation and a bigframes workaround
493- # is not yet available. Should return column
494- # of structs with all NaN values.
480+ pytest .param ([], id = "empty_list" ),
481+ pytest .param (
482+ [1 ], id = "single_int_list" , marks = pytest .mark .skip (reason = "b/404338651" )
483+ ),
484+ pytest .param (pd .IntervalIndex .from_tuples ([]), id = "empty_interval_index" ),
495485 ],
496486)
497- def test_cut_errors (scalars_dfs , bins ):
498- scalars_df , _ = scalars_dfs
487+ def test_cut_w_edge_cases (scalars_dfs , bins ):
488+ scalars_df , scalars_pandas_df = scalars_dfs
489+ bf_result = bpd .cut (scalars_df ["int64_too" ], bins , labels = False ).to_pandas ()
490+ if isinstance (bins , list ):
491+ bins = pd .IntervalIndex .from_tuples (bins )
492+ pd_result = pd .cut (scalars_pandas_df ["int64_too" ], bins , labels = False )
493+
494+ # Convert to match data format
495+ pd_result_converted = _convert_pandas_category (pd_result )
499496
500- with pytest .raises (ValueError ):
501- bpd .cut (scalars_df ["float64_col" ], bins )
497+ pd .testing .assert_series_equal (
498+ bf_result , pd_result_converted , check_index = False , check_dtype = False
499+ )
502500
503501
504502@pytest .mark .parametrize (
@@ -529,23 +527,7 @@ def test_cut_with_interval(scalars_dfs, bins, right):
529527 pd_result = pd .cut (scalars_pandas_df ["int64_too" ], bins , labels = False , right = right )
530528
531529 # Convert to match data format
532- pd_interval = pd_result .cat .categories [pd_result .cat .codes ]
533- if pd_interval .closed == "left" :
534- left_key = "left_inclusive"
535- right_key = "right_exclusive"
536- else :
537- left_key = "left_exclusive"
538- right_key = "right_inclusive"
539-
540- pd_result_converted = pd .Series (
541- [
542- {left_key : interval .left , right_key : interval .right }
543- if pd .notna (val )
544- else pd .NA
545- for val , interval in zip (pd_result , pd_interval )
546- ],
547- name = pd_result .name ,
548- )
530+ pd_result_converted = _convert_pandas_category (pd_result )
549531
550532 pd .testing .assert_series_equal (
551533 bf_result , pd_result_converted , check_index = False , check_dtype = False
0 commit comments