Skip to content

Commit f0f0a57

Browse files
authored
Remove reapply_compile flag
1 parent 8db1a6f commit f0f0a57

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def __init__(
127127
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
128128
sync_batchnorm: bool = False,
129129
reload_dataloaders_every_n_epochs: int = 0,
130-
default_root_dir: Optional[_PATH] = None,
131-
reapply_compile=False,
130+
default_root_dir: Optional[_PATH] = None
132131
) -> None:
133132
r"""Customize every aspect of training via flags.
134133
@@ -305,8 +304,6 @@ def __init__(
305304
if default_root_dir is not None:
306305
default_root_dir = os.fspath(default_root_dir)
307306

308-
self._reapply_compile = reapply_compile
309-
310307
self.barebones = barebones
311308
if barebones:
312309
# opt-outs
@@ -533,8 +530,9 @@ def fit(
533530
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
534531
535532
"""
533+
# when provided compiled model, unwrap and re-do after applied strategy
536534
model, compile_kwargs = (
537-
_unwrap_compiled(model) if self._reapply_compile else (_maybe_unwrap_optimized(model), None)
535+
_unwrap_compiled(model) if isinstance(model, torch._dynamo.OptimizedModule) else (_maybe_unwrap_optimized(model), None)
538536
)
539537
self.strategy._lightning_module = model
540538
_verify_strategy_supports_compile(model, self.strategy)
@@ -548,7 +546,7 @@ def fit(
548546
def _fit_impl(
549547
self,
550548
model: "pl.LightningModule",
551-
compile_kwargs,
549+
compile_kwargs: Optional[Dict[str, Any]] = None,
552550
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
553551
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
554552
datamodule: Optional[LightningDataModule] = None,
@@ -909,7 +907,7 @@ def _predict_impl(
909907
return results
910908

911909
def _run(
912-
self, model: "pl.LightningModule", compile_kwargs, ckpt_path: Optional[_PATH] = None
910+
self, model: "pl.LightningModule", compile_kwargs: Optional[Dict[str, Any]] = None, ckpt_path: Optional[_PATH] = None
913911
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
914912
if self.state.fn == TrainerFn.FITTING:
915913
min_epochs, max_epochs = _parse_loop_limits(
@@ -963,6 +961,7 @@ def _run(
963961
# strategy will configure model and move it to the device
964962
self.strategy.setup(self)
965963

964+
# when provided compiled model, unwrap is done in fit method, re-apply compile after applying strategy
966965
if compile_kwargs is not None:
967966
self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs)
968967

0 commit comments

Comments
 (0)