1515import inspect
1616import logging
1717import os
18- from functools import partial
18+ from functools import lru_cache , partial
1919from pathlib import Path
2020from typing import Any , Callable , Dict , List , Optional , Set , Type , TYPE_CHECKING , Union
2121
2424from torch .autograd .profiler import record_function
2525
2626from pytorch_lightning .profiler .base import BaseProfiler
27- from pytorch_lightning .utilities import rank_zero_deprecation , rank_zero_warn
27+ from pytorch_lightning .utilities import rank_zero_warn
2828from pytorch_lightning .utilities .exceptions import MisconfigurationException
2929from pytorch_lightning .utilities .imports import _KINETO_AVAILABLE
30+ from pytorch_lightning .utilities .warnings import WarningCache
3031
3132if TYPE_CHECKING :
3233 from torch .autograd .profiler import EventList
3839 from torch .profiler import ProfilerAction , ProfilerActivity , tensorboard_trace_handler
3940
4041log = logging .getLogger (__name__ )
42+ warning_cache = WarningCache ()
4143
4244_PROFILER = Union [torch .autograd .profiler .profile , torch .cuda .profiler .profile , torch .autograd .profiler .emit_nvtx ]
4345
@@ -116,6 +118,7 @@ def pre_step(self, current_action: str) -> None:
116118 self ._current_action = current_action
117119
118120 def reset (self ):
121+ # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
119122 self ._num_optimizer_step_and_closure = 0
120123 self ._num_validation_step = 0
121124 self ._num_test_step = 0
@@ -128,9 +131,15 @@ def reset(self):
128131 self ._current_action : Optional [str ] = None
129132 self ._start_action_name : Optional [str ] = None
130133
134+ @property
135+ def is_training (self ) -> bool :
136+ return self ._current_action is not None and (
137+ self ._current_action .startswith ("optimizer_step_and_closure_" ) or self ._current_action == "training_step"
138+ )
139+
131140 @property
132141 def num_step (self ) -> int :
133- if self ._current_action is not None and self . _current_action . startswith ( "optimizer_step_and_closure_" ) :
142+ if self .is_training :
134143 return self ._num_optimizer_step_and_closure
135144 if self ._current_action == "validation_step" :
136145 return self ._num_validation_step
@@ -141,7 +150,7 @@ def num_step(self) -> int:
141150 return 0
142151
143152 def _step (self ) -> None :
144- if self ._current_action is not None and self . _current_action . startswith ( "optimizer_step_and_closure_" ) :
153+ if self .is_training :
145154 self ._num_optimizer_step_and_closure += 1
146155 elif self ._current_action == "validation_step" :
147156 if self ._start_action_name == "on_fit_start" :
@@ -156,7 +165,7 @@ def _step(self) -> None:
156165
157166 @property
158167 def has_finished (self ) -> bool :
159- if self ._current_action is not None and self . _current_action . startswith ( "optimizer_step_and_closure_" ) :
168+ if self .is_training :
160169 return self ._optimizer_step_and_closure_reached_end
161170 if self ._current_action == "validation_step" :
162171 return self ._validation_step_reached_end
@@ -172,9 +181,9 @@ def __call__(self, num_step: int) -> "ProfilerAction":
172181 return ProfilerAction .NONE
173182
174183 self ._step ()
175- action = self ._schedule (self .num_step )
184+ action = self ._schedule (max ( self .num_step , 0 ) )
176185 if action == ProfilerAction .RECORD_AND_SAVE :
177- if self ._current_action is not None and self . _current_action . startswith ( "optimizer_step_and_closure_" ) :
186+ if self .is_training :
178187 self ._optimizer_step_and_closure_reached_end = True
179188 elif self ._current_action == "validation_step" :
180189 self ._validation_step_reached_end = True
@@ -196,7 +205,7 @@ class PyTorchProfiler(BaseProfiler):
196205 "predict_step" ,
197206 }
198207 RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_"
199- STEP_FUNCTIONS = {"validation_step" , "test_step" , "predict_step" }
208+ STEP_FUNCTIONS = {"training_step" , " validation_step" , "test_step" , "predict_step" }
200209 STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_"
201210 AVAILABLE_SORT_KEYS = {
202211 "cpu_time" ,
@@ -320,6 +329,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
320329 raise MisconfigurationException (
321330 f"Schedule should return a `torch.profiler.ProfilerAction`. Found: { action } "
322331 )
332+ self ._default_schedule ()
323333 schedule = schedule if has_schedule else self ._default_schedule ()
324334 self ._schedule = ScheduleWrapper (schedule ) if schedule is not None else schedule
325335 self ._profiler_kwargs ["schedule" ] = self ._schedule
@@ -331,28 +341,13 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
331341 with_stack = profiler_kwargs .get ("with_stack" , False ) or self ._export_to_flame_graph
332342 self ._profiler_kwargs ["with_stack" ] = with_stack
333343
334- def __deprecation_check (
335- self , profiled_functions : Optional [List [str ]], record_functions : Optional [Set [str ]]
336- ) -> Set [str ]:
337- if record_functions is None :
338- record_functions = set ()
339-
340- if profiled_functions is not None :
341- rank_zero_deprecation (
342- "`PyTorchProfiler.profiled_functions` has been renamed to"
343- " `record_functions` in v1.3 and will be removed in v1.5"
344- )
345- if not record_functions :
346- record_functions |= set (profiled_functions )
347- else :
348- raise MisconfigurationException (
349- "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`."
350- " Please use only the later."
351- )
352-
353- return record_functions
344+ def _should_override_schedule (self ) -> bool :
345+ return (self ._lightning_module is not None and self ._lightning_module .trainer .limit_train_batches < 5 ) and (
346+ self ._schedule is not None and self ._schedule ._schedule == self ._default_schedule ()
347+ )
354348
355349 @staticmethod
350+ @lru_cache (1 )
356351 def _default_schedule () -> Optional [callable ]:
357352 if _KINETO_AVAILABLE :
358353 # Those schedule defaults allow the profiling overhead to be negligible over training time.
@@ -393,11 +388,18 @@ def start(self, action_name: str) -> None:
393388 if self ._register is not None :
394389 self ._register .__enter__ ()
395390
391+ if self ._lightning_module is not None :
392+ # when the model is used in automatic optimization,
393+ # we use `optimizer_step_and_closure` to step the model.
394+ if self ._lightning_module .automatic_optimization and "training_step" in self .STEP_FUNCTIONS :
395+ self .STEP_FUNCTIONS .remove ("training_step" )
396+
396397 if (
397398 self .profiler is not None
398399 and (action_name in self ._record_functions or action_name .startswith (self .RECORD_FUNCTION_PREFIX ))
399400 and action_name not in self ._recording_map
400401 ):
402+
401403 recording = record_function (action_name )
402404 recording .__enter__ ()
403405 self ._recording_map [action_name ] = recording
@@ -413,6 +415,17 @@ def stop(self, action_name: str) -> None:
413415 if self .profiler is not None and (
414416 action_name in self .STEP_FUNCTIONS or action_name .startswith (self .STEP_FUNCTION_PREFIX )
415417 ):
418+
419+ # the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`.
420+ # otherwise, this will raise a `segmentation fault`.
421+ if self ._should_override_schedule ():
422+ warning_cache .warn (
423+ "The PyTorch Profiler default schedule will be overridden as there is not enough "
424+ "steps to properly record traces."
425+ )
426+ self ._schedule = None
427+ self .profiler .schedule = torch .profiler .profiler ._default_schedule_fn
428+
416429 if self ._schedule is not None :
417430 self ._schedule .pre_step (action_name )
418431
0 commit comments