diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 81c7bfc656885..65aef1ea3b306 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016)) +- fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809)) --- diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 41681fbd239f3..c0b4b9953cc33 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -19,6 +19,7 @@ import os import pstats import tempfile +from collections import defaultdict from pathlib import Path from typing import Optional, Union @@ -66,14 +67,15 @@ def __init__( If you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: dict[str, cProfile.Profile] = {} + self.profiled_actions: dict[str, cProfile.Profile] = defaultdict(cProfile.Profile) self.line_count_restriction = line_count_restriction self.dump_stats = dump_stats @override def start(self, action_name: str) -> None: - if action_name not in self.profiled_actions: - self.profiled_actions[action_name] = cProfile.Profile() + # Disable all profilers before starting a new one + for pr in self.profiled_actions.values(): + pr.disable() self.profiled_actions[action_name].enable() @override @@ -114,7 +116,7 @@ def summary(self) -> str: @override def teardown(self, stage: Optional[str]) -> None: super().teardown(stage=stage) - self.profiled_actions = {} + self.profiled_actions.clear() def __reduce__(self) -> tuple: # avoids `TypeError: cannot pickle 'cProfile.Profile' object` diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index d0221d12e317f..52112769702c0 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -336,6 +336,12 @@ def test_advanced_profiler_deepcopy(advanced_profiler): assert deepcopy(advanced_profiler) +def test_advanced_profiler_nested(advanced_profiler): + """Ensure AdvancedProfiler does not raise ValueError for nested profiling actions (Python 3.12+ compatibility).""" + with advanced_profiler.profile("outer"), advanced_profiler.profile("inner"): + pass # Should not raise ValueError + + @pytest.fixture def pytorch_profiler(tmp_path): return PyTorchProfiler(dirpath=tmp_path, filename="profiler")