Skip to content

Commit a35f1fe

Browse files
committed
contextlib.AbstractContextManager
1 parent b9920bd commit a35f1fe

File tree

22 files changed

+96
-91
lines changed

22 files changed

+96
-91
lines changed

src/lightning/fabric/fabric.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
import inspect
1515
import os
1616
from collections.abc import Generator, Mapping, Sequence
17-
from contextlib import contextmanager, nullcontext
17+
from contextlib import AbstractContextManager, contextmanager, nullcontext
1818
from functools import partial
1919
from pathlib import Path
2020
from typing import (
2121
Any,
2222
Callable,
23-
ContextManager,
2423
Optional,
2524
Union,
2625
cast,
@@ -484,7 +483,7 @@ def clip_gradients(
484483
)
485484
raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!")
486485

487-
def autocast(self) -> ContextManager:
486+
def autocast(self) -> AbstractContextManager:
488487
"""A context manager to automatically convert operations for the chosen precision.
489488
490489
Use this only if the `forward` method of your model does not cover all operations you wish to run with the
@@ -634,7 +633,7 @@ def rank_zero_first(self, local: bool = False) -> Generator:
634633
if rank == 0:
635634
barrier()
636635

637-
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager:
636+
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> AbstractContextManager:
638637
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
639638
640639
Use this context manager when performing gradient accumulation to speed up training with multiple devices.
@@ -676,7 +675,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
676675
forward_module, _ = _unwrap_compiled(module._forward_module)
677676
return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled)
678677

679-
def sharded_model(self) -> ContextManager:
678+
def sharded_model(self) -> AbstractContextManager:
680679
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
681680
682681
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
@@ -688,12 +687,12 @@ def sharded_model(self) -> ContextManager:
688687
return self.strategy.module_sharded_context()
689688
return nullcontext()
690689

691-
def init_tensor(self) -> ContextManager:
690+
def init_tensor(self) -> AbstractContextManager:
692691
"""Tensors that you instantiate under this context manager will be created on the device right away and have
693692
the right data type depending on the precision setting in Fabric."""
694693
return self._strategy.tensor_init_context()
695694

696-
def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:
695+
def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
697696
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
698697
699698
The parameters get created on the device and with the right data type right away without wasting memory being

src/lightning/fabric/plugins/precision/amp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, ContextManager, Literal, Optional
14+
from contextlib import AbstractContextManager
15+
from typing import Any, Literal, Optional
1516

1617
import torch
1718
from lightning_utilities.core.apply_func import apply_to_collection
@@ -59,7 +60,7 @@ def __init__(
5960
self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16
6061

6162
@override
62-
def forward_context(self) -> ContextManager:
63+
def forward_context(self) -> AbstractContextManager:
6364
return torch.autocast(self.device, dtype=self._desired_input_dtype)
6465

6566
@override

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import os
1818
import warnings
1919
from collections import OrderedDict
20-
from contextlib import ExitStack
20+
from contextlib import AbstractContextManager, ExitStack
2121
from functools import partial
2222
from types import ModuleType
23-
from typing import Any, Callable, ContextManager, Literal, Optional, cast
23+
from typing import Any, Callable, Literal, Optional, cast
2424

2525
import torch
2626
from lightning_utilities import apply_to_collection
@@ -123,11 +123,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
123123
return module
124124

125125
@override
126-
def tensor_init_context(self) -> ContextManager:
126+
def tensor_init_context(self) -> AbstractContextManager:
127127
return _DtypeContextManager(self.dtype)
128128

129129
@override
130-
def module_init_context(self) -> ContextManager:
130+
def module_init_context(self) -> AbstractContextManager:
131131
if self.ignore_modules:
132132
# cannot patch the Linear class if the user wants to skip some submodules
133133
raise RuntimeError(
@@ -145,7 +145,7 @@ def module_init_context(self) -> ContextManager:
145145
return stack
146146

147147
@override
148-
def forward_context(self) -> ContextManager:
148+
def forward_context(self) -> AbstractContextManager:
149149
return _DtypeContextManager(self.dtype)
150150

151151
@override

src/lightning/fabric/plugins/precision/deepspeed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from contextlib import nullcontext
15-
from typing import TYPE_CHECKING, Any, ContextManager, Literal
14+
from contextlib import AbstractContextManager, nullcontext
15+
from typing import TYPE_CHECKING, Any, Literal
1616

1717
import torch
1818
from lightning_utilities.core.apply_func import apply_to_collection
@@ -68,13 +68,13 @@ def convert_module(self, module: Module) -> Module:
6868
return module
6969

7070
@override
71-
def tensor_init_context(self) -> ContextManager:
71+
def tensor_init_context(self) -> AbstractContextManager:
7272
if "true" not in self.precision:
7373
return nullcontext()
7474
return _DtypeContextManager(self._desired_dtype)
7575

7676
@override
77-
def module_init_context(self) -> ContextManager:
77+
def module_init_context(self) -> AbstractContextManager:
7878
return self.tensor_init_context()
7979

8080
@override

src/lightning/fabric/plugins/precision/double.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, ContextManager, Literal
14+
from contextlib import AbstractContextManager
15+
from typing import Any, Literal
1516

1617
import torch
1718
from lightning_utilities.core.apply_func import apply_to_collection
@@ -33,15 +34,15 @@ def convert_module(self, module: Module) -> Module:
3334
return module.double()
3435

3536
@override
36-
def tensor_init_context(self) -> ContextManager:
37+
def tensor_init_context(self) -> AbstractContextManager:
3738
return _DtypeContextManager(torch.double)
3839

3940
@override
40-
def module_init_context(self) -> ContextManager:
41+
def module_init_context(self) -> AbstractContextManager:
4142
return self.tensor_init_context()
4243

4344
@override
44-
def forward_context(self) -> ContextManager:
45+
def forward_context(self) -> AbstractContextManager:
4546
return self.tensor_init_context()
4647

4748
@override

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional
14+
from contextlib import AbstractContextManager
15+
from typing import TYPE_CHECKING, Any, Literal, Optional
1516

1617
import torch
1718
from lightning_utilities import apply_to_collection
@@ -100,15 +101,15 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
100101
)
101102

102103
@override
103-
def tensor_init_context(self) -> ContextManager:
104+
def tensor_init_context(self) -> AbstractContextManager:
104105
return _DtypeContextManager(self._desired_input_dtype)
105106

106107
@override
107-
def module_init_context(self) -> ContextManager:
108+
def module_init_context(self) -> AbstractContextManager:
108109
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
109110

110111
@override
111-
def forward_context(self) -> ContextManager:
112+
def forward_context(self) -> AbstractContextManager:
112113
if "mixed" in self.precision:
113114
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
114115
return self.tensor_init_context()

src/lightning/fabric/plugins/precision/half.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, ContextManager, Literal
14+
from contextlib import AbstractContextManager
15+
from typing import Any, Literal
1516

1617
import torch
1718
from lightning_utilities.core.apply_func import apply_to_collection
@@ -42,15 +43,15 @@ def convert_module(self, module: Module) -> Module:
4243
return module.to(dtype=self._desired_input_dtype)
4344

4445
@override
45-
def tensor_init_context(self) -> ContextManager:
46+
def tensor_init_context(self) -> AbstractContextManager:
4647
return _DtypeContextManager(self._desired_input_dtype)
4748

4849
@override
49-
def module_init_context(self) -> ContextManager:
50+
def module_init_context(self) -> AbstractContextManager:
5051
return self.tensor_init_context()
5152

5253
@override
53-
def forward_context(self) -> ContextManager:
54+
def forward_context(self) -> AbstractContextManager:
5455
return self.tensor_init_context()
5556

5657
@override

src/lightning/fabric/plugins/precision/precision.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from contextlib import nullcontext
15-
from typing import Any, ContextManager, Literal, Optional, Union
14+
from contextlib import AbstractContextManager, nullcontext
15+
from typing import Any, Literal, Optional, Union
1616

1717
from torch import Tensor
1818
from torch.nn import Module
@@ -53,19 +53,19 @@ def convert_module(self, module: Module) -> Module:
5353
"""
5454
return module
5555

56-
def tensor_init_context(self) -> ContextManager:
56+
def tensor_init_context(self) -> AbstractContextManager:
5757
"""Controls how tensors get created (device, dtype)."""
5858
return nullcontext()
5959

60-
def module_init_context(self) -> ContextManager:
60+
def module_init_context(self) -> AbstractContextManager:
6161
"""Instantiate module parameters or tensors in the precision type this plugin handles.
6262
6363
This is optional and depends on the precision limitations during optimization.
6464
6565
"""
6666
return nullcontext()
6767

68-
def forward_context(self) -> ContextManager:
68+
def forward_context(self) -> AbstractContextManager:
6969
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
7070
return nullcontext()
7171

src/lightning/fabric/plugins/precision/transformer_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414
import logging
1515
from collections.abc import Mapping
16-
from contextlib import ExitStack
17-
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional, Union
16+
from contextlib import AbstractContextManager, ExitStack
17+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
1818

1919
import torch
2020
from lightning_utilities import apply_to_collection
@@ -107,11 +107,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
107107
return module
108108

109109
@override
110-
def tensor_init_context(self) -> ContextManager:
110+
def tensor_init_context(self) -> AbstractContextManager:
111111
return _DtypeContextManager(self.weights_dtype)
112112

113113
@override
114-
def module_init_context(self) -> ContextManager:
114+
def module_init_context(self) -> AbstractContextManager:
115115
dtype_ctx = self.tensor_init_context()
116116
stack = ExitStack()
117117
if self.replace_layers:
@@ -126,7 +126,7 @@ def module_init_context(self) -> ContextManager:
126126
return stack
127127

128128
@override
129-
def forward_context(self) -> ContextManager:
129+
def forward_context(self) -> AbstractContextManager:
130130
dtype_ctx = _DtypeContextManager(self.weights_dtype)
131131
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
132132
import transformer_engine.pytorch as te

src/lightning/fabric/strategies/ddp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from contextlib import nullcontext
14+
from contextlib import AbstractContextManager, nullcontext
1515
from datetime import timedelta
16-
from typing import Any, ContextManager, Literal, Optional, Union
16+
from typing import Any, Literal, Optional, Union
1717

1818
import torch
1919
import torch.distributed
@@ -231,7 +231,7 @@ def _determine_ddp_device_ids(self) -> Optional[list[int]]:
231231

232232
class _DDPBackwardSyncControl(_BackwardSyncControl):
233233
@override
234-
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
234+
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
235235
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
236236
wrapper."""
237237
if not enabled:

0 commit comments

Comments
 (0)