diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 70cd75c1c0d37..4c31dac1c922a 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.6.0 +torch >=2.1.0, <2.7.0 fsspec[http] >=2022.5.0, <2024.4.0 packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.11.0 diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt index 6be089ebb9767..ba68ae5fec613 100644 --- a/requirements/fabric/examples.txt +++ b/requirements/fabric/examples.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision >=0.16.0, <0.21.0 +torchvision >=0.16.0, <0.22.0 torchmetrics >=0.10.0, <1.8.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 4b32e44d2cacc..02beb77bb7b87 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.6.0 +torch >=2.1.0, <2.7.0 tqdm >=4.57.0, <4.67.0 PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2024.4.0 diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index 8a19179b813e0..f64eca95c0243 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment requests <2.32.0 -torchvision >=0.16.0, <0.21.0 +torchvision >=0.16.0, <0.22.0 ipython[all] <8.19.0 torchmetrics >=0.10.0, <1.8.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/typing.txt b/requirements/typing.txt index 71414998dd7f3..52f4df899e9d8 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.11.0 -torch==2.5.1 +torch==2.6.0 types-Markdown types-PyYAML diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 81e15a33cb983..883380bb881aa 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -50,7 +50,7 @@ def world_size(self) -> int: @override def broadcast(self, tensor: Tensor, src: int) -> Tensor: - dist.broadcast(tensor, src, group=self.group) + dist.broadcast(tensor, src, group=self.group) # type: ignore[arg-type] return tensor @override @@ -62,7 +62,7 @@ def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum" @override def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: op = self._convert_to_native_op(op) - dist.reduce(tensor, dst, op=op, group=self.group) + dist.reduce(tensor, dst, op=op, group=self.group) # type: ignore[arg-type] return tensor @override @@ -72,12 +72,12 @@ def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: @override def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: - dist.gather(tensor, gather_list, dst, group=self.group) + dist.gather(tensor, gather_list, dst, group=self.group) # type: ignore[arg-type] return gather_list @override def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: - dist.scatter(tensor, scatter_list, src, group=self.group) + dist.scatter(tensor, scatter_list, src, group=self.group) # type: ignore[arg-type] return tensor @override @@ -109,27 +109,27 @@ def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]: def broadcast_object_list( self, object_list: list[Any], src: int, device: Optional[torch.device] = None ) -> list[Any]: - dist.broadcast_object_list(object_list, src, group=self.group, device=device) + dist.broadcast_object_list(object_list, src, group=self.group, device=device) # type: ignore[arg-type] return object_list def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]: - dist.gather_object(obj, object_gather_list, dst, group=self.group) + dist.gather_object(obj, object_gather_list, dst, group=self.group) # type: ignore[arg-type] return object_gather_list def scatter_object_list( self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0 ) -> list[Any]: - dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) + dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) # type: ignore[arg-type] return scatter_object_output_list @override def barrier(self, device_ids: Optional[list[int]] = None) -> None: if self.group == dist.GroupMember.NON_GROUP_MEMBER: return - dist.barrier(group=self.group, device_ids=device_ids) + dist.barrier(group=self.group, device_ids=device_ids) # type: ignore[arg-type] def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None: - dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) + dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) # type: ignore[arg-type] @override def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self: diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 935ef72713bcc..87e45293e5e47 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -295,6 +295,7 @@ def clip_gradients_norm( ) -> Tensor: """Clip gradients by norm.""" self.precision.unscale_gradients(optimizer) + assert callable(module.clip_grad_norm_) return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) @override diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 2760c6bd227c1..4f8519eec9610 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -67,7 +67,8 @@ def _materialize(module: Module, device: _DEVICE) -> None: f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented." " This method is used to initialize any children parameters or buffers in this module." ) - module.reset_parameters() + if callable(module.reset_parameters): + module.reset_parameters() def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None: diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 356ab221777ae..cec83fee0f4d7 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -133,7 +133,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - if isinstance(modules, Iterable): _flatten_modules = [] - for m in modules: # type: ignore[union-attr] + for m in modules: _flatten_modules.extend(BaseFinetuning.flatten_modules(m)) _modules = iter(_flatten_modules) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a49610a912e57..8b618ae2be912 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -140,7 +140,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, # this assumes that all iterations used the same batch size samples=iter_num * batch_size, lengths=None if self.length_fn is None else self._lengths[stage], - flops=flops_per_batch, + flops=flops_per_batch, # type: ignore[arg-type] ) def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None: diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 196008b7ed29f..92d444338ff0f 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -163,9 +163,7 @@ def _register_ddp_comm_hook( def _sync_module_states(module: torch.nn.Module) -> None: """Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682.""" - parameters_to_ignore = ( - set(module._ddp_params_and_buffers_to_ignore) if hasattr(module, "_ddp_params_and_buffers_to_ignore") else set() - ) + parameters_to_ignore = set(getattr(module, "_ddp_params_and_buffers_to_ignore", [])) from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.utils import _sync_module_states as torch_sync_module_states diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index efa1aa008a35e..739d8f1d06526 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator +from collections.abc import Generator, Iterable from contextlib import AbstractContextManager, contextmanager from typing import Any, Literal @@ -72,6 +72,8 @@ class LightningDoublePrecisionModule(_DeviceDtypeModuleMixin, nn.Module): """ + _ddp_params_and_buffers_to_ignore: Iterable[str] + def __init__(self, pl_module: "pl.LightningModule") -> None: super().__init__() rank_zero_deprecation(