Skip to content

Commit 119039b

Browse files
authored
Add @override for files in src/lightning/fabric/plugins/io (#19157)
1 parent 5c36e99 commit 119039b

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/lightning/fabric/plugins/io/torch_io.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import os
1616
from typing import Any, Callable, Dict, Optional
1717

18+
from typing_extensions import override
19+
1820
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
1921
from lightning.fabric.utilities.cloud_io import _atomic_save, get_filesystem
2022
from lightning.fabric.utilities.cloud_io import _load as pl_load
@@ -31,6 +33,7 @@ class TorchCheckpointIO(CheckpointIO):
3133
3234
"""
3335

36+
@override
3437
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
3538
"""Save model/training states as a checkpoint file through state-dump and file-write.
3639
@@ -54,6 +57,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
5457
fs.makedirs(os.path.dirname(path), exist_ok=True)
5558
_atomic_save(checkpoint, path)
5659

60+
@override
5761
def load_checkpoint(
5862
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
5963
) -> Dict[str, Any]:
@@ -78,6 +82,7 @@ def load_checkpoint(
7882

7983
return pl_load(path, map_location=map_location)
8084

85+
@override
8186
def remove_checkpoint(self, path: _PATH) -> None:
8287
"""Remove checkpoint file from the filesystem.
8388

src/lightning/fabric/plugins/io/xla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from lightning_utilities.core.apply_func import apply_to_collection
2020
from lightning_utilities.core.imports import RequirementCache
21+
from typing_extensions import override
2122

2223
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
2324
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
@@ -39,6 +40,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3940
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
4041
super().__init__(*args, **kwargs)
4142

43+
@override
4244
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
4345
"""Save model/training states as a checkpoint file through state-dump and file-write.
4446

0 commit comments

Comments
 (0)