Skip to content

Commit b45f071

Browse files
committed
[SPARK-54868][PYTHON][INFRA][FOLLOW-UP] Also enable faulthandler in classic tests
### What changes were proposed in this pull request? Also enable `faulthandler` in classic tests, by introducing a base class `PySparkBaseTestCase` ### Why are the changes needed? `faulthandler` was only enabled in `ReusedConnectTestCase` for spark connect, after this change `faulthandler` will be enabled in most classic tests (There are still some tests directly use `unittest.TestCase`, we can change them when they hit hanging issues) ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci and manually check with ``` PYSPARK_TEST_TIMEOUT=10 python/run-tests -k --python-executables python3 --testnames 'pyspark.tests.test_util' ``` ### Was this patch authored or co-authored using generative AI tooling? no Closes #53651 from zhengruifeng/super_class. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 5b5ff6d commit b45f071

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

python/pyspark/testing/connectutils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
import shutil
1818
import tempfile
1919
import os
20-
import sys
21-
import signal
22-
import faulthandler
2320
import functools
2421
import unittest
2522
import uuid
@@ -45,6 +42,7 @@
4542
should_test_connect,
4643
PySparkErrorTestUtils,
4744
)
45+
from pyspark.testing.utils import PySparkBaseTestCase
4846
from pyspark.testing.sqlutils import SQLTestUtils
4947
from pyspark.sql.session import SparkSession as PySparkSession
5048

@@ -75,7 +73,7 @@ def __getattr__(self, item):
7573

7674

7775
@unittest.skipIf(not should_test_connect, connect_requirement_message)
78-
class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
76+
class PlanOnlyTestFixture(PySparkBaseTestCase, PySparkErrorTestUtils):
7977
if should_test_connect:
8078

8179
class MockDF(DataFrame):
@@ -152,7 +150,7 @@ def tearDownClass(cls):
152150

153151

154152
@unittest.skipIf(not should_test_connect, connect_requirement_message)
155-
class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils, PySparkErrorTestUtils):
153+
class ReusedConnectTestCase(PySparkBaseTestCase, SQLTestUtils, PySparkErrorTestUtils):
156154
"""
157155
Spark Connect version of :class:`pyspark.testing.sqlutils.ReusedSQLTestCase`.
158156
"""
@@ -180,8 +178,7 @@ def master(cls):
180178

181179
@classmethod
182180
def setUpClass(cls):
183-
if os.environ.get("PYSPARK_TEST_TIMEOUT"):
184-
faulthandler.register(signal.SIGTERM, file=sys.__stderr__, all_threads=True)
181+
super().setUpClass()
185182

186183
# This environment variable is for interrupting hanging ML-handler and making the
187184
# tests fail fast.
@@ -203,11 +200,11 @@ def setUpClass(cls):
203200

204201
@classmethod
205202
def tearDownClass(cls):
206-
if os.environ.get("PYSPARK_TEST_TIMEOUT"):
207-
faulthandler.unregister(signal.SIGTERM)
208-
209-
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
210-
cls.spark.stop()
203+
try:
204+
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
205+
cls.spark.stop()
206+
finally:
207+
super().tearDownClass()
211208

212209
def setUp(self) -> None:
213210
# force to clean up the ML cache before each test

python/pyspark/testing/sqlutils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils, PySparkErrorTestUti
209209
@classmethod
210210
def setUpClass(cls):
211211
super().setUpClass()
212+
212213
cls._legacy_sc = cls.sc
213214
cls.spark = SparkSession(cls.sc)
214215
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
@@ -218,9 +219,11 @@ def setUpClass(cls):
218219

219220
@classmethod
220221
def tearDownClass(cls):
221-
super().tearDownClass()
222-
cls.spark.stop()
223-
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
222+
try:
223+
cls.spark.stop()
224+
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
225+
finally:
226+
super().tearDownClass()
224227

225228
def tearDown(self):
226229
try:

python/pyspark/testing/utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import sys
2121
import unittest
2222
import difflib
23+
import faulthandler
2324
import functools
2425
from decimal import Decimal
2526
from time import time, sleep
@@ -268,7 +269,19 @@ def __exit__(self, exc_type, exc_val, exc_tb):
268269
self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
269270

270271

271-
class PySparkTestCase(unittest.TestCase):
272+
class PySparkBaseTestCase(unittest.TestCase):
273+
@classmethod
274+
def setUpClass(cls):
275+
if os.environ.get("PYSPARK_TEST_TIMEOUT"):
276+
faulthandler.register(signal.SIGTERM, file=sys.__stderr__, all_threads=True)
277+
278+
@classmethod
279+
def tearDownClass(cls):
280+
if os.environ.get("PYSPARK_TEST_TIMEOUT"):
281+
faulthandler.unregister(signal.SIGTERM)
282+
283+
284+
class PySparkTestCase(PySparkBaseTestCase):
272285
def setUp(self):
273286
from pyspark import SparkContext
274287

@@ -281,7 +294,7 @@ def tearDown(self):
281294
sys.path = self._old_sys_path
282295

283296

284-
class ReusedPySparkTestCase(unittest.TestCase):
297+
class ReusedPySparkTestCase(PySparkBaseTestCase):
285298
@classmethod
286299
def conf(cls):
287300
"""
@@ -291,6 +304,8 @@ def conf(cls):
291304

292305
@classmethod
293306
def setUpClass(cls):
307+
super().setUpClass()
308+
294309
from pyspark import SparkContext
295310

296311
cls.sc = SparkContext(cls.master(), cls.__name__, conf=cls.conf())
@@ -301,7 +316,10 @@ def master(cls):
301316

302317
@classmethod
303318
def tearDownClass(cls):
304-
cls.sc.stop()
319+
try:
320+
cls.sc.stop()
321+
finally:
322+
super().tearDownClass()
305323

306324
def test_assert_classic_mode(self):
307325
from pyspark.sql import is_remote

0 commit comments

Comments
 (0)