1414import inspect
1515import os
1616from collections .abc import Generator , Mapping , Sequence
17- from contextlib import contextmanager , nullcontext
17+ from contextlib import AbstractContextManager , contextmanager , nullcontext
1818from functools import partial
1919from pathlib import Path
2020from typing import (
2121 Any ,
2222 Callable ,
23- ContextManager ,
2423 Optional ,
2524 Union ,
2625 cast ,
@@ -484,7 +483,7 @@ def clip_gradients(
484483 )
485484 raise ValueError ("You have to specify either `clip_val` or `max_norm` to do gradient clipping!" )
486485
487- def autocast (self ) -> ContextManager :
486+ def autocast (self ) -> AbstractContextManager :
488487 """A context manager to automatically convert operations for the chosen precision.
489488
490489 Use this only if the `forward` method of your model does not cover all operations you wish to run with the
@@ -634,7 +633,7 @@ def rank_zero_first(self, local: bool = False) -> Generator:
634633 if rank == 0 :
635634 barrier ()
636635
637- def no_backward_sync (self , module : _FabricModule , enabled : bool = True ) -> ContextManager :
636+ def no_backward_sync (self , module : _FabricModule , enabled : bool = True ) -> AbstractContextManager :
638637 r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
639638
640639 Use this context manager when performing gradient accumulation to speed up training with multiple devices.
@@ -676,7 +675,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
676675 forward_module , _ = _unwrap_compiled (module ._forward_module )
677676 return self ._strategy ._backward_sync_control .no_backward_sync (forward_module , enabled )
678677
679- def sharded_model (self ) -> ContextManager :
678+ def sharded_model (self ) -> AbstractContextManager :
680679 r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
681680
682681 .. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
@@ -688,12 +687,12 @@ def sharded_model(self) -> ContextManager:
688687 return self .strategy .module_sharded_context ()
689688 return nullcontext ()
690689
691- def init_tensor (self ) -> ContextManager :
690+ def init_tensor (self ) -> AbstractContextManager :
692691 """Tensors that you instantiate under this context manager will be created on the device right away and have
693692 the right data type depending on the precision setting in Fabric."""
694693 return self ._strategy .tensor_init_context ()
695694
696- def init_module (self , empty_init : Optional [bool ] = None ) -> ContextManager :
695+ def init_module (self , empty_init : Optional [bool ] = None ) -> AbstractContextManager :
697696 """Instantiate the model and its parameters under this context manager to reduce peak memory usage.
698697
699698 The parameters get created on the device and with the right data type right away without wasting memory being
0 commit comments