2525import os
2626from contextlib import contextmanager
2727from datetime import timedelta
28- from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Union
28+ from typing import Any , Dict , Generator , Iterable , List , Optional , Union
2929from weakref import proxy
3030
3131import torch
7979 LRSchedulerConfig ,
8080)
8181from lightning .pytorch .utilities .warnings import PossibleUserWarning
82+ from lightning .fabric .wrappers import (
83+ _unwrap_compiled ,
84+ _to_compiled
85+ )
8286
8387log = logging .getLogger (__name__ )
8488
@@ -127,7 +131,7 @@ def __init__(
127131 sync_batchnorm : bool = False ,
128132 reload_dataloaders_every_n_epochs : int = 0 ,
129133 default_root_dir : Optional [_PATH ] = None ,
130- compile_fn : Optional [ Callable ] = None ,
134+ reapply_compile = False
131135 ) -> None :
132136 r"""Customize every aspect of training via flags.
133137
@@ -290,9 +294,6 @@ def __init__(
290294 Default: ``os.getcwd()``.
291295 Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
292296
293- compile_fn: Provide torch.compile function to be applied after configuring strategy
294- Default: ``None``.
295-
296297 Raises:
297298 TypeError:
298299 If ``gradient_clip_val`` is not an int or float.
@@ -307,6 +308,8 @@ def __init__(
307308 if default_root_dir is not None :
308309 default_root_dir = os .fspath (default_root_dir )
309310
311+ self ._reapply_compile = reapply_compile
312+
310313 self .barebones = barebones
311314 if barebones :
312315 # opt-outs
@@ -472,8 +475,6 @@ def __init__(
472475 self .should_stop = False
473476 self .state = TrainerState ()
474477
475- self .compile_fn = compile_fn
476-
477478 # configure profiler
478479 setup ._init_profiler (self , profiler )
479480
@@ -535,19 +536,20 @@ def fit(
535536 For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
536537
537538 """
538- model = _maybe_unwrap_optimized (model )
539+ model , compile_kwargs = _unwrap_compiled ( model ) if self . _reapply_compile else ( _maybe_unwrap_optimized (model ), None )
539540 self .strategy ._lightning_module = model
540541 _verify_strategy_supports_compile (model , self .strategy )
541542 self .state .fn = TrainerFn .FITTING
542543 self .state .status = TrainerStatus .RUNNING
543544 self .training = True
544545 call ._call_and_handle_interrupt (
545- self , self ._fit_impl , model , train_dataloaders , val_dataloaders , datamodule , ckpt_path
546+ self , self ._fit_impl , model , compile_kwargs , train_dataloaders , val_dataloaders , datamodule , ckpt_path
546547 )
547548
548549 def _fit_impl (
549550 self ,
550551 model : "pl.LightningModule" ,
552+ compile_kwargs ,
551553 train_dataloaders : Optional [Union [TRAIN_DATALOADERS , LightningDataModule ]] = None ,
552554 val_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
553555 datamodule : Optional [LightningDataModule ] = None ,
@@ -577,7 +579,7 @@ def _fit_impl(
577579 model_provided = True ,
578580 model_connected = self .lightning_module is not None ,
579581 )
580- self ._run (model , ckpt_path = ckpt_path )
582+ self ._run (model , compile_kwargs , ckpt_path = ckpt_path )
581583
582584 assert self .state .stopped
583585 self .training = False
@@ -908,7 +910,7 @@ def _predict_impl(
908910 return results
909911
910912 def _run (
911- self , model : "pl.LightningModule" , ckpt_path : Optional [_PATH ] = None
913+ self , model : "pl.LightningModule" , compile_kwargs , ckpt_path : Optional [_PATH ] = None
912914 ) -> Optional [Union [_EVALUATE_OUTPUT , _PREDICT_OUTPUT ]]:
913915 if self .state .fn == TrainerFn .FITTING :
914916 min_epochs , max_epochs = _parse_loop_limits (
@@ -962,10 +964,9 @@ def _run(
962964 # strategy will configure model and move it to the device
963965 self .strategy .setup (self )
964966
965- # compile if compile_fn provided after configured strategy
966- if self .compile_fn is not None :
967- self .strategy .model = self .compile_fn (self .strategy .model )
968-
967+ if compile_kwargs is not None :
968+ self .strategy .model = _to_compiled (self .strategy .model , compile_kwargs )
969+
969970 # hook
970971 if self .state .fn == TrainerFn .FITTING :
971972 call ._call_callback_hooks (self , "on_fit_start" )
0 commit comments