@@ -514,6 +514,50 @@ def test_where_dataframe_cond_dataframe_other(
514
514
pandas .testing .assert_frame_equal (bf_result , pd_result )
515
515
516
516
517
+ def test_where_callable_cond_constant_other (scalars_df_index , scalars_pandas_df_index ):
518
+ # Condition is callable, other is a constant.
519
+ columns = ["int64_col" , "float64_col" ]
520
+ dataframe_bf = scalars_df_index [columns ]
521
+ dataframe_pd = scalars_pandas_df_index [columns ]
522
+
523
+ other = 10
524
+
525
+ bf_result = dataframe_bf .where (lambda x : x > 0 , other ).to_pandas ()
526
+ pd_result = dataframe_pd .where (lambda x : x > 0 , other )
527
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
528
+
529
+
530
+ def test_where_dataframe_cond_callable_other (scalars_df_index , scalars_pandas_df_index ):
531
+ # Condition is a dataframe, other is callable.
532
+ columns = ["int64_col" , "float64_col" ]
533
+ dataframe_bf = scalars_df_index [columns ]
534
+ dataframe_pd = scalars_pandas_df_index [columns ]
535
+
536
+ cond_bf = dataframe_bf > 0
537
+ cond_pd = dataframe_pd > 0
538
+
539
+ def func (x ):
540
+ return x * 2
541
+
542
+ bf_result = dataframe_bf .where (cond_bf , func ).to_pandas ()
543
+ pd_result = dataframe_pd .where (cond_pd , func )
544
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
545
+
546
+
547
+ def test_where_callable_cond_callable_other (scalars_df_index , scalars_pandas_df_index ):
548
+ # Condition is callable, other is callable too.
549
+ columns = ["int64_col" , "float64_col" ]
550
+ dataframe_bf = scalars_df_index [columns ]
551
+ dataframe_pd = scalars_pandas_df_index [columns ]
552
+
553
+ def func (x ):
554
+ return x ["int64_col" ] > 0
555
+
556
+ bf_result = dataframe_bf .where (func , lambda x : x * 2 ).to_pandas ()
557
+ pd_result = dataframe_pd .where (func , lambda x : x * 2 )
558
+ pandas .testing .assert_frame_equal (bf_result , pd_result )
559
+
560
+
517
561
def test_drop_column (scalars_dfs ):
518
562
scalars_df , scalars_pandas_df = scalars_dfs
519
563
col_name = "int64_col"
0 commit comments