@@ -1937,6 +1937,100 @@ def float_parser(row):
1937
1937
)
1938
1938
1939
1939
1940
+ @pytest .mark .flaky (retries = 2 , delay = 120 )
1941
+ def test_df_apply_axis_1_args (session , scalars_dfs ):
1942
+ columns = ["int64_col" , "int64_too" ]
1943
+ scalars_df , scalars_pandas_df = scalars_dfs
1944
+
1945
+ try :
1946
+
1947
+ def the_sum (s1 , s2 , x ):
1948
+ return s1 + s2 + x
1949
+
1950
+ the_sum_mf = session .remote_function (
1951
+ input_types = [int , int , int ],
1952
+ output_type = int ,
1953
+ reuse = False ,
1954
+ cloud_function_service_account = "default" ,
1955
+ )(the_sum )
1956
+
1957
+ args1 = (1 ,)
1958
+
1959
+ # Fails to apply on dataframe with incompatible number of columns.
1960
+ with pytest .raises (
1961
+ ValueError ,
1962
+ match = "^Column count mismatch: BigFrames BigQuery function expected 2 columns from DataFrame but received 3\\ .$" ,
1963
+ ):
1964
+ scalars_df [columns + ["float64_col" ]].apply (the_sum_mf , axis = 1 , args = args1 )
1965
+
1966
+ # Fails to apply on dataframe with incompatible column datatypes.
1967
+ with pytest .raises (
1968
+ ValueError ,
1969
+ match = "^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*" ,
1970
+ ):
1971
+ scalars_df [columns ].assign (
1972
+ int64_col = lambda df : df ["int64_col" ].astype ("Float64" )
1973
+ ).apply (the_sum_mf , axis = 1 , args = args1 )
1974
+
1975
+ bf_result = (
1976
+ scalars_df [columns ]
1977
+ .dropna ()
1978
+ .apply (the_sum_mf , axis = 1 , args = args1 )
1979
+ .to_pandas ()
1980
+ )
1981
+ pd_result = scalars_pandas_df [columns ].dropna ().apply (sum , axis = 1 , args = args1 )
1982
+
1983
+ pandas .testing .assert_series_equal (pd_result , bf_result , check_dtype = False )
1984
+
1985
+ finally :
1986
+ # clean up the gcp assets created for the remote function.
1987
+ cleanup_function_assets (the_sum_mf , session .bqclient , ignore_failures = False )
1988
+
1989
+
1990
+ @pytest .mark .flaky (retries = 2 , delay = 120 )
1991
+ def test_df_apply_axis_1_series_args (session , scalars_dfs ):
1992
+ columns = ["int64_col" , "float64_col" ]
1993
+ scalars_df , scalars_pandas_df = scalars_dfs
1994
+
1995
+ try :
1996
+
1997
+ @session .remote_function (
1998
+ input_types = [bigframes .series .Series , float , str , bool ],
1999
+ output_type = list [str ],
2000
+ reuse = False ,
2001
+ cloud_function_service_account = "default" ,
2002
+ )
2003
+ def foo_list (x , y0 : float , y1 , y2 ) -> list [str ]:
2004
+ return (
2005
+ [str (x ["int64_col" ]), str (y0 ), str (y1 ), str (y2 )]
2006
+ if y2
2007
+ else [str (x ["float64_col" ])]
2008
+ )
2009
+
2010
+ args1 = (12.34 , "hello world" , True )
2011
+ bf_result = scalars_df [columns ].apply (foo_list , axis = 1 , args = args1 ).to_pandas ()
2012
+ pd_result = scalars_pandas_df [columns ].apply (foo_list , axis = 1 , args = args1 )
2013
+
2014
+ # Ignore any dtype difference.
2015
+ pandas .testing .assert_series_equal (bf_result , pd_result , check_dtype = False )
2016
+
2017
+ args2 = (43.21 , "xxx3yyy" , False )
2018
+ foo_list_ref = session .read_gbq_function (
2019
+ foo_list .bigframes_bigquery_function , is_row_processor = True
2020
+ )
2021
+ bf_result = (
2022
+ scalars_df [columns ].apply (foo_list_ref , axis = 1 , args = args2 ).to_pandas ()
2023
+ )
2024
+ pd_result = scalars_pandas_df [columns ].apply (foo_list , axis = 1 , args = args2 )
2025
+
2026
+ # Ignore any dtype difference.
2027
+ pandas .testing .assert_series_equal (bf_result , pd_result , check_dtype = False )
2028
+
2029
+ finally :
2030
+ # Clean up the gcp assets created for the remote function.
2031
+ cleanup_function_assets (foo_list , session .bqclient , ignore_failures = False )
2032
+
2033
+
1940
2034
@pytest .mark .parametrize (
1941
2035
("memory_mib_args" , "expected_memory" ),
1942
2036
[
0 commit comments