Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.6.0
torch >=2.1.0, <2.7.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.11.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/examples.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.16.0, <0.21.0
torchvision >=0.16.0, <0.22.0
torchmetrics >=0.10.0, <1.8.0
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.6.0
torch >=2.1.0, <2.7.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2024.4.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

requests <2.32.0
torchvision >=0.16.0, <0.21.0
torchvision >=0.16.0, <0.22.0
ipython[all] <8.19.0
torchmetrics >=0.10.0, <1.8.0
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mypy==1.11.0
torch==2.5.1
torch==2.6.0

types-Markdown
types-PyYAML
Expand Down
18 changes: 9 additions & 9 deletions src/lightning/fabric/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def world_size(self) -> int:

@override
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
dist.broadcast(tensor, src, group=self.group)
dist.broadcast(tensor, src, group=self.group) # type: ignore[arg-type]
return tensor

@override
Expand All @@ -62,7 +62,7 @@ def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum"
@override
def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
op = self._convert_to_native_op(op)
dist.reduce(tensor, dst, op=op, group=self.group)
dist.reduce(tensor, dst, op=op, group=self.group) # type: ignore[arg-type]
return tensor

@override
Expand All @@ -72,12 +72,12 @@ def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]:

@override
def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]:
dist.gather(tensor, gather_list, dst, group=self.group)
dist.gather(tensor, gather_list, dst, group=self.group) # type: ignore[arg-type]
return gather_list

@override
def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor:
dist.scatter(tensor, scatter_list, src, group=self.group)
dist.scatter(tensor, scatter_list, src, group=self.group) # type: ignore[arg-type]
return tensor

@override
Expand Down Expand Up @@ -109,27 +109,27 @@ def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]:
def broadcast_object_list(
self, object_list: list[Any], src: int, device: Optional[torch.device] = None
) -> list[Any]:
dist.broadcast_object_list(object_list, src, group=self.group, device=device)
dist.broadcast_object_list(object_list, src, group=self.group, device=device) # type: ignore[arg-type]
return object_list

def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]:
dist.gather_object(obj, object_gather_list, dst, group=self.group)
dist.gather_object(obj, object_gather_list, dst, group=self.group) # type: ignore[arg-type]
return object_gather_list

def scatter_object_list(
self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0
) -> list[Any]:
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group)
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) # type: ignore[arg-type]
return scatter_object_output_list

@override
def barrier(self, device_ids: Optional[list[int]] = None) -> None:
if self.group == dist.GroupMember.NON_GROUP_MEMBER:
return
dist.barrier(group=self.group, device_ids=device_ids)
dist.barrier(group=self.group, device_ids=device_ids) # type: ignore[arg-type]

def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None:
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks)
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) # type: ignore[arg-type]

@override
def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self:
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def clip_gradients_norm(
) -> Tensor:
"""Clip gradients by norm."""
self.precision.unscale_gradients(optimizer)
assert callable(module.clip_grad_norm_)
return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type)

@override
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/utilities/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _materialize(module: Module, device: _DEVICE) -> None:
f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented."
" This method is used to initialize any children parameters or buffers in this module."
)
module.reset_parameters()
if callable(module.reset_parameters):
module.reset_parameters()


def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -

if isinstance(modules, Iterable):
_flatten_modules = []
for m in modules: # type: ignore[union-attr]
for m in modules:
_flatten_modules.extend(BaseFinetuning.flatten_modules(m))

_modules = iter(_flatten_modules)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
# this assumes that all iterations used the same batch size
samples=iter_num * batch_size,
lengths=None if self.length_fn is None else self._lengths[stage],
flops=flops_per_batch,
flops=flops_per_batch, # type: ignore[arg-type]
)

def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None:
Expand Down
4 changes: 1 addition & 3 deletions src/lightning/pytorch/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ def _register_ddp_comm_hook(

def _sync_module_states(module: torch.nn.Module) -> None:
"""Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682."""
parameters_to_ignore = (
set(module._ddp_params_and_buffers_to_ignore) if hasattr(module, "_ddp_params_and_buffers_to_ignore") else set()
)
parameters_to_ignore = set(getattr(module, "_ddp_params_and_buffers_to_ignore", []))
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.utils import _sync_module_states as torch_sync_module_states

Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from collections.abc import Generator, Iterable
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Literal

Expand Down Expand Up @@ -72,6 +72,8 @@ class LightningDoublePrecisionModule(_DeviceDtypeModuleMixin, nn.Module):

"""

_ddp_params_and_buffers_to_ignore: Iterable[str]

def __init__(self, pl_module: "pl.LightningModule") -> None:
super().__init__()
rank_zero_deprecation(
Expand Down
Loading