|
25 | 25 | import os |
26 | 26 | from contextlib import contextmanager |
27 | 27 | from datetime import timedelta |
28 | | -from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable |
| 28 | +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union |
29 | 29 | from weakref import proxy |
30 | 30 |
|
31 | 31 | import torch |
@@ -127,7 +127,7 @@ def __init__( |
127 | 127 | sync_batchnorm: bool = False, |
128 | 128 | reload_dataloaders_every_n_epochs: int = 0, |
129 | 129 | default_root_dir: Optional[_PATH] = None, |
130 | | - compile_fn: Optional[Callable] = None |
| 130 | + compile_fn: Optional[Callable] = None, |
131 | 131 | ) -> None: |
132 | 132 | r"""Customize every aspect of training via flags. |
133 | 133 |
|
@@ -470,7 +470,7 @@ def __init__( |
470 | 470 | self.state = TrainerState() |
471 | 471 |
|
472 | 472 | self.compile_fn = compile_fn |
473 | | - |
| 473 | + |
474 | 474 | # configure profiler |
475 | 475 | setup._init_profiler(self, profiler) |
476 | 476 |
|
@@ -962,7 +962,7 @@ def _run( |
962 | 962 | # compile if compile_fn provided after configured strategy |
963 | 963 | if self.compile_fn is not None: |
964 | 964 | self.strategy.model = self.compile_fn(self.strategy.model) |
965 | | - |
| 965 | + |
966 | 966 | # hook |
967 | 967 | if self.state.fn == TrainerFn.FITTING: |
968 | 968 | call._call_callback_hooks(self, "on_fit_start") |
|
0 commit comments