diff --git a/dev/requirements.txt b/dev/requirements.txt index 7eb157352408a..a64f9c4cc50a8 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -77,6 +77,8 @@ graphviz==0.20.3 flameprof==0.4 viztracer debugpy +pystack>=1.5.1; python_version!='3.13' and sys_platform=='linux' # no 3.13t wheels +psutil # TorchDistributor dependencies torch diff --git a/dev/spark-test-image/python-311/Dockerfile b/dev/spark-test-image/python-311/Dockerfile index 50c042c9da1a5..0db52262fa841 100644 --- a/dev/spark-test-image/python-311/Dockerfile +++ b/dev/spark-test-image/python-311/Dockerfile @@ -68,7 +68,7 @@ RUN apt-get update && apt-get install -y \ && rm -rf /var/lib/apt/lists/* -ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 pystack psutil" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3" diff --git a/python/pyspark/threaddump.py b/python/pyspark/threaddump.py new file mode 100644 index 0000000000000..0cf0feb810962 --- /dev/null +++ b/python/pyspark/threaddump.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import sys + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Dump threads of a process and its children") + parser.add_argument("-p", "--pid", type=int, required=True, help="The PID to dump") + return parser + + +def main() -> int: + try: + import psutil + from pystack.__main__ import main as pystack_main # type: ignore + except ImportError: + print("pystack and psutil are not installed") + return 1 + + parser = build_parser() + args = parser.parse_args() + + try: + pids = [args.pid] + [ + child.pid + for child in psutil.Process(args.pid).children(recursive=True) + if "python" in child.exe() + ] + except Exception as e: + print(f"Error getting children of process {args.pid}: {e}") + return 2 + + for pid in pids: + sys.argv = ["pystack", "remote", str(pid)] + try: + print(f"Dumping threads for process {pid}") + pystack_main() + except Exception: + # We might tried to dump a process that is not a Python process + pass + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/python/run-tests.py b/python/run-tests.py index b3522a13df4af..2676a67d96879 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -113,7 +113,7 @@ def run(self): try: asyncio.run(self.handle_inout()) except subprocess.TimeoutExpired: - LOGGER.error(f"Test {self.test_name} timed out") + LOGGER.error(f"Test {self.test_name} timed out after {self.timeout} seconds") try: return self.p.wait(timeout=30) except subprocess.TimeoutExpired: @@ -204,9 +204,22 @@ async def check_timeout(self): # We don't want to kill the process if it's in pdb mode return if self.p.poll() is None: + if sys.platform == "linux": + self.thread_dump(self.p.pid) self.p.terminate() raise subprocess.TimeoutExpired(self.cmd, self.timeout) + def thread_dump(self, pid): + pyspark_python = self.env['PYSPARK_PYTHON'] + p = subprocess.run( + [pyspark_python, "-m", "pyspark.threaddump", "-p", str(pid)], + env={**self.env, "PYTHONPATH": f"{os.path.join(SPARK_HOME, 'python')}:{os.environ.get('PYTHONPATH', '')}"}, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + if p.returncode == 0: + LOGGER.error(f"Thread dump:\n{p.stdout.decode('utf-8')}") + def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_output): """