Skip to content

Commit 64132fb

Browse files
committed
typing
1 parent b0fc3fe commit 64132fb

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
125125
self._lengths[stage] += self.length_fn(batch)
126126

127127
if hasattr(pl_module, "flops_per_batch"):
128-
flops_per_batch = int(pl_module.flops_per_batch)
128+
flops_per_batch = pl_module.flops_per_batch
129129
else:
130130
rank_zero_warn(
131131
"When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property"

src/lightning/pytorch/overrides/distributed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,7 @@ def _register_ddp_comm_hook(
163163

164164
def _sync_module_states(module: torch.nn.Module) -> None:
165165
"""Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682."""
166-
parameters_to_ignore = (
167-
set(module._ddp_params_and_buffers_to_ignore) if hasattr(module, "_ddp_params_and_buffers_to_ignore") else set()
168-
)
166+
parameters_to_ignore = set(getattr(module, "_ddp_params_and_buffers_to_ignore", []))
169167
from torch.distributed.distributed_c10d import _get_default_group
170168
from torch.distributed.utils import _sync_module_states as torch_sync_module_states
171169

0 commit comments

Comments
 (0)