diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index bc30c7740e391..2397bc9c79622 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -565,6 +565,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state", "pyspark.sql.tests.pandas.test_pandas_map", "pyspark.sql.tests.pandas.test_pandas_transform_with_state", + "pyspark.sql.tests.pandas.test_pandas_transform_with_state_checkpoint_v2", "pyspark.sql.tests.pandas.test_pandas_udf", "pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg", "pyspark.sql.tests.pandas.test_pandas_udf_scalar", @@ -1125,6 +1126,7 @@ def __hash__(self): "pyspark.sql.tests.connect.streaming.test_parity_listener", "pyspark.sql.tests.connect.streaming.test_parity_foreach", "pyspark.sql.tests.connect.streaming.test_parity_foreach_batch", + "pyspark.sql.tests.connect.streaming.test_parity_transform_with_state_pyspark", "pyspark.sql.tests.connect.test_resources", "pyspark.sql.tests.connect.shell.test_progress", "pyspark.sql.tests.connect.test_df_debug", diff --git a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py index e772c2139326f..334031ec362f1 100644 --- a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py @@ -18,7 +18,6 @@ from pyspark.sql.tests.pandas.test_pandas_transform_with_state import ( TransformWithStateInPandasTestsMixin, - TransformWithStateInPySparkTestsMixin, ) from pyspark import SparkConf from pyspark.testing.connectutils import ReusedConnectTestCase @@ -54,36 +53,6 @@ def test_schema_evolution_scenarios(self): pass -class TransformWithStateInPySparkParityTests( - TransformWithStateInPySparkTestsMixin, ReusedConnectTestCase -): - """ - Spark connect parity tests for TransformWithStateInPySpark. Run every test case in - `TransformWithStateInPySparkTestsMixin` in spark connect mode. - """ - - @classmethod - def conf(cls): - # Due to multiple inheritance from the same level, we need to explicitly setting configs in - # both TransformWithStateInPySparkTestsMixin and ReusedConnectTestCase here - cfg = SparkConf(loadDefaults=False) - for base in cls.__bases__: - if hasattr(base, "conf"): - parent_cfg = base.conf() - for k, v in parent_cfg.getAll(): - cfg.set(k, v) - - # Extra removing config for connect suites - if cfg._jconf is not None: - cfg._jconf.remove("spark.master") - - return cfg - - @unittest.skip("Flaky in spark connect on CI. Skip for now. See SPARK-51368 for details.") - def test_schema_evolution_scenarios(self): - pass - - if __name__ == "__main__": from pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state import * # noqa: F401,E501 diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_transform_with_state_pyspark.py b/python/pyspark/sql/tests/connect/streaming/test_parity_transform_with_state_pyspark.py new file mode 100644 index 0000000000000..7209b94a50608 --- /dev/null +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_transform_with_state_pyspark.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.sql.tests.pandas.test_pandas_transform_with_state import ( + TransformWithStateInPySparkTestsMixin, +) +from pyspark import SparkConf +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class TransformWithStateInPySparkParityTests( + TransformWithStateInPySparkTestsMixin, ReusedConnectTestCase +): + """ + Spark connect parity tests for TransformWithStateInPySpark. Run every test case in + `TransformWithStateInPySparkTestsMixin` in spark connect mode. + """ + + @classmethod + def conf(cls): + # Due to multiple inheritance from the same level, we need to explicitly setting configs in + # both TransformWithStateInPySparkTestsMixin and ReusedConnectTestCase here + cfg = SparkConf(loadDefaults=False) + for base in cls.__bases__: + if hasattr(base, "conf"): + parent_cfg = base.conf() + for k, v in parent_cfg.getAll(): + cfg.set(k, v) + + # Extra removing config for connect suites + if cfg._jconf is not None: + cfg._jconf.remove("spark.master") + + return cfg + + @unittest.skip("Flaky in spark connect on CI. Skip for now. See SPARK-51368 for details.") + def test_schema_evolution_scenarios(self): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.streaming.test_parity_transform_with_state_pyspark import * # noqa: F401,E501 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index af44093c512df..10755f13344f3 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1916,22 +1916,6 @@ def conf(cls): return cfg -class TransformWithStateInPandasWithCheckpointV2TestsMixin(TransformWithStateInPandasTestsMixin): - @classmethod - def conf(cls): - cfg = super().conf() - cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") - return cfg - - -class TransformWithStateInPySparkWithCheckpointV2TestsMixin(TransformWithStateInPySparkTestsMixin): - @classmethod - def conf(cls): - cfg = super().conf() - cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") - return cfg - - class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): pass @@ -1940,18 +1924,6 @@ class TransformWithStateInPySparkTests(TransformWithStateInPySparkTestsMixin, Re pass -class TransformWithStateInPandasWithCheckpointV2Tests( - TransformWithStateInPandasWithCheckpointV2TestsMixin, ReusedSQLTestCase -): - pass - - -class TransformWithStateInPySparkWithCheckpointV2Tests( - TransformWithStateInPySparkWithCheckpointV2TestsMixin, ReusedSQLTestCase -): - pass - - if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_transform_with_state import * # noqa: F401 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state_checkpoint_v2.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state_checkpoint_v2.py new file mode 100644 index 0000000000000..d6609c44db622 --- /dev/null +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state_checkpoint_v2.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.sql.tests.pandas.test_pandas_transform_with_state import ( + TransformWithStateInPandasTestsMixin, + TransformWithStateInPySparkTestsMixin, +) + + +class TransformWithStateInPandasWithCheckpointV2TestsMixin(TransformWithStateInPandasTestsMixin): + @classmethod + def conf(cls): + cfg = super().conf() + cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + return cfg + + +class TransformWithStateInPySparkWithCheckpointV2TestsMixin(TransformWithStateInPySparkTestsMixin): + @classmethod + def conf(cls): + cfg = super().conf() + cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + return cfg + + +class TransformWithStateInPandasWithCheckpointV2Tests( + TransformWithStateInPandasWithCheckpointV2TestsMixin, ReusedSQLTestCase +): + pass + + +class TransformWithStateInPySparkWithCheckpointV2Tests( + TransformWithStateInPySparkWithCheckpointV2TestsMixin, ReusedSQLTestCase +): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.pandas.test_pandas_transform_with_state_checkpoint_v2 import * # noqa: F401,E501 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)