Skip to content

Commit a5ad966

Browse files
justusschockcarmocca
authored andcommitted
[bugfix] Resolve PyTorch Profiling for Manual Optimization (#9316)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 4cf6be9 commit a5ad966

File tree

6 files changed

+117
-64
lines changed

6 files changed

+117
-64
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))
3838

3939

40+
- Fixed PyTorch Profiler not enabled for manual optimization ([#9316](https://github.com/PyTorchLightning/pytorch-lightning/pull/9316))
41+
42+
4043
- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125))
4144

4245
## [1.4.5] - 2021-08-31

pl_examples/basic_examples/profiler_example.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from pl_examples import _DATASETS_PATH, cli_lightning_logo
3333
from pytorch_lightning import LightningDataModule, LightningModule
34+
from pytorch_lightning.profiler.pytorch import PyTorchProfiler
3435
from pytorch_lightning.utilities.cli import LightningCLI
3536

3637
DEFAULT_CMD_LINE = (
@@ -43,18 +44,34 @@
4344

4445

4546
class ModelToProfile(LightningModule):
46-
def __init__(self, name: str = "resnet50"):
47+
def __init__(self, name: str = "resnet18", automatic_optimization: bool = True):
4748
super().__init__()
4849
self.model = getattr(models, name)(pretrained=True)
4950
self.criterion = torch.nn.CrossEntropyLoss()
51+
self.automatic_optimization = automatic_optimization
52+
self.training_step = (
53+
self.automatic_optimization_training_step
54+
if automatic_optimization
55+
else self.manual_optimization_training_step
56+
)
5057

51-
def training_step(self, batch, batch_idx):
58+
def automatic_optimization_training_step(self, batch, batch_idx):
5259
inputs, labels = batch
5360
outputs = self.model(inputs)
5461
loss = self.criterion(outputs, labels)
5562
self.log("train_loss", loss)
5663
return loss
5764

65+
def manual_optimization_training_step(self, batch, batch_idx):
66+
opt = self.optimizers()
67+
opt.zero_grad()
68+
inputs, labels = batch
69+
outputs = self.model(inputs)
70+
loss = self.criterion(outputs, labels)
71+
self.log("train_loss", loss)
72+
self.manual_backward(loss)
73+
opt.step()
74+
5875
def validation_step(self, batch, batch_idx):
5976
inputs, labels = batch
6077
outputs = self.model(inputs)
@@ -77,18 +94,20 @@ def train_dataloader(self, *args, **kwargs):
7794
trainset = torchvision.datasets.CIFAR10(
7895
root=_DATASETS_PATH, train=True, download=True, transform=self.transform
7996
)
80-
return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0)
97+
return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)
8198

8299
def val_dataloader(self, *args, **kwargs):
83100
valset = torchvision.datasets.CIFAR10(root=_DATASETS_PATH, train=False, download=True, transform=self.transform)
84-
return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0)
101+
return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0)
85102

86103

87104
def cli_main():
88105
if len(sys.argv) == 1:
89106
sys.argv += DEFAULT_CMD_LINE
90107

91-
LightningCLI(ModelToProfile, CIFAR10DataModule)
108+
LightningCLI(
109+
ModelToProfile, CIFAR10DataModule, save_config_overwrite=True, trainer_defaults={"profiler": PyTorchProfiler()}
110+
)
92111

93112

94113
if __name__ == "__main__":

pytorch_lightning/profiler/pytorch.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import inspect
1616
import logging
1717
import os
18-
from functools import partial
18+
from functools import lru_cache, partial
1919
from pathlib import Path
2020
from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union
2121

@@ -24,9 +24,10 @@
2424
from torch.autograd.profiler import record_function
2525

2626
from pytorch_lightning.profiler.base import BaseProfiler
27-
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
27+
from pytorch_lightning.utilities import rank_zero_warn
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2929
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
30+
from pytorch_lightning.utilities.warnings import WarningCache
3031

3132
if TYPE_CHECKING:
3233
from torch.autograd.profiler import EventList
@@ -38,6 +39,7 @@
3839
from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler
3940

4041
log = logging.getLogger(__name__)
42+
warning_cache = WarningCache()
4143

4244
_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx]
4345

@@ -116,6 +118,7 @@ def pre_step(self, current_action: str) -> None:
116118
self._current_action = current_action
117119

118120
def reset(self):
121+
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
119122
self._num_optimizer_step_and_closure = 0
120123
self._num_validation_step = 0
121124
self._num_test_step = 0
@@ -128,9 +131,15 @@ def reset(self):
128131
self._current_action: Optional[str] = None
129132
self._start_action_name: Optional[str] = None
130133

134+
@property
135+
def is_training(self) -> bool:
136+
return self._current_action is not None and (
137+
self._current_action.startswith("optimizer_step_and_closure_") or self._current_action == "training_step"
138+
)
139+
131140
@property
132141
def num_step(self) -> int:
133-
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
142+
if self.is_training:
134143
return self._num_optimizer_step_and_closure
135144
if self._current_action == "validation_step":
136145
return self._num_validation_step
@@ -141,7 +150,7 @@ def num_step(self) -> int:
141150
return 0
142151

143152
def _step(self) -> None:
144-
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
153+
if self.is_training:
145154
self._num_optimizer_step_and_closure += 1
146155
elif self._current_action == "validation_step":
147156
if self._start_action_name == "on_fit_start":
@@ -156,7 +165,7 @@ def _step(self) -> None:
156165

157166
@property
158167
def has_finished(self) -> bool:
159-
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
168+
if self.is_training:
160169
return self._optimizer_step_and_closure_reached_end
161170
if self._current_action == "validation_step":
162171
return self._validation_step_reached_end
@@ -172,9 +181,9 @@ def __call__(self, num_step: int) -> "ProfilerAction":
172181
return ProfilerAction.NONE
173182

174183
self._step()
175-
action = self._schedule(self.num_step)
184+
action = self._schedule(max(self.num_step, 0))
176185
if action == ProfilerAction.RECORD_AND_SAVE:
177-
if self._current_action is not None and self._current_action.startswith("optimizer_step_and_closure_"):
186+
if self.is_training:
178187
self._optimizer_step_and_closure_reached_end = True
179188
elif self._current_action == "validation_step":
180189
self._validation_step_reached_end = True
@@ -196,7 +205,7 @@ class PyTorchProfiler(BaseProfiler):
196205
"predict_step",
197206
}
198207
RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_"
199-
STEP_FUNCTIONS = {"validation_step", "test_step", "predict_step"}
208+
STEP_FUNCTIONS = {"training_step", "validation_step", "test_step", "predict_step"}
200209
STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_"
201210
AVAILABLE_SORT_KEYS = {
202211
"cpu_time",
@@ -320,6 +329,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
320329
raise MisconfigurationException(
321330
f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}"
322331
)
332+
self._default_schedule()
323333
schedule = schedule if has_schedule else self._default_schedule()
324334
self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule
325335
self._profiler_kwargs["schedule"] = self._schedule
@@ -331,28 +341,13 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
331341
with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph
332342
self._profiler_kwargs["with_stack"] = with_stack
333343

334-
def __deprecation_check(
335-
self, profiled_functions: Optional[List[str]], record_functions: Optional[Set[str]]
336-
) -> Set[str]:
337-
if record_functions is None:
338-
record_functions = set()
339-
340-
if profiled_functions is not None:
341-
rank_zero_deprecation(
342-
"`PyTorchProfiler.profiled_functions` has been renamed to"
343-
" `record_functions` in v1.3 and will be removed in v1.5"
344-
)
345-
if not record_functions:
346-
record_functions |= set(profiled_functions)
347-
else:
348-
raise MisconfigurationException(
349-
"You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`."
350-
" Please use only the later."
351-
)
352-
353-
return record_functions
344+
def _should_override_schedule(self) -> bool:
345+
return (self._lightning_module is not None and self._lightning_module.trainer.limit_train_batches < 5) and (
346+
self._schedule is not None and self._schedule._schedule == self._default_schedule()
347+
)
354348

355349
@staticmethod
350+
@lru_cache(1)
356351
def _default_schedule() -> Optional[callable]:
357352
if _KINETO_AVAILABLE:
358353
# Those schedule defaults allow the profiling overhead to be negligible over training time.
@@ -393,11 +388,18 @@ def start(self, action_name: str) -> None:
393388
if self._register is not None:
394389
self._register.__enter__()
395390

391+
if self._lightning_module is not None:
392+
# when the model is used in automatic optimization,
393+
# we use `optimizer_step_and_closure` to step the model.
394+
if self._lightning_module.automatic_optimization and "training_step" in self.STEP_FUNCTIONS:
395+
self.STEP_FUNCTIONS.remove("training_step")
396+
396397
if (
397398
self.profiler is not None
398399
and (action_name in self._record_functions or action_name.startswith(self.RECORD_FUNCTION_PREFIX))
399400
and action_name not in self._recording_map
400401
):
402+
401403
recording = record_function(action_name)
402404
recording.__enter__()
403405
self._recording_map[action_name] = recording
@@ -413,6 +415,17 @@ def stop(self, action_name: str) -> None:
413415
if self.profiler is not None and (
414416
action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX)
415417
):
418+
419+
# the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`.
420+
# otherwise, this will raise a `segmentation fault`.
421+
if self._should_override_schedule():
422+
warning_cache.warn(
423+
"The PyTorch Profiler default schedule will be overridden as there is not enough "
424+
"steps to properly record traces."
425+
)
426+
self._schedule = None
427+
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn
428+
416429
if self._schedule is not None:
417430
self._schedule.pre_step(action_name)
418431

tests/helpers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
1-
from tests.helpers.boring_model import BoringDataModule, BoringModel, RandomDataset # noqa: F401
1+
from tests.helpers.boring_model import ( # noqa: F401
2+
BoringDataModule,
3+
BoringModel,
4+
ManualOptimBoringModel,
5+
RandomDataset,
6+
)
27
from tests.helpers.datasets import TrialMNIST # noqa: F401

tests/helpers/boring_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,18 @@ def test_dataloader(self):
184184

185185
def predict_dataloader(self):
186186
return DataLoader(self.random_predict)
187+
188+
189+
class ManualOptimBoringModel(BoringModel):
190+
def __init__(self):
191+
super().__init__()
192+
self.automatic_optimization = False
193+
194+
def training_step(self, batch, batch_idx):
195+
opt = self.optimizers()
196+
output = self(batch)
197+
loss = self.loss(batch, output)
198+
opt.zero_grad()
199+
self.manual_backward(loss)
200+
opt.step()
201+
return loss

tests/profiler/test_profiler.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3131
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
32-
from tests.helpers import BoringModel
32+
from tests.helpers import BoringModel, ManualOptimBoringModel
3333
from tests.helpers.runif import RunIf
3434

3535
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
@@ -309,50 +309,48 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
309309
assert any(f"{local_rank}-validation_step" in f for f in files)
310310

311311

312-
def test_pytorch_profiler_trainer_test(tmpdir):
312+
@pytest.mark.parametrize("fast_dev_run", [1, 2, 3, 4, 5])
313+
@pytest.mark.parametrize("boring_model_cls", [ManualOptimBoringModel, BoringModel])
314+
def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
313315
"""Ensure that the profiler can be given to the trainer and test step are properly recorded."""
314-
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None)
315-
model = BoringModel()
316-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
317-
trainer.test(model)
316+
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile")
317+
model = boring_model_cls()
318+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run, profiler=pytorch_profiler)
319+
trainer.fit(model)
318320

319-
assert sum(e.name == "test_step" for e in pytorch_profiler.function_events)
321+
assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events)
320322

321-
path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt"
323+
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
322324
assert path.read_text("utf-8")
323325

324326
if _KINETO_AVAILABLE:
325327
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
326-
assert any(f"test-{pytorch_profiler.filename}" in f for f in files)
327-
path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt"
328+
assert any(f"fit-{pytorch_profiler.filename}" in f for f in files)
329+
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
328330
assert path.read_text("utf-8")
329331

330332

331-
def test_pytorch_profiler_trainer_predict(tmpdir):
332-
"""Ensure that the profiler can be given to the trainer and predict function are properly recorded."""
333+
@pytest.mark.parametrize("fn, step_name", [("test", "test"), ("validate", "validation"), ("predict", "predict")])
334+
@pytest.mark.parametrize("boring_model_cls", [BoringModel, ManualOptimBoringModel])
335+
def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
336+
"""Ensure that the profiler can be given to the trainer and test step are properly recorded."""
333337
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None)
334-
model = BoringModel()
338+
model = boring_model_cls()
335339
model.predict_dataloader = model.train_dataloader
336-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_predict_batches=2, profiler=pytorch_profiler)
337-
trainer.predict(model)
338-
339-
assert sum(e.name == "predict_step" for e in pytorch_profiler.function_events)
340-
path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt"
341-
assert path.read_text("utf-8")
342-
343-
344-
def test_pytorch_profiler_trainer_validate(tmpdir):
345-
"""Ensure that the profiler can be given to the trainer and validate function are properly recorded."""
346-
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None)
347-
model = BoringModel()
348-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=2, profiler=pytorch_profiler)
349-
trainer.validate(model)
340+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
341+
getattr(trainer, fn)(model)
350342

351-
assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events)
343+
assert sum(e.name == f"{step_name}_step" for e in pytorch_profiler.function_events)
352344

353-
path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt"
345+
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
354346
assert path.read_text("utf-8")
355347

348+
if _KINETO_AVAILABLE:
349+
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
350+
assert any(f"{fn}-{pytorch_profiler.filename}" in f for f in files)
351+
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
352+
assert path.read_text("utf-8")
353+
356354

357355
def test_pytorch_profiler_nested(tmpdir):
358356
"""Ensure that the profiler handles nested context"""
@@ -467,7 +465,7 @@ def on_fit_end(self, trainer, *args, **kwargs) -> None:
467465

468466
profiler = cls(dirpath=tmpdir, filename="profiler")
469467
model = BoringModel()
470-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()])
468+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1, profiler=profiler, callbacks=[TestCallback()])
471469
trainer.fit(model)
472470

473471
assert profiler._output_file is None

0 commit comments

Comments
 (0)