14
14
import inspect
15
15
import os
16
16
from collections .abc import Generator , Mapping , Sequence
17
- from contextlib import contextmanager , nullcontext
17
+ from contextlib import AbstractContextManager , contextmanager , nullcontext
18
18
from functools import partial
19
19
from pathlib import Path
20
20
from typing import (
21
21
Any ,
22
22
Callable ,
23
- ContextManager ,
24
23
Optional ,
25
24
Union ,
26
25
cast ,
@@ -484,7 +483,7 @@ def clip_gradients(
484
483
)
485
484
raise ValueError ("You have to specify either `clip_val` or `max_norm` to do gradient clipping!" )
486
485
487
- def autocast (self ) -> ContextManager :
486
+ def autocast (self ) -> AbstractContextManager :
488
487
"""A context manager to automatically convert operations for the chosen precision.
489
488
490
489
Use this only if the `forward` method of your model does not cover all operations you wish to run with the
@@ -634,7 +633,7 @@ def rank_zero_first(self, local: bool = False) -> Generator:
634
633
if rank == 0 :
635
634
barrier ()
636
635
637
- def no_backward_sync (self , module : _FabricModule , enabled : bool = True ) -> ContextManager :
636
+ def no_backward_sync (self , module : _FabricModule , enabled : bool = True ) -> AbstractContextManager :
638
637
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
639
638
640
639
Use this context manager when performing gradient accumulation to speed up training with multiple devices.
@@ -676,7 +675,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
676
675
forward_module , _ = _unwrap_compiled (module ._forward_module )
677
676
return self ._strategy ._backward_sync_control .no_backward_sync (forward_module , enabled )
678
677
679
- def sharded_model (self ) -> ContextManager :
678
+ def sharded_model (self ) -> AbstractContextManager :
680
679
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
681
680
682
681
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
@@ -688,12 +687,12 @@ def sharded_model(self) -> ContextManager:
688
687
return self .strategy .module_sharded_context ()
689
688
return nullcontext ()
690
689
691
- def init_tensor (self ) -> ContextManager :
690
+ def init_tensor (self ) -> AbstractContextManager :
692
691
"""Tensors that you instantiate under this context manager will be created on the device right away and have
693
692
the right data type depending on the precision setting in Fabric."""
694
693
return self ._strategy .tensor_init_context ()
695
694
696
- def init_module (self , empty_init : Optional [bool ] = None ) -> ContextManager :
695
+ def init_module (self , empty_init : Optional [bool ] = None ) -> AbstractContextManager :
697
696
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
698
697
699
698
The parameters get created on the device and with the right data type right away without wasting memory being
0 commit comments