@@ -127,8 +127,7 @@ def __init__(
127127 plugins : Optional [Union [_PLUGIN_INPUT , List [_PLUGIN_INPUT ]]] = None ,
128128 sync_batchnorm : bool = False ,
129129 reload_dataloaders_every_n_epochs : int = 0 ,
130- default_root_dir : Optional [_PATH ] = None ,
131- reapply_compile = False ,
130+ default_root_dir : Optional [_PATH ] = None
132131 ) -> None :
133132 r"""Customize every aspect of training via flags.
134133
@@ -305,8 +304,6 @@ def __init__(
305304 if default_root_dir is not None :
306305 default_root_dir = os .fspath (default_root_dir )
307306
308- self ._reapply_compile = reapply_compile
309-
310307 self .barebones = barebones
311308 if barebones :
312309 # opt-outs
@@ -533,8 +530,9 @@ def fit(
533530 For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
534531
535532 """
533+ # when provided compiled model, unwrap and re-do after applied strategy
536534 model , compile_kwargs = (
537- _unwrap_compiled (model ) if self . _reapply_compile else (_maybe_unwrap_optimized (model ), None )
535+ _unwrap_compiled (model ) if isinstance ( model , torch . _dynamo . OptimizedModule ) else (_maybe_unwrap_optimized (model ), None )
538536 )
539537 self .strategy ._lightning_module = model
540538 _verify_strategy_supports_compile (model , self .strategy )
@@ -548,7 +546,7 @@ def fit(
548546 def _fit_impl (
549547 self ,
550548 model : "pl.LightningModule" ,
551- compile_kwargs ,
549+ compile_kwargs : Optional [ Dict [ str , Any ]] = None ,
552550 train_dataloaders : Optional [Union [TRAIN_DATALOADERS , LightningDataModule ]] = None ,
553551 val_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
554552 datamodule : Optional [LightningDataModule ] = None ,
@@ -909,7 +907,7 @@ def _predict_impl(
909907 return results
910908
911909 def _run (
912- self , model : "pl.LightningModule" , compile_kwargs , ckpt_path : Optional [_PATH ] = None
910+ self , model : "pl.LightningModule" , compile_kwargs : Optional [ Dict [ str , Any ]] = None , ckpt_path : Optional [_PATH ] = None
913911 ) -> Optional [Union [_EVALUATE_OUTPUT , _PREDICT_OUTPUT ]]:
914912 if self .state .fn == TrainerFn .FITTING :
915913 min_epochs , max_epochs = _parse_loop_limits (
@@ -963,6 +961,7 @@ def _run(
963961 # strategy will configure model and move it to the device
964962 self .strategy .setup (self )
965963
964+ # when provided compiled model, unwrap is done in fit method, re-apply compile after applying strategy
966965 if compile_kwargs is not None :
967966 self .strategy .model = _to_compiled (self .strategy .model , compile_kwargs )
968967
0 commit comments