|
25 | 25 | import os
|
26 | 26 | from contextlib import contextmanager
|
27 | 27 | from datetime import timedelta
|
28 |
| -from typing import Any, Dict, Generator, Iterable, List, Optional, Union, Callable |
| 28 | +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union |
29 | 29 | from weakref import proxy
|
30 | 30 |
|
31 | 31 | import torch
|
@@ -127,7 +127,7 @@ def __init__(
|
127 | 127 | sync_batchnorm: bool = False,
|
128 | 128 | reload_dataloaders_every_n_epochs: int = 0,
|
129 | 129 | default_root_dir: Optional[_PATH] = None,
|
130 |
| - compile_fn: Optional[Callable] = None |
| 130 | + compile_fn: Optional[Callable] = None, |
131 | 131 | ) -> None:
|
132 | 132 | r"""Customize every aspect of training via flags.
|
133 | 133 |
|
@@ -470,7 +470,7 @@ def __init__(
|
470 | 470 | self.state = TrainerState()
|
471 | 471 |
|
472 | 472 | self.compile_fn = compile_fn
|
473 |
| - |
| 473 | + |
474 | 474 | # configure profiler
|
475 | 475 | setup._init_profiler(self, profiler)
|
476 | 476 |
|
@@ -962,7 +962,7 @@ def _run(
|
962 | 962 | # compile if compile_fn provided after configured strategy
|
963 | 963 | if self.compile_fn is not None:
|
964 | 964 | self.strategy.model = self.compile_fn(self.strategy.model)
|
965 |
| - |
| 965 | + |
966 | 966 | # hook
|
967 | 967 | if self.state.fn == TrainerFn.FITTING:
|
968 | 968 | call._call_callback_hooks(self, "on_fit_start")
|
|
0 commit comments