Skip to content

Commit b0fc3fe

Browse files
committed
typing
1 parent 30da1d2 commit b0fc3fe

File tree

5 files changed

+7
-4
lines changed

5 files changed

+7
-4
lines changed

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def clip_gradients_norm(
295295
) -> Tensor:
296296
"""Clip gradients by norm."""
297297
self.precision.unscale_gradients(optimizer)
298+
assert callable(module.clip_grad_norm_)
298299
return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type)
299300

300301
@override

src/lightning/fabric/utilities/init.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def _materialize(module: Module, device: _DEVICE) -> None:
6767
f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented."
6868
" This method is used to initialize any children parameters or buffers in this module."
6969
)
70-
module.reset_parameters()
70+
if callable(module.reset_parameters):
71+
module.reset_parameters()
7172

7273

7374
def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None:

src/lightning/pytorch/callbacks/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
133133

134134
if isinstance(modules, Iterable):
135135
_flatten_modules = []
136-
for m in modules: # type: ignore[union-attr]
136+
for m in modules:
137137
_flatten_modules.extend(BaseFinetuning.flatten_modules(m))
138138

139139
_modules = iter(_flatten_modules)

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
125125
self._lengths[stage] += self.length_fn(batch)
126126

127127
if hasattr(pl_module, "flops_per_batch"):
128-
flops_per_batch = pl_module.flops_per_batch
128+
flops_per_batch = int(pl_module.flops_per_batch)
129129
else:
130130
rank_zero_warn(
131131
"When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property"

src/lightning/pytorch/plugins/precision/double.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from collections.abc import Generator
1515
from contextlib import AbstractContextManager, contextmanager
16-
from typing import Any, Literal
16+
from typing import Any, Literal, Iterable
1717

1818
import torch
1919
import torch.nn as nn
@@ -71,6 +71,7 @@ class LightningDoublePrecisionModule(_DeviceDtypeModuleMixin, nn.Module):
7171
pl_module: the model to wrap
7272
7373
"""
74+
_ddp_params_and_buffers_to_ignore: Iterable[str]
7475

7576
def __init__(self, pl_module: "pl.LightningModule") -> None:
7677
super().__init__()

0 commit comments

Comments
 (0)