Skip to content

Commit 5a03612

Browse files
authored
Fix incorrect str plugin type (#19133)
1 parent 58b4bfa commit 5a03612

File tree

4 files changed

+14
-28
lines changed

4 files changed

+14
-28
lines changed

src/lightning/fabric/connector.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@
6565
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
6666
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
6767

68-
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
69-
_PLUGIN_INPUT = Union[_PLUGIN, str]
68+
_PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO]
7069

7170

7271
class _Connector:
@@ -84,15 +83,9 @@ class _Connector:
8483
backend (registed these too, and _strategy_type could be deprecated)
8584
8685
C. plugins flag could be:
87-
1. List of str, which could contain:
88-
i. precision str (Not supported in the old accelerator_connector version)
89-
ii. checkpoint_io str (Not supported in the old accelerator_connector version)
90-
iii. cluster_environment str (Not supported in the old accelerator_connector version)
91-
2. List of class, which could contains:
92-
i. precision class (should be removed, and precision flag should allow user pass classes)
93-
ii. checkpoint_io class
94-
iii. cluster_environment class
95-
86+
1. precision class (should be removed, and precision flag should allow user pass classes)
87+
2. checkpoint_io class
88+
3. cluster_environment class
9689
9790
priorities which to take when:
9891
A. Class > str

src/lightning/pytorch/plugins/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision
1414
from lightning.pytorch.plugins.precision.xla import XLAPrecision
1515

16-
PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync]
17-
PLUGIN_INPUT = Union[PLUGIN, str]
16+
_PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync]
1817

1918
__all__ = [
2019
"AsyncCheckpointIO",

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from lightning.pytorch.accelerators.mps import MPSAccelerator
3737
from lightning.pytorch.accelerators.xla import XLAAccelerator
3838
from lightning.pytorch.plugins import (
39-
PLUGIN_INPUT,
39+
_PLUGIN_INPUT,
4040
CheckpointIO,
4141
DeepSpeedPrecision,
4242
DoublePrecision,
@@ -81,7 +81,7 @@ def __init__(
8181
num_nodes: int = 1,
8282
accelerator: Union[str, Accelerator] = "auto",
8383
strategy: Union[str, Strategy] = "auto",
84-
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
84+
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
8585
precision: Optional[_PRECISION_INPUT] = None,
8686
sync_batchnorm: bool = False,
8787
benchmark: Optional[bool] = None,
@@ -96,20 +96,14 @@ def __init__(
9696
2. accelerator str
9797
3. accelerator auto
9898
99-
B. strategy flag could be :
99+
B. strategy flag could be:
100100
1. strategy class
101101
2. strategy str registered with StrategyRegistry
102102
103103
C. plugins flag could be:
104-
1. List of str, which could contain:
105-
i. precision str (Not supported in the old accelerator_connector version)
106-
ii. checkpoint_io str (Not supported in the old accelerator_connector version)
107-
iii. cluster_environment str (Not supported in the old accelerator_connector version)
108-
2. List of class, which could contains:
109-
i. precision class (should be removed, and precision flag should allow user pass classes)
110-
ii. checkpoint_io class
111-
iii. cluster_environment class
112-
104+
1. precision class (should be removed, and precision flag should allow user pass classes)
105+
2. checkpoint_io class
106+
3. cluster_environment class
113107
114108
priorities which to take when:
115109
A. Class > str
@@ -175,7 +169,7 @@ def _check_config_and_set_final_flags(
175169
strategy: Union[str, Strategy],
176170
accelerator: Union[str, Accelerator],
177171
precision: Optional[_PRECISION_INPUT],
178-
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]],
172+
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]],
179173
sync_batchnorm: bool,
180174
) -> None:
181175
"""This method checks:

src/lightning/pytorch/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop
4848
from lightning.pytorch.loops.fit_loop import _FitLoop
4949
from lightning.pytorch.loops.utilities import _parse_loop_limits, _reset_progress
50-
from lightning.pytorch.plugins import PLUGIN_INPUT, Precision
50+
from lightning.pytorch.plugins import _PLUGIN_INPUT, Precision
5151
from lightning.pytorch.profilers import Profiler
5252
from lightning.pytorch.strategies import ParallelStrategy, Strategy
5353
from lightning.pytorch.trainer import call, setup
@@ -128,7 +128,7 @@ def __init__(
128128
profiler: Optional[Union[Profiler, str]] = None,
129129
detect_anomaly: bool = False,
130130
barebones: bool = False,
131-
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
131+
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
132132
sync_batchnorm: bool = False,
133133
reload_dataloaders_every_n_epochs: int = 0,
134134
default_root_dir: Optional[_PATH] = None,

0 commit comments

Comments
 (0)