Skip to content

Commit 3b1643c

Browse files
authored
Add @override for files in src/lightning/fabric/plugins/precision (#19158)
1 parent c985400 commit 3b1643c

File tree

8 files changed

+59
-3
lines changed

8 files changed

+59
-3
lines changed

src/lightning/fabric/plugins/precision/amp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import Tensor
1919
from torch.nn import Module
2020
from torch.optim import LBFGS, Optimizer
21+
from typing_extensions import override
2122

2223
from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
2324
from lightning.fabric.plugins.precision.precision import Precision
@@ -59,20 +60,25 @@ def __init__(
5960

6061
self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16
6162

63+
@override
6264
def forward_context(self) -> ContextManager:
6365
return torch.autocast(self.device, dtype=self._desired_input_dtype)
6466

67+
@override
6568
def convert_input(self, data: Any) -> Any:
6669
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
6770

71+
@override
6872
def convert_output(self, data: Any) -> Any:
6973
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
7074

75+
@override
7176
def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
7277
if self.scaler is not None:
7378
tensor = self.scaler.scale(tensor)
7479
super().backward(tensor, model, *args, **kwargs)
7580

81+
@override
7682
def optimizer_step(
7783
self,
7884
optimizer: Optimizable,
@@ -88,15 +94,18 @@ def optimizer_step(
8894
self.scaler.update()
8995
return step_output
9096

97+
@override
9198
def state_dict(self) -> Dict[str, Any]:
9299
if self.scaler is not None:
93100
return self.scaler.state_dict()
94101
return {}
95102

103+
@override
96104
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
97105
if self.scaler is not None:
98106
self.scaler.load_state_dict(state_dict)
99107

108+
@override
100109
def unscale_gradients(self, optimizer: Optimizer) -> None:
101110
scaler = self.scaler
102111
if scaler is not None:

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from lightning_utilities.core.imports import RequirementCache
2626
from torch import Tensor
2727
from torch.nn.modules.module import _IncompatibleKeys
28+
from typing_extensions import override
2829

2930
from lightning.fabric.plugins.precision.precision import Precision
3031
from lightning.fabric.plugins.precision.utils import (
@@ -96,6 +97,7 @@ def __init__(
9697
self.dtype = dtype
9798
self.ignore_modules = ignore_modules or set()
9899

100+
@override
99101
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
100102
# avoid naive users thinking they quantized their model
101103
if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
@@ -116,9 +118,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
116118
m.compute_type_is_set = False
117119
return module
118120

121+
@override
119122
def tensor_init_context(self) -> ContextManager:
120123
return _DtypeContextManager(self.dtype)
121124

125+
@override
122126
def module_init_context(self) -> ContextManager:
123127
if self.ignore_modules:
124128
# cannot patch the Linear class if the user wants to skip some submodules
@@ -136,12 +140,15 @@ def module_init_context(self) -> ContextManager:
136140
stack.enter_context(context_manager)
137141
return stack
138142

143+
@override
139144
def forward_context(self) -> ContextManager:
140145
return _DtypeContextManager(self.dtype)
141146

147+
@override
142148
def convert_input(self, data: Any) -> Any:
143149
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype)
144150

151+
@override
145152
def convert_output(self, data: Any) -> Any:
146153
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
147154

src/lightning/fabric/plugins/precision/deepspeed.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from lightning_utilities.core.apply_func import apply_to_collection
1919
from torch import Tensor
2020
from torch.nn import Module
21-
from typing_extensions import get_args
21+
from typing_extensions import get_args, override
2222

2323
from lightning.fabric.plugins.precision.precision import Precision
2424
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
@@ -61,29 +61,36 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
6161
}
6262
self._desired_dtype = precision_to_type[self.precision]
6363

64+
@override
6465
def convert_module(self, module: Module) -> Module:
6566
if "true" in self.precision:
6667
return module.to(dtype=self._desired_dtype)
6768
return module
6869

70+
@override
6971
def tensor_init_context(self) -> ContextManager:
7072
if "true" not in self.precision:
7173
return nullcontext()
7274
return _DtypeContextManager(self._desired_dtype)
7375

76+
@override
7477
def module_init_context(self) -> ContextManager:
7578
return self.tensor_init_context()
7679

80+
@override
7781
def convert_input(self, data: Any) -> Any:
7882
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
7983

84+
@override
8085
def convert_output(self, data: Any) -> Any:
8186
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
8287

88+
@override
8389
def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
8490
"""Performs back-propagation using DeepSpeed's engine."""
8591
model.backward(tensor, *args, **kwargs)
8692

93+
@override
8794
def optimizer_step(
8895
self,
8996
optimizer: Steppable,

src/lightning/fabric/plugins/precision/double.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning_utilities.core.apply_func import apply_to_collection
1818
from torch import Tensor
1919
from torch.nn import Module
20+
from typing_extensions import override
2021

2122
from lightning.fabric.plugins.precision.precision import Precision
2223
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
@@ -27,20 +28,26 @@ class DoublePrecision(Precision):
2728

2829
precision: Literal["64-true"] = "64-true"
2930

31+
@override
3032
def convert_module(self, module: Module) -> Module:
3133
return module.double()
3234

35+
@override
3336
def tensor_init_context(self) -> ContextManager:
3437
return _DtypeContextManager(torch.double)
3538

39+
@override
3640
def module_init_context(self) -> ContextManager:
3741
return self.tensor_init_context()
3842

43+
@override
3944
def forward_context(self) -> ContextManager:
4045
return self.tensor_init_context()
4146

47+
@override
4248
def convert_input(self, data: Any) -> Any:
4349
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)
4450

51+
@override
4552
def convert_output(self, data: Any) -> Any:
4653
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch import Tensor
1919
from torch.nn import Module
2020
from torch.optim import Optimizer
21-
from typing_extensions import get_args
21+
from typing_extensions import get_args, override
2222

2323
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
2424
from lightning.fabric.plugins.precision.precision import Precision
@@ -103,28 +103,35 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
103103
buffer_dtype=buffer_dtype,
104104
)
105105

106+
@override
106107
def tensor_init_context(self) -> ContextManager:
107108
return _DtypeContextManager(self._desired_input_dtype)
108109

110+
@override
109111
def module_init_context(self) -> ContextManager:
110112
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
111113

114+
@override
112115
def forward_context(self) -> ContextManager:
113116
if "mixed" in self.precision:
114117
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
115118
return self.tensor_init_context()
116119

120+
@override
117121
def convert_input(self, data: Any) -> Any:
118122
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
119123

124+
@override
120125
def convert_output(self, data: Any) -> Any:
121126
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
122127

128+
@override
123129
def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
124130
if self.scaler is not None:
125131
tensor = cast(Tensor, self.scaler.scale(tensor))
126132
super().backward(tensor, model, *args, **kwargs)
127133

134+
@override
128135
def optimizer_step(
129136
self,
130137
optimizer: Optimizable,
@@ -138,18 +145,21 @@ def optimizer_step(
138145
self.scaler.update()
139146
return step_output
140147

148+
@override
141149
def unscale_gradients(self, optimizer: Optimizer) -> None:
142150
scaler = self.scaler
143151
if scaler is not None:
144152
if _optimizer_handles_unscaling(optimizer):
145153
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
146154
scaler.unscale_(optimizer)
147155

156+
@override
148157
def state_dict(self) -> Dict[str, Any]:
149158
if self.scaler is not None:
150159
return self.scaler.state_dict()
151160
return {}
152161

162+
@override
153163
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
154164
if self.scaler is not None:
155165
self.scaler.load_state_dict(state_dict)

src/lightning/fabric/plugins/precision/half.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning_utilities.core.apply_func import apply_to_collection
1818
from torch import Tensor
1919
from torch.nn import Module
20+
from typing_extensions import override
2021

2122
from lightning.fabric.plugins.precision.precision import Precision
2223
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
@@ -36,20 +37,26 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
3637
self.precision = precision
3738
self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16
3839

40+
@override
3941
def convert_module(self, module: Module) -> Module:
4042
return module.to(dtype=self._desired_input_dtype)
4143

44+
@override
4245
def tensor_init_context(self) -> ContextManager:
4346
return _DtypeContextManager(self._desired_input_dtype)
4447

48+
@override
4549
def module_init_context(self) -> ContextManager:
4650
return self.tensor_init_context()
4751

52+
@override
4853
def forward_context(self) -> ContextManager:
4954
return self.tensor_init_context()
5055

56+
@override
5157
def convert_input(self, data: Any) -> Any:
5258
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
5359

60+
@override
5461
def convert_output(self, data: Any) -> Any:
5562
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

src/lightning/fabric/plugins/precision/transformer_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from lightning_utilities import apply_to_collection
2020
from lightning_utilities.core.imports import RequirementCache
2121
from torch import Tensor
22+
from typing_extensions import override
2223

2324
from lightning.fabric.plugins.precision.precision import Precision
2425
from lightning.fabric.plugins.precision.utils import (
@@ -89,6 +90,7 @@ def __init__(
8990
self.replace_layers = replace_layers
9091
self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
9192

93+
@override
9294
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
9395
# avoid converting if any is found. assume the user took care of it
9496
if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
@@ -103,9 +105,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
103105
module = module.to(dtype=self.weights_dtype)
104106
return module
105107

108+
@override
106109
def tensor_init_context(self) -> ContextManager:
107110
return _DtypeContextManager(self.weights_dtype)
108111

112+
@override
109113
def module_init_context(self) -> ContextManager:
110114
dtype_ctx = self.tensor_init_context()
111115
stack = ExitStack()
@@ -122,6 +126,7 @@ def module_init_context(self) -> ContextManager:
122126
stack.enter_context(dtype_ctx)
123127
return stack
124128

129+
@override
125130
def forward_context(self) -> ContextManager:
126131
dtype_ctx = _DtypeContextManager(self.weights_dtype)
127132
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
@@ -135,9 +140,11 @@ def forward_context(self) -> ContextManager:
135140
stack.enter_context(autocast_ctx)
136141
return stack
137142

143+
@override
138144
def convert_input(self, data: Any) -> Any:
139145
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
140146

147+
@override
141148
def convert_output(self, data: Any) -> Any:
142149
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
143150

src/lightning/fabric/plugins/precision/xla.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Any, Literal
1616

1717
import torch
18-
from typing_extensions import get_args
18+
from typing_extensions import get_args, override
1919

2020
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
2121
from lightning.fabric.plugins.precision.precision import Precision
@@ -56,6 +56,7 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
5656
else:
5757
self._desired_dtype = torch.float32
5858

59+
@override
5960
def optimizer_step(
6061
self,
6162
optimizer: Optimizable,
@@ -66,6 +67,7 @@ def optimizer_step(
6667
# you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
6768
return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)
6869

70+
@override
6971
def teardown(self) -> None:
7072
os.environ.pop("XLA_USE_BF16", None)
7173
os.environ.pop("XLA_USE_F16", None)

0 commit comments

Comments
 (0)