|
25 | 25 | import os
|
26 | 26 | from contextlib import contextmanager
|
27 | 27 | from datetime import timedelta
|
28 |
| -from typing import Any, Dict, Generator, Iterable, List, Optional, Union |
| 28 | +from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable |
29 | 29 | from weakref import proxy
|
30 | 30 |
|
31 | 31 | import torch
|
@@ -127,6 +127,7 @@ def __init__(
|
127 | 127 | sync_batchnorm: bool = False,
|
128 | 128 | reload_dataloaders_every_n_epochs: int = 0,
|
129 | 129 | default_root_dir: Optional[_PATH] = None,
|
| 130 | + compile_fn: Optional[Callable] = None |
130 | 131 | ) -> None:
|
131 | 132 | r"""Customize every aspect of training via flags.
|
132 | 133 |
|
@@ -468,6 +469,8 @@ def __init__(
|
468 | 469 | self.should_stop = False
|
469 | 470 | self.state = TrainerState()
|
470 | 471 |
|
| 472 | + self.compile_fn = compile_fn |
| 473 | + |
471 | 474 | # configure profiler
|
472 | 475 | setup._init_profiler(self, profiler)
|
473 | 476 |
|
@@ -956,6 +959,10 @@ def _run(
|
956 | 959 | # strategy will configure model and move it to the device
|
957 | 960 | self.strategy.setup(self)
|
958 | 961 |
|
| 962 | + # compile if compile_fn provided after configured strategy |
| 963 | + if self.compile_fn is not None: |
| 964 | + self.strategy.model = self.compile_fn(self.strategy.model) |
| 965 | + |
959 | 966 | # hook
|
960 | 967 | if self.state.fn == TrainerFn.FITTING:
|
961 | 968 | call._call_callback_hooks(self, "on_fit_start")
|
|
0 commit comments