Skip to content

Commit 1bc2ce7

Browse files
committed
Add compile_fn for Trainer
1 parent 06a8d5b commit 1bc2ce7

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import os
2626
from contextlib import contextmanager
2727
from datetime import timedelta
28-
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
28+
from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable
2929
from weakref import proxy
3030

3131
import torch
@@ -127,6 +127,7 @@ def __init__(
127127
sync_batchnorm: bool = False,
128128
reload_dataloaders_every_n_epochs: int = 0,
129129
default_root_dir: Optional[_PATH] = None,
130+
compile_fn: Optional[Callable] = None
130131
) -> None:
131132
r"""Customize every aspect of training via flags.
132133
@@ -468,6 +469,8 @@ def __init__(
468469
self.should_stop = False
469470
self.state = TrainerState()
470471

472+
self.compile_fn = compile_fn
473+
471474
# configure profiler
472475
setup._init_profiler(self, profiler)
473476

@@ -956,6 +959,10 @@ def _run(
956959
# strategy will configure model and move it to the device
957960
self.strategy.setup(self)
958961

962+
# compile if compile_fn provided after configured strategy
963+
if self.compile_fn is not None:
964+
self.strategy.model = self.compile_fn(self.strategy.model)
965+
959966
# hook
960967
if self.state.fn == TrainerFn.FITTING:
961968
call._call_callback_hooks(self, "on_fit_start")

0 commit comments

Comments
 (0)