2929from ....tensor import arange , tensor
3030from ....tensor .random import rand
3131from ....tests .core import require_cudf
32- from ....utils import lazy_import
32+ from ....utils import lazy_import , pd_release_version
3333from ... import eval as mars_eval , cut , qcut , get_dummies
3434from ...datasource .dataframe import from_pandas as from_pandas_df
3535from ...datasource .series import from_pandas as from_pandas_series
3838from ..to_numeric import to_numeric
3939from ..rebalance import DataFrameRebalance
4040
41+ pytestmark = pytest .mark .pd_compat
42+
4143cudf = lazy_import ("cudf" , globals = globals ())
4244
45+ _explode_with_ignore_index = pd_release_version [:2 ] >= (1 , 1 )
46+
4347
4448@require_cudf
4549def test_to_gpu_execution (setup_gpu ):
@@ -1968,7 +1972,12 @@ def test_stack_execution(setup):
19681972 assert_method (result , expected )
19691973
19701974
1971- def test_explode_execution (setup ):
1975+ @pytest .mark .parametrize (
1976+ "ignore_index" , [False , True ] if _explode_with_ignore_index else [False ]
1977+ )
1978+ def test_explode_execution (setup , ignore_index ):
1979+ explode_kw = {"ignore_index" : True } if ignore_index else {}
1980+
19721981 raw = pd .DataFrame (
19731982 {
19741983 "a" : np .random .rand (10 ),
@@ -1978,20 +1987,12 @@ def test_explode_execution(setup):
19781987 }
19791988 )
19801989 df = from_pandas_df (raw , chunk_size = (4 , 2 ))
1981-
1982- for ignore_index in [False , True ]:
1983- r = df .explode ("b" , ignore_index = ignore_index )
1984- pd .testing .assert_frame_equal (
1985- r .execute ().fetch (), raw .explode ("b" , ignore_index = ignore_index )
1986- )
1990+ r = df .explode ("b" , ignore_index = ignore_index )
1991+ pd .testing .assert_frame_equal (r .execute ().fetch (), raw .explode ("b" , ** explode_kw ))
19871992
19881993 series = from_pandas_series (raw .b , chunk_size = 4 )
1989-
1990- for ignore_index in [False , True ]:
1991- r = series .explode (ignore_index = ignore_index )
1992- pd .testing .assert_series_equal (
1993- r .execute ().fetch (), raw .b .explode (ignore_index = ignore_index )
1994- )
1994+ r = series .explode (ignore_index = ignore_index )
1995+ pd .testing .assert_series_equal (r .execute ().fetch (), raw .b .explode (** explode_kw ))
19951996
19961997
19971998def test_eval_query_execution (setup ):
0 commit comments