From f2a61dc3e3a36966ac889450830ce09b5af74176 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 22 Apr 2025 13:13:15 +0000 Subject: [PATCH 01/12] build(deps): bump torch from 2.5.1 to 2.6.0 in /requirements Bumps [torch](https://github.com/pytorch/pytorch) from 2.5.1 to 2.6.0. - [Release notes](https://github.com/pytorch/pytorch/releases) - [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md) - [Commits](https://github.com/pytorch/pytorch/compare/v2.5.1...v2.6.0) --- updated-dependencies: - dependency-name: torch dependency-version: 2.6.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements/fabric/base.txt | 2 +- requirements/pytorch/base.txt | 2 +- requirements/typing.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 70cd75c1c0d37..4c31dac1c922a 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.6.0 +torch >=2.1.0, <2.7.0 fsspec[http] >=2022.5.0, <2024.4.0 packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.11.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 4b32e44d2cacc..02beb77bb7b87 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.6.0 +torch >=2.1.0, <2.7.0 tqdm >=4.57.0, <4.67.0 PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2024.4.0 diff --git a/requirements/typing.txt b/requirements/typing.txt index 71414998dd7f3..52f4df899e9d8 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.11.0 -torch==2.5.1 +torch==2.6.0 types-Markdown types-PyYAML From bcc636a1a63b63d059511f1181ee0a2895e4ed4d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 22 Apr 2025 15:32:48 +0200 Subject: [PATCH 02/12] build(deps): update torchvision requirement from <0.21.0,>=0.16.0 to >=0.16.0,<0.22.0 in /requirements (#20736) build(deps): update torchvision requirement in /requirements Updates the requirements on [torchvision](https://github.com/pytorch/vision) to permit the latest version. - [Release notes](https://github.com/pytorch/vision/releases) - [Commits](https://github.com/pytorch/vision/compare/v0.16.0...v0.21.0) --- updated-dependencies: - dependency-name: torchvision dependency-version: 0.21.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/fabric/examples.txt | 2 +- requirements/pytorch/examples.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt index 6be089ebb9767..ba68ae5fec613 100644 --- a/requirements/fabric/examples.txt +++ b/requirements/fabric/examples.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision >=0.16.0, <0.21.0 +torchvision >=0.16.0, <0.22.0 torchmetrics >=0.10.0, <1.8.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index 8a19179b813e0..f64eca95c0243 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment requests <2.32.0 -torchvision >=0.16.0, <0.21.0 +torchvision >=0.16.0, <0.22.0 ipython[all] <8.19.0 torchmetrics >=0.10.0, <1.8.0 lightning-utilities >=0.8.0, <0.12.0 From 783f905d8998bb4b829a316f805c960eb65c4d73 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 18:28:16 +0200 Subject: [PATCH 03/12] Protocol, ProcessGroup --- src/lightning/fabric/utilities/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index 1d7235fa36383..deeb465ce8726 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -61,7 +61,7 @@ def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ... @runtime_checkable -class CollectibleGroup(Protocol): +class CollectibleGroup(Protocol, ProcessGroup): def size(self) -> int: ... def rank(self) -> int: ... @@ -81,7 +81,7 @@ def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: @runtime_checkable -class Optimizable(Steppable, Protocol): +class Optimizable(Steppable): """To structurally type ``optimizer``""" param_groups: list[dict[Any, Any]] From e4532fe0fe48dbf8cf92deefd0d3f3b19d234e9e Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 18:40:51 +0200 Subject: [PATCH 04/12] Revert "Protocol, ProcessGroup" This reverts commit 783f905d8998bb4b829a316f805c960eb65c4d73. --- src/lightning/fabric/utilities/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index deeb465ce8726..1d7235fa36383 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -61,7 +61,7 @@ def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ... @runtime_checkable -class CollectibleGroup(Protocol, ProcessGroup): +class CollectibleGroup(Protocol): def size(self) -> int: ... def rank(self) -> int: ... @@ -81,7 +81,7 @@ def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: @runtime_checkable -class Optimizable(Steppable): +class Optimizable(Steppable, Protocol): """To structurally type ``optimizer``""" param_groups: list[dict[Any, Any]] From 432b0caa46b9824f2e031480d21f57eab32c414d Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 19:05:04 +0200 Subject: [PATCH 05/12] cast(super().group, ProcessGroup) --- .../fabric/plugins/collectives/torch_collective.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 81e15a33cb983..bd9db96389ef8 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -1,6 +1,6 @@ import datetime import os -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import torch import torch.distributed as dist @@ -8,7 +8,7 @@ from typing_extensions import Self, override from lightning.fabric.plugins.collectives.collective import Collective -from lightning.fabric.utilities.types import CollectibleGroup, RedOpType, ReduceOp +from lightning.fabric.utilities.types import CollectibleGroup, RedOpType, ReduceOp, ProcessGroup if dist.is_available(): from torch.distributed.constants import default_pg_timeout @@ -32,10 +32,10 @@ def __init__(self) -> None: @property @override - def group(self) -> CollectibleGroup: + def group(self) -> ProcessGroup: if self._group is None: self._group = dist.GroupMember.WORLD - return super().group + return cast(super().group, ProcessGroup) @property @override From 40096bf5eb487dde8a10ac22fb8c02ceaeb4534a Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 19:25:36 +0200 Subject: [PATCH 06/12] Revert "cast(super().group, ProcessGroup)" This reverts commit 432b0caa46b9824f2e031480d21f57eab32c414d. --- .../fabric/plugins/collectives/torch_collective.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index bd9db96389ef8..81e15a33cb983 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -1,6 +1,6 @@ import datetime import os -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union import torch import torch.distributed as dist @@ -8,7 +8,7 @@ from typing_extensions import Self, override from lightning.fabric.plugins.collectives.collective import Collective -from lightning.fabric.utilities.types import CollectibleGroup, RedOpType, ReduceOp, ProcessGroup +from lightning.fabric.utilities.types import CollectibleGroup, RedOpType, ReduceOp if dist.is_available(): from torch.distributed.constants import default_pg_timeout @@ -32,10 +32,10 @@ def __init__(self) -> None: @property @override - def group(self) -> ProcessGroup: + def group(self) -> CollectibleGroup: if self._group is None: self._group = dist.GroupMember.WORLD - return cast(super().group, ProcessGroup) + return super().group @property @override From 30da1d2a338db721151de8657774bc35be57bcab Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 19:29:08 +0200 Subject: [PATCH 07/12] # type: ignore[arg-type] --- .../plugins/collectives/torch_collective.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 81e15a33cb983..883380bb881aa 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -50,7 +50,7 @@ def world_size(self) -> int: @override def broadcast(self, tensor: Tensor, src: int) -> Tensor: - dist.broadcast(tensor, src, group=self.group) + dist.broadcast(tensor, src, group=self.group) # type: ignore[arg-type] return tensor @override @@ -62,7 +62,7 @@ def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum" @override def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor: op = self._convert_to_native_op(op) - dist.reduce(tensor, dst, op=op, group=self.group) + dist.reduce(tensor, dst, op=op, group=self.group) # type: ignore[arg-type] return tensor @override @@ -72,12 +72,12 @@ def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: @override def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: - dist.gather(tensor, gather_list, dst, group=self.group) + dist.gather(tensor, gather_list, dst, group=self.group) # type: ignore[arg-type] return gather_list @override def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: - dist.scatter(tensor, scatter_list, src, group=self.group) + dist.scatter(tensor, scatter_list, src, group=self.group) # type: ignore[arg-type] return tensor @override @@ -109,27 +109,27 @@ def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]: def broadcast_object_list( self, object_list: list[Any], src: int, device: Optional[torch.device] = None ) -> list[Any]: - dist.broadcast_object_list(object_list, src, group=self.group, device=device) + dist.broadcast_object_list(object_list, src, group=self.group, device=device) # type: ignore[arg-type] return object_list def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]: - dist.gather_object(obj, object_gather_list, dst, group=self.group) + dist.gather_object(obj, object_gather_list, dst, group=self.group) # type: ignore[arg-type] return object_gather_list def scatter_object_list( self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0 ) -> list[Any]: - dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) + dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) # type: ignore[arg-type] return scatter_object_output_list @override def barrier(self, device_ids: Optional[list[int]] = None) -> None: if self.group == dist.GroupMember.NON_GROUP_MEMBER: return - dist.barrier(group=self.group, device_ids=device_ids) + dist.barrier(group=self.group, device_ids=device_ids) # type: ignore[arg-type] def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None: - dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) + dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) # type: ignore[arg-type] @override def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self: From b0fc3feadb27e4b2988ed87e09d31970f0fbc814 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 20:54:10 +0200 Subject: [PATCH 08/12] typing --- src/lightning/fabric/strategies/xla_fsdp.py | 1 + src/lightning/fabric/utilities/init.py | 3 ++- src/lightning/pytorch/callbacks/finetuning.py | 2 +- src/lightning/pytorch/callbacks/throughput_monitor.py | 2 +- src/lightning/pytorch/plugins/precision/double.py | 3 ++- 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 935ef72713bcc..87e45293e5e47 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -295,6 +295,7 @@ def clip_gradients_norm( ) -> Tensor: """Clip gradients by norm.""" self.precision.unscale_gradients(optimizer) + assert callable(module.clip_grad_norm_) return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) @override diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 2760c6bd227c1..4f8519eec9610 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -67,7 +67,8 @@ def _materialize(module: Module, device: _DEVICE) -> None: f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented." " This method is used to initialize any children parameters or buffers in this module." ) - module.reset_parameters() + if callable(module.reset_parameters): + module.reset_parameters() def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None: diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 356ab221777ae..cec83fee0f4d7 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -133,7 +133,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - if isinstance(modules, Iterable): _flatten_modules = [] - for m in modules: # type: ignore[union-attr] + for m in modules: _flatten_modules.extend(BaseFinetuning.flatten_modules(m)) _modules = iter(_flatten_modules) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a49610a912e57..00d80132e4a98 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -125,7 +125,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, self._lengths[stage] += self.length_fn(batch) if hasattr(pl_module, "flops_per_batch"): - flops_per_batch = pl_module.flops_per_batch + flops_per_batch = int(pl_module.flops_per_batch) else: rank_zero_warn( "When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property" diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index efa1aa008a35e..566b99aa19b60 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager -from typing import Any, Literal +from typing import Any, Literal, Iterable import torch import torch.nn as nn @@ -71,6 +71,7 @@ class LightningDoublePrecisionModule(_DeviceDtypeModuleMixin, nn.Module): pl_module: the model to wrap """ + _ddp_params_and_buffers_to_ignore: Iterable[str] def __init__(self, pl_module: "pl.LightningModule") -> None: super().__init__() From 44314cda8300bd7495f97ba5110e4b4dc31c04c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Apr 2025 18:56:01 +0000 Subject: [PATCH 09/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/plugins/precision/double.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 566b99aa19b60..739d8f1d06526 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator +from collections.abc import Generator, Iterable from contextlib import AbstractContextManager, contextmanager -from typing import Any, Literal, Iterable +from typing import Any, Literal import torch import torch.nn as nn @@ -71,6 +71,7 @@ class LightningDoublePrecisionModule(_DeviceDtypeModuleMixin, nn.Module): pl_module: the model to wrap """ + _ddp_params_and_buffers_to_ignore: Iterable[str] def __init__(self, pl_module: "pl.LightningModule") -> None: From 64132fb8c01705ed5aa586b8dc80e6c4618eb893 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 21:36:20 +0200 Subject: [PATCH 10/12] typing --- src/lightning/pytorch/callbacks/throughput_monitor.py | 2 +- src/lightning/pytorch/overrides/distributed.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 00d80132e4a98..a49610a912e57 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -125,7 +125,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, self._lengths[stage] += self.length_fn(batch) if hasattr(pl_module, "flops_per_batch"): - flops_per_batch = int(pl_module.flops_per_batch) + flops_per_batch = pl_module.flops_per_batch else: rank_zero_warn( "When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property" diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 196008b7ed29f..92d444338ff0f 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -163,9 +163,7 @@ def _register_ddp_comm_hook( def _sync_module_states(module: torch.nn.Module) -> None: """Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682.""" - parameters_to_ignore = ( - set(module._ddp_params_and_buffers_to_ignore) if hasattr(module, "_ddp_params_and_buffers_to_ignore") else set() - ) + parameters_to_ignore = set(getattr(module, "_ddp_params_and_buffers_to_ignore", [])) from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.utils import _sync_module_states as torch_sync_module_states From 41932ec59163200e0d4f9de842a1434d325941aa Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 22:09:29 +0200 Subject: [PATCH 11/12] typing --- src/lightning/pytorch/callbacks/throughput_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a49610a912e57..288e6ec7af018 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -140,7 +140,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, # this assumes that all iterations used the same batch size samples=iter_num * batch_size, lengths=None if self.length_fn is None else self._lengths[stage], - flops=flops_per_batch, + flops=int(flops_per_batch) if flops_per_batch is not None else None, ) def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None: From e610b0f2511f646e6ef8e49a0457c8b0723dbef5 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 22 Apr 2025 22:30:30 +0200 Subject: [PATCH 12/12] type: ignore[arg-type] --- src/lightning/pytorch/callbacks/throughput_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 288e6ec7af018..8b618ae2be912 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -140,7 +140,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, # this assumes that all iterations used the same batch size samples=iter_num * batch_size, lengths=None if self.length_fn is None else self._lengths[stage], - flops=int(flops_per_batch) if flops_per_batch is not None else None, + flops=flops_per_batch, # type: ignore[arg-type] ) def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None: