Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mypy==1.15.0
mypy==1.16.0
torch==2.7.1

types-Markdown
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class _Linear8bitLt(bnb.nn.Linear8bitLt):
def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None:
super().__init__(*args, device=device, threshold=threshold, **kwargs)
self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type]
self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type]
self.bias: Optional[torch.nn.Parameter] = self.bias
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
# filling the device memory with float32 weights which could lead to OOM
if torch.tensor(0, device=device).device.type == "cuda":
Expand Down Expand Up @@ -310,7 +310,7 @@ class _Linear4bit(bnb.nn.Linear4bit):
def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None:
super().__init__(*args, device=device, **kwargs)
self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type]
self.bias: Optional[torch.nn.Parameter] = self.bias
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
# filling the device memory with float32 weights which could lead to OOM
if torch.tensor(0, device=device).device.type == "cuda":
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def precision(self) -> FSDPPrecision:

@precision.setter
@override
def precision(self, precision: Optional[FSDPPrecision]) -> None:
def precision(self, precision: Optional[Precision]) -> None:
if precision is not None and not isinstance(precision, FSDPPrecision):
raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}")
self._precision = precision
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/strategies/single_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
Expand Down Expand Up @@ -61,7 +61,7 @@ def checkpoint_io(self) -> XLACheckpointIO:

@checkpoint_io.setter
@override
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io
Expand All @@ -77,7 +77,7 @@ def precision(self) -> XLAPrecision:

@precision.setter
@override
def precision(self, precision: Optional[XLAPrecision]) -> None:
def precision(self, precision: Optional[Precision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io
self._precision: Optional[Precision] = None
# Call the precision setter for input validation
self.precision = precision # type: ignore[assignment]
self.precision = precision
self._launcher: Optional[_Launcher] = None
self._backward_sync_control: Optional[_BackwardSyncControl] = None

Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
Expand Down Expand Up @@ -83,7 +83,7 @@ def checkpoint_io(self) -> XLACheckpointIO:

@checkpoint_io.setter
@override
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io
Expand All @@ -99,7 +99,7 @@ def precision(self) -> XLAPrecision:

@precision.setter
@override
def precision(self, precision: Optional[XLAPrecision]) -> None:
def precision(self, precision: Optional[Precision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
Expand Down Expand Up @@ -134,7 +134,7 @@ def checkpoint_io(self) -> XLACheckpointIO:

@checkpoint_io.setter
@override
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io
Expand All @@ -150,7 +150,7 @@ def precision(self) -> XLAPrecision:

@precision.setter
@override
def precision(self, precision: Optional[XLAPrecision]) -> None:
def precision(self, precision: Optional[Precision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA FSDP strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def trainer(self) -> "pl.Trainer":
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.trainer = trainer # type: ignore[assignment]
v.trainer = trainer
self._trainer = trainer

@property
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/serve/servable_module_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def on_train_start(self, trainer: "pl.Trainer", servable_module: "pl.LightningMo

# Note: The Trainer needs to be detached from the pl_module before starting the process.
# This would fail during the deepcopy with DDP.
servable_module.trainer = None # type: ignore[assignment]
servable_module.trainer = None

process = Process(target=self._start_server, args=(servable_module, self.host, self.port, self.optimization))
process.start()
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def precision_plugin(self) -> FSDPPrecision:

@precision_plugin.setter
@override
def precision_plugin(self, precision_plugin: Optional[FSDPPrecision]) -> None:
def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecision):
raise TypeError(
f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}"
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/strategies/single_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import lightning.pytorch as pl
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import XLACheckpointIO
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.types import _DEVICE
Expand Down Expand Up @@ -67,7 +67,7 @@ def checkpoint_io(self) -> Union[XLACheckpointIO, _WrappingCheckpointIO]:

@checkpoint_io.setter
@override
def checkpoint_io(self, io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]]) -> None:
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
if io is not None and not isinstance(io, (XLACheckpointIO, _WrappingCheckpointIO)):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io
Expand All @@ -83,7 +83,7 @@ def precision_plugin(self) -> XLAPrecision:

@precision_plugin.setter
@override
def precision_plugin(self, precision_plugin: Optional[XLAPrecision]) -> None:
def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision_plugin}")
self._precision_plugin = precision_plugin
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io
self._precision_plugin: Optional[Precision] = None
# Call the precision setter for input validation
self.precision_plugin = precision_plugin # type: ignore[assignment]
self.precision_plugin = precision_plugin
self._lightning_module: Optional[pl.LightningModule] = None
self._model: Optional[Module] = None
self._launcher: Optional[_Launcher] = None
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import lightning.pytorch as pl
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins import XLACheckpointIO
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand Down Expand Up @@ -81,7 +81,7 @@ def checkpoint_io(self) -> Union[XLACheckpointIO, _WrappingCheckpointIO]:

@checkpoint_io.setter
@override
def checkpoint_io(self, io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]]) -> None:
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
if io is not None and not isinstance(io, (XLACheckpointIO, _WrappingCheckpointIO)):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io
Expand All @@ -97,7 +97,7 @@ def precision_plugin(self) -> XLAPrecision:

@precision_plugin.setter
@override
def precision_plugin(self, precision_plugin: Optional[XLAPrecision]) -> None:
def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision_plugin}")
self._precision_plugin = precision_plugin
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def _predict_impl(
# --------------------
log.debug(f"{self.__class__.__name__}: trainer predict stage")

self.predict_loop.return_predictions = return_predictions # type: ignore[assignment]
self.predict_loop.return_predictions = return_predictions

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
Expand Down
Loading