Skip to content

Commit 591d617

Browse files
awaelchlipre-commit-ci[bot]carmocca
authored andcommitted
[bugfix] Properly name PyTorchProfiler traces (#8009)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 7c58c4d commit 591d617

File tree

5 files changed

+21
-12
lines changed

5 files changed

+21
-12
lines changed

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
3131
from pytorch_lightning.utilities import AMPType
3232
from pytorch_lightning.utilities.apply_func import apply_to_collection
33-
from pytorch_lightning.utilities.distributed import _warn, rank_zero_info, rank_zero_only
33+
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
3434
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3535
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
3636

pytorch_lightning/profiler/profilers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,19 @@ def _rank_zero_info(self, *args, **kwargs) -> None:
120120
if self._local_rank in (None, 0):
121121
log.info(*args, **kwargs)
122122

123-
def _prepare_filename(self, extension: str = ".txt") -> str:
124-
filename = ""
123+
def _prepare_filename(
124+
self, action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-"
125+
) -> str:
126+
args = []
125127
if self._stage is not None:
126-
filename += f"{self._stage}-"
127-
filename += str(self.filename)
128+
args.append(self._stage)
129+
if self.filename:
130+
args.append(self.filename)
128131
if self._local_rank is not None:
129-
filename += f"-{self._local_rank}"
130-
filename += extension
132+
args.append(str(self._local_rank))
133+
if action_name is not None:
134+
args.append(action_name)
135+
filename = split_token.join(args) + extension
131136
return filename
132137

133138
def _prepare_streams(self) -> None:

pytorch_lightning/profiler/pytorch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,15 @@ def stop(self, action_name: str) -> None:
427427
def on_trace_ready(profiler):
428428
if self.dirpath is not None:
429429
if self._export_to_chrome:
430-
handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension=""))
430+
handler = tensorboard_trace_handler(
431+
self.dirpath, self._prepare_filename(action_name=action_name, extension="")
432+
)
431433
handler(profiler)
432434

433435
if self._export_to_flame_graph:
434-
path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack"))
436+
path = os.path.join(
437+
self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack")
438+
)
435439
profiler.export_stacks(path, metric=self._metric)
436440
else:
437441
rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")

tests/deprecated_api/test_remove_1-5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.utilities import device_parser
3131
from pytorch_lightning.utilities.imports import _compare_version
3232
from tests.deprecated_api import no_deprecated_call
33-
from tests.helpers import BoringDataModule, BoringModel
33+
from tests.helpers import BoringModel
3434
from tests.helpers.utils import no_warning_call
3535

3636

tests/test_profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
330330
files = [file for file in files if file.endswith('.json')]
331331
assert len(files) == 2, files
332332
local_rank = trainer.local_rank
333-
assert any(f'training_step_{local_rank}' in f for f in files)
334-
assert any(f'validation_step_{local_rank}' in f for f in files)
333+
assert any(f'{local_rank}-training_step_and_backward' in f for f in files)
334+
assert any(f'{local_rank}-validation_step' in f for f in files)
335335

336336

337337
def test_pytorch_profiler_trainer_test(tmpdir):

0 commit comments

Comments
 (0)