diff --git a/docs/source-pytorch/tuning/profiler.rst b/docs/source-pytorch/tuning/profiler.rst index 1ff7c24ff7dbb..792386e4846a2 100644 --- a/docs/source-pytorch/tuning/profiler.rst +++ b/docs/source-pytorch/tuning/profiler.rst @@ -4,6 +4,29 @@ Find bottlenecks in your code ############################# +.. warning:: + + **Do not wrap** ``Trainer.fit()``, ``Trainer.validate()``, or other Trainer methods + inside a manual ``torch.profiler.profile`` context manager. + This will cause unexpected crashes and cryptic errors due to incompatibility between + PyTorch Profiler's context management and Lightning's internal training loop. + Instead, always use the ``profiler`` argument in the ``Trainer`` constructor. + + Example (correct usage): + + .. code-block:: python + + import pytorch_lightning as pl + + trainer = pl.Trainer( + profiler="pytorch", # <- This enables built-in profiling safely! + ... + ) + trainer.fit(model, train_dataloaders=...) + + **References:** + - https://github.com/pytorch/pytorch/issues/88472 + .. raw:: html
diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 8e4e2de97fd6a..26ef2c1ccc164 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -264,6 +264,14 @@ def __init__( profiler: To profile individual steps during training and assist in identifying bottlenecks. Default: ``None``. + .. note:: + Do **not** use a manual ``torch.profiler.profile`` context manager around + ``Trainer.fit()``, ``Trainer.validate()``, etc. + This will lead to internal errors and cryptic crashes due to incompatibility between + PyTorch Profiler and Lightning's training loop. + Always use this ``profiler`` argument to enable profiling in Lightning. + + detect_anomaly: Enable anomaly detection for the autograd engine. Default: ``False``.