Skip to content

Commit 3fe0476

Browse files
committed
fix mypy issue
1 parent 2d208dc commit 3fe0476

File tree

5 files changed

+15
-13
lines changed

5 files changed

+15
-13
lines changed

src/lightning/fabric/plugins/precision/deepspeed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.nn import Module
2020
from typing_extensions import override
2121

22-
from lightning.fabric.plugins.precision.precision import Precision
22+
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, Precision
2323
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2424
from lightning.fabric.utilities.types import Steppable
2525

@@ -84,11 +84,11 @@ def optimizer_step(
8484
return self.deepspeed_impl.optimizer_step(optimizer, **kwargs)
8585

8686
@property
87-
def precision(self) -> _PRECISION_INPUT:
87+
def precision(self) -> _PRECISION_INPUT_STR:
8888
return self.deepspeed_impl.precision
8989

9090
@precision.setter
91-
def precision(self, precision: _PRECISION_INPUT) -> None:
91+
def precision(self, precision: _PRECISION_INPUT_STR) -> None:
9292
self.deepspeed_impl.precision = precision
9393

9494
@property

src/lightning/fabric/plugins/precision/xla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
from typing_extensions import override
1818

19-
from lightning.fabric.plugins.precision.precision import Precision
19+
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, Precision
2020
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2121
from lightning.fabric.utilities.types import Optimizable
2222

@@ -63,9 +63,9 @@ def _desired_dtype(self, dtype: torch.dtype) -> None:
6363
self.xla_impl._desired_dtype = dtype
6464

6565
@property
66-
def precision(self) -> _PRECISION_INPUT:
66+
def precision(self) -> _PRECISION_INPUT_STR:
6767
return self.xla_impl.precision
6868

6969
@precision.setter
70-
def precision(self, precision: _PRECISION_INPUT) -> None:
70+
def precision(self, precision: _PRECISION_INPUT_STR) -> None:
7171
self.xla_impl.precision = precision

src/lightning/pytorch/loggers/neptune.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
271271

272272
@override
273273
@rank_zero_only
274-
def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
274+
def log_metrics( # type: ignore[override]
275+
self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None
276+
) -> None:
275277
"""Log metrics (numeric values) in Neptune runs.
276278
277279
Args:

src/lightning/pytorch/plugins/precision/deepspeed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing_extensions import override
2222

2323
import lightning.pytorch as pl
24-
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
24+
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT, _PRECISION_INPUT_STR
2525
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2626
from lightning.fabric.utilities.types import Steppable
2727
from lightning.pytorch.plugins.precision.precision import Precision
@@ -111,11 +111,11 @@ def clip_gradients(
111111
)
112112

113113
@property
114-
def precision(self) -> str:
114+
def precision(self) -> _PRECISION_INPUT_STR:
115115
return self.deepspeed_precision_impl.precision
116116

117117
@precision.setter
118-
def precision(self, value: str) -> None:
118+
def precision(self, value: _PRECISION_INPUT_STR) -> None:
119119
self.deepspeed_precision_impl.precision = value
120120

121121
@property

src/lightning/pytorch/plugins/precision/xla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing_extensions import override
1818

1919
import lightning.pytorch as pl
20-
from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT
20+
from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT, _PRECISION_INPUT_STR
2121
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2222
from lightning.fabric.utilities.types import Optimizable
2323
from lightning.pytorch.plugins.precision.precision import Precision
@@ -54,11 +54,11 @@ def optimizer_step( # type: ignore[override]
5454
return self.xla_impl.optimizer_step(optimizer, model, closure, **kwargs)
5555

5656
@property
57-
def precision(self) -> _PRECISION_INPUT:
57+
def precision(self) -> _PRECISION_INPUT_STR:
5858
return self.xla_impl.precision
5959

6060
@precision.setter
61-
def precision(self, precision: _PRECISION_INPUT) -> None:
61+
def precision(self, precision: _PRECISION_INPUT_STR) -> None:
6262
self.xla_impl.precision = precision
6363

6464
@property

0 commit comments

Comments
 (0)