Skip to content

Commit 86d2c70

Browse files
authored
Test reapply_compile for trainer
1 parent 1946070 commit 86d2c70

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import os
2626
from contextlib import contextmanager
2727
from datetime import timedelta
28-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
28+
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
2929
from weakref import proxy
3030

3131
import torch
@@ -79,6 +79,10 @@
7979
LRSchedulerConfig,
8080
)
8181
from lightning.pytorch.utilities.warnings import PossibleUserWarning
82+
from lightning.fabric.wrappers import (
83+
_unwrap_compiled,
84+
_to_compiled
85+
)
8286

8387
log = logging.getLogger(__name__)
8488

@@ -127,7 +131,7 @@ def __init__(
127131
sync_batchnorm: bool = False,
128132
reload_dataloaders_every_n_epochs: int = 0,
129133
default_root_dir: Optional[_PATH] = None,
130-
compile_fn: Optional[Callable] = None,
134+
reapply_compile = False
131135
) -> None:
132136
r"""Customize every aspect of training via flags.
133137
@@ -290,9 +294,6 @@ def __init__(
290294
Default: ``os.getcwd()``.
291295
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
292296
293-
compile_fn: Provide torch.compile function to be applied after configuring strategy
294-
Default: ``None``.
295-
296297
Raises:
297298
TypeError:
298299
If ``gradient_clip_val`` is not an int or float.
@@ -307,6 +308,8 @@ def __init__(
307308
if default_root_dir is not None:
308309
default_root_dir = os.fspath(default_root_dir)
309310

311+
self._reapply_compile = reapply_compile
312+
310313
self.barebones = barebones
311314
if barebones:
312315
# opt-outs
@@ -472,8 +475,6 @@ def __init__(
472475
self.should_stop = False
473476
self.state = TrainerState()
474477

475-
self.compile_fn = compile_fn
476-
477478
# configure profiler
478479
setup._init_profiler(self, profiler)
479480

@@ -535,19 +536,20 @@ def fit(
535536
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
536537
537538
"""
538-
model = _maybe_unwrap_optimized(model)
539+
model, compile_kwargs = _unwrap_compiled(model) if self._reapply_compile else (_maybe_unwrap_optimized(model), None)
539540
self.strategy._lightning_module = model
540541
_verify_strategy_supports_compile(model, self.strategy)
541542
self.state.fn = TrainerFn.FITTING
542543
self.state.status = TrainerStatus.RUNNING
543544
self.training = True
544545
call._call_and_handle_interrupt(
545-
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546+
self, self._fit_impl, model, compile_kwargs, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546547
)
547548

548549
def _fit_impl(
549550
self,
550551
model: "pl.LightningModule",
552+
compile_kwargs,
551553
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
552554
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
553555
datamodule: Optional[LightningDataModule] = None,
@@ -577,7 +579,7 @@ def _fit_impl(
577579
model_provided=True,
578580
model_connected=self.lightning_module is not None,
579581
)
580-
self._run(model, ckpt_path=ckpt_path)
582+
self._run(model, compile_kwargs, ckpt_path=ckpt_path)
581583

582584
assert self.state.stopped
583585
self.training = False
@@ -908,7 +910,7 @@ def _predict_impl(
908910
return results
909911

910912
def _run(
911-
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
913+
self, model: "pl.LightningModule", compile_kwargs, ckpt_path: Optional[_PATH] = None
912914
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
913915
if self.state.fn == TrainerFn.FITTING:
914916
min_epochs, max_epochs = _parse_loop_limits(
@@ -962,10 +964,9 @@ def _run(
962964
# strategy will configure model and move it to the device
963965
self.strategy.setup(self)
964966

965-
# compile if compile_fn provided after configured strategy
966-
if self.compile_fn is not None:
967-
self.strategy.model = self.compile_fn(self.strategy.model)
968-
967+
if compile_kwargs is not None:
968+
self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs)
969+
969970
# hook
970971
if self.state.fn == TrainerFn.FITTING:
971972
call._call_callback_hooks(self, "on_fit_start")

0 commit comments

Comments
 (0)