Skip to content

Commit 89f4d18

Browse files
awaelchlilantiga
authored andcommitted
Rename PrecisionPlugin -> Precision (#18840)
1 parent 37fec5e commit 89f4d18

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+437
-322
lines changed

docs/source-pytorch/api_references.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ precision
112112
:nosignatures:
113113
:template: classtemplate.rst
114114

115-
DeepSpeedPrecisionPlugin
116-
DoublePrecisionPlugin
117-
HalfPrecisionPlugin
118-
FSDPPrecisionPlugin
119-
MixedPrecisionPlugin
120-
PrecisionPlugin
121-
XLAPrecisionPlugin
122-
TransformerEnginePrecisionPlugin
123-
BitsandbytesPrecisionPlugin
115+
DeepSpeedPrecision
116+
DoublePrecision
117+
HalfPrecision
118+
FSDPPrecision
119+
MixedPrecision
120+
Precision
121+
XLAPrecision
122+
TransformerEnginePrecision
123+
BitsandbytesPrecision
124124

125125
environments
126126
""""""""""""

docs/source-pytorch/common/precision_expert.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ N-Bit Precision (Expert)
1212
Precision Plugins
1313
*****************
1414

15-
You can also customize and pass your own Precision Plugin by subclassing the :class:`~lightning.pytorch.plugins.precision.precision_plugin.PrecisionPlugin` class.
15+
You can also customize and pass your own Precision Plugin by subclassing the :class:`~lightning.pytorch.plugins.precision.precision.Precision` class.
1616

1717
- Perform pre and post backward/optimizer step operations such as scaling gradients.
1818
- Provide context managers for forward, training_step, etc.
1919

2020
.. code-block:: python
2121
22-
class CustomPrecisionPlugin(PrecisionPlugin):
22+
class CustomPrecision(Precision):
2323
precision = "16-mixed"
2424
2525
...
2626
2727
28-
trainer = Trainer(plugins=[CustomPrecisionPlugin()])
28+
trainer = Trainer(plugins=[CustomPrecision()])

docs/source-pytorch/common/precision_intermediate.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,18 +183,19 @@ If your model weights can fit on a single device with 16 bit precision, it's rec
183183

184184
Quantizing the model will dramatically reduce the weight's memory requirements but may have a negative impact on the model's performance or runtime.
185185

186-
The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecisionPlugin` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives.
186+
The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecision` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives.
187187

188188
.. code-block:: python
189189
190-
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin
190+
from lightning.pytorch.plugins import BitsandbytesPrecision
191191
192192
# this will pick out the compute dtype automatically, by default `bfloat16`
193-
precision = BitsandbytesPrecisionPlugin(mode="nf4-dq")
193+
precision = BitsandbytesPrecision(mode="nf4-dq")
194+
precision = BitsandbytesPrecisionPlugin()
194195
trainer = Trainer(plugins=precision)
195196
196197
# Customize the dtype, or skip some modules
197-
precision = BitsandbytesPrecisionPlugin(mode="int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
198+
precision = BitsandbytesPrecision(mode="int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
198199
trainer = Trainer(plugins=precision)
199200
200201

docs/source-pytorch/extensions/plugins.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ The full list of built-in precision plugins is listed below.
5252
:nosignatures:
5353
:template: classtemplate.rst
5454

55-
DeepSpeedPrecisionPlugin
56-
DoublePrecisionPlugin
57-
HalfPrecisionPlugin
58-
FSDPPrecisionPlugin
59-
MixedPrecisionPlugin
60-
PrecisionPlugin
61-
XLAPrecisionPlugin
62-
TransformerEnginePrecisionPlugin
63-
BitsandbytesPrecisionPlugin
55+
DeepSpeedPrecision
56+
DoublePrecision
57+
HalfPrecision
58+
FSDPPrecision
59+
MixedPrecision
60+
Precision
61+
XLAPrecision
62+
TransformerEnginePrecision
63+
BitsandbytesPrecision
6464

6565
More information regarding precision with Lightning can be found :ref:`here <precision>`
6666

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class BitsandbytesPrecision(Precision):
6161

6262
# TODO: we could implement optimizer replacement with
6363
# - Fabric: Add `Precision.convert_optimizer` from `Strategy.setup_optimizer`
64-
# - Trainer: Use `PrecisionPlugin.connect`
64+
# - Trainer: Use `Precision.connect`
6565

6666
def __init__(
6767
self,

src/lightning/pytorch/_graveyard/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
import lightning.pytorch._graveyard._torchmetrics
1515
import lightning.pytorch._graveyard.hpu
1616
import lightning.pytorch._graveyard.ipu
17+
import lightning.pytorch._graveyard.precision
1718
import lightning.pytorch._graveyard.tpu # noqa: F401
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import sys
2+
from typing import TYPE_CHECKING, Literal, Optional
3+
4+
import lightning.pytorch as pl
5+
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation
6+
from lightning.pytorch.plugins.precision import (
7+
BitsandbytesPrecision,
8+
DeepSpeedPrecision,
9+
DoublePrecision,
10+
FSDPPrecision,
11+
HalfPrecision,
12+
MixedPrecision,
13+
Precision,
14+
TransformerEnginePrecision,
15+
XLAPrecision,
16+
)
17+
18+
if TYPE_CHECKING:
19+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
20+
21+
22+
def _patch_sys_modules() -> None:
23+
sys.modules["lightning.pytorch.plugins.precision.precision_plugin"] = sys.modules[
24+
"lightning.pytorch.plugins.precision.precision"
25+
]
26+
27+
28+
class FSDPMixedPrecisionPlugin(FSDPPrecision):
29+
"""AMP for Fully Sharded Data Parallel (FSDP) Training.
30+
31+
.. deprecated:: Use :class:`FSDPPrecision` instead.
32+
33+
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
34+
35+
"""
36+
37+
def __init__(
38+
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None
39+
) -> None:
40+
rank_zero_deprecation(
41+
f"The `{type(self).__name__}` is deprecated."
42+
" Use `lightning.pytorch.plugins.precision.FSDPPrecision` instead."
43+
)
44+
super().__init__(precision=precision, scaler=scaler)
45+
46+
47+
def _patch_classes() -> None:
48+
classes_map = (
49+
# module name, old name, new class
50+
("bitsandbytes", "BitsandbytesPrecisionPlugin", BitsandbytesPrecision),
51+
("deepspeed", "DeepSpeedPrecisionPlugin", DeepSpeedPrecision),
52+
("double", "DoublePrecisionPlugin", DoublePrecision),
53+
("fsdp", "FSDPPrecisionPlugin", FSDPPrecision),
54+
("fsdp", "FSDPMixedPrecisionPlugin", FSDPPrecision),
55+
("half", "HalfPrecisionPlugin", HalfPrecision),
56+
("amp", "MixedPrecisionPlugin", MixedPrecision),
57+
("precision", "PrecisionPlugin", Precision),
58+
("transformer_engine", "TransformerEnginePrecisionPlugin", TransformerEnginePrecision),
59+
("xla", "XLAPrecisionPlugin", XLAPrecision),
60+
)
61+
62+
for module_name, deprecated_name, new_class in classes_map:
63+
setattr(getattr(pl.plugins.precision, module_name), deprecated_name, new_class)
64+
setattr(pl.plugins.precision, deprecated_name, new_class)
65+
setattr(pl.plugins, deprecated_name, new_class)
66+
67+
# special treatment for `FSDPMixedPrecisionPlugin` because it has a different signature
68+
setattr(pl.plugins.precision.fsdp, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
69+
setattr(pl.plugins.precision, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
70+
setattr(pl.plugins, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
71+
72+
73+
_patch_sys_modules()
74+
_patch_classes()

src/lightning/pytorch/_graveyard/tpu.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import lightning.pytorch as pl
1919
from lightning.fabric.strategies import _StrategyRegistry
2020
from lightning.pytorch.accelerators.xla import XLAAccelerator
21-
from lightning.pytorch.plugins.precision import XLAPrecisionPlugin
21+
from lightning.pytorch.plugins.precision import XLAPrecision
2222
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy
2323
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation
2424

@@ -63,47 +63,47 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6363
super().__init__(*args, **kwargs)
6464

6565

66-
class TPUPrecisionPlugin(XLAPrecisionPlugin):
66+
class TPUPrecisionPlugin(XLAPrecision):
6767
"""Legacy class.
6868
69-
Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecisionPlugin` instead.
69+
Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecision` instead.
7070
7171
"""
7272

7373
def __init__(self, *args: Any, **kwargs: Any) -> None:
7474
rank_zero_deprecation(
75-
"The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecisionPlugin`"
75+
"The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecision`"
7676
" instead."
7777
)
7878
super().__init__(precision="32-true")
7979

8080

81-
class TPUBf16PrecisionPlugin(XLAPrecisionPlugin):
81+
class TPUBf16PrecisionPlugin(XLAPrecision):
8282
"""Legacy class.
8383
84-
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
84+
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
8585
8686
"""
8787

8888
def __init__(self, *args: Any, **kwargs: Any) -> None:
8989
rank_zero_deprecation(
9090
"The `TPUBf16PrecisionPlugin` class is deprecated. Use"
91-
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
91+
" `lightning.pytorch.plugins.precision.XLAPrecision` instead."
9292
)
9393
super().__init__(precision="bf16-true")
9494

9595

96-
class XLABf16PrecisionPlugin(XLAPrecisionPlugin):
96+
class XLABf16PrecisionPlugin(XLAPrecision):
9797
"""Legacy class.
9898
99-
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
99+
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
100100
101101
"""
102102

103103
def __init__(self, *args: Any, **kwargs: Any) -> None:
104104
rank_zero_deprecation(
105105
"The `XLABf16PrecisionPlugin` class is deprecated. Use"
106-
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
106+
" `lightning.pytorch.plugins.precision.XLAPrecision` instead."
107107
)
108108
super().__init__(precision="bf16-true")
109109

src/lightning/pytorch/plugins/__init__.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,33 @@
33
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, TorchCheckpointIO, XLACheckpointIO
44
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
55
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
6-
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
7-
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin
8-
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
9-
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
10-
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, FSDPPrecisionPlugin
11-
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
12-
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
13-
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin
14-
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
6+
from lightning.pytorch.plugins.precision.amp import MixedPrecision
7+
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision
8+
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
9+
from lightning.pytorch.plugins.precision.double import DoublePrecision
10+
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
11+
from lightning.pytorch.plugins.precision.half import HalfPrecision
12+
from lightning.pytorch.plugins.precision.precision import Precision
13+
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision
14+
from lightning.pytorch.plugins.precision.xla import XLAPrecision
1515

16-
PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
16+
PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync]
1717
PLUGIN_INPUT = Union[PLUGIN, str]
1818

1919
__all__ = [
2020
"AsyncCheckpointIO",
2121
"CheckpointIO",
2222
"TorchCheckpointIO",
2323
"XLACheckpointIO",
24-
"BitsandbytesPrecisionPlugin",
25-
"DeepSpeedPrecisionPlugin",
26-
"DoublePrecisionPlugin",
27-
"HalfPrecisionPlugin",
28-
"MixedPrecisionPlugin",
29-
"PrecisionPlugin",
30-
"TransformerEnginePrecisionPlugin",
31-
"FSDPMixedPrecisionPlugin",
32-
"FSDPPrecisionPlugin",
33-
"XLAPrecisionPlugin",
24+
"BitsandbytesPrecision",
25+
"DeepSpeedPrecision",
26+
"DoublePrecision",
27+
"HalfPrecision",
28+
"MixedPrecision",
29+
"Precision",
30+
"TransformerEnginePrecision",
31+
"FSDPPrecision",
32+
"XLAPrecision",
3433
"LayerSync",
3534
"TorchSyncBatchNorm",
3635
]

src/lightning/pytorch/plugins/precision/__init__.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,24 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
15-
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin
16-
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
17-
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
18-
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, FSDPPrecisionPlugin
19-
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
20-
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
21-
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin
22-
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
14+
from lightning.pytorch.plugins.precision.amp import MixedPrecision
15+
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision
16+
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
17+
from lightning.pytorch.plugins.precision.double import DoublePrecision
18+
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
19+
from lightning.pytorch.plugins.precision.half import HalfPrecision
20+
from lightning.pytorch.plugins.precision.precision import Precision
21+
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision
22+
from lightning.pytorch.plugins.precision.xla import XLAPrecision
2323

2424
__all__ = [
25-
"BitsandbytesPrecisionPlugin",
26-
"DeepSpeedPrecisionPlugin",
27-
"DoublePrecisionPlugin",
28-
"FSDPMixedPrecisionPlugin",
29-
"FSDPPrecisionPlugin",
30-
"HalfPrecisionPlugin",
31-
"MixedPrecisionPlugin",
32-
"PrecisionPlugin",
33-
"TransformerEnginePrecisionPlugin",
34-
"XLAPrecisionPlugin",
25+
"BitsandbytesPrecision",
26+
"DeepSpeedPrecision",
27+
"DoublePrecision",
28+
"FSDPPrecision",
29+
"HalfPrecision",
30+
"MixedPrecision",
31+
"Precision",
32+
"TransformerEnginePrecision",
33+
"XLAPrecision",
3534
]

0 commit comments

Comments
 (0)