Skip to content

Commit 9cb6e8c

Browse files
authored
Add @override for files in src/lightning/fabric/loggers (#19090)
1 parent 32f5ddd commit 9cb6e8c

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/lightning/fabric/loggers/csv_logs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Dict, List, Optional, Set, Union
2020

2121
from torch import Tensor
22+
from typing_extensions import override
2223

2324
from lightning.fabric.loggers.logger import Logger, rank_zero_experiment
2425
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
@@ -73,6 +74,7 @@ def __init__(
7374
self._flush_logs_every_n_steps = flush_logs_every_n_steps
7475

7576
@property
77+
@override
7678
def name(self) -> str:
7779
"""Gets the name of the experiment.
7880
@@ -83,6 +85,7 @@ def name(self) -> str:
8385
return self._name
8486

8587
@property
88+
@override
8689
def version(self) -> Union[int, str]:
8790
"""Gets the version of the experiment.
8891
@@ -95,11 +98,13 @@ def version(self) -> Union[int, str]:
9598
return self._version
9699

97100
@property
101+
@override
98102
def root_dir(self) -> str:
99103
"""Gets the save directory where the versioned CSV experiments are saved."""
100104
return self._root_dir
101105

102106
@property
107+
@override
103108
def log_dir(self) -> str:
104109
"""The log directory for this run.
105110
@@ -128,10 +133,12 @@ def experiment(self) -> "_ExperimentWriter":
128133
self._experiment = _ExperimentWriter(log_dir=self.log_dir)
129134
return self._experiment
130135

136+
@override
131137
@rank_zero_only
132138
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
133139
raise NotImplementedError("The `CSVLogger` does not yet support logging hyperparameters.")
134140

141+
@override
135142
@rank_zero_only
136143
def log_metrics( # type: ignore[override]
137144
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
@@ -143,11 +150,13 @@ def log_metrics( # type: ignore[override]
143150
if (step + 1) % self._flush_logs_every_n_steps == 0:
144151
self.save()
145152

153+
@override
146154
@rank_zero_only
147155
def save(self) -> None:
148156
super().save()
149157
self.experiment.save()
150158

159+
@override
151160
@rank_zero_only
152161
def finalize(self, status: str) -> None:
153162
if self._experiment is None:

src/lightning/fabric/loggers/tensorboard.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from lightning_utilities.core.imports import RequirementCache
2121
from torch import Tensor
2222
from torch.nn import Module
23+
from typing_extensions import override
2324

2425
from lightning.fabric.loggers.logger import Logger, rank_zero_experiment
2526
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
@@ -109,6 +110,7 @@ def __init__(
109110
self._kwargs = kwargs
110111

111112
@property
113+
@override
112114
def name(self) -> str:
113115
"""Get the name of the experiment.
114116
@@ -119,6 +121,7 @@ def name(self) -> str:
119121
return self._name
120122

121123
@property
124+
@override
122125
def version(self) -> Union[int, str]:
123126
"""Get the experiment version.
124127
@@ -131,6 +134,7 @@ def version(self) -> Union[int, str]:
131134
return self._version
132135

133136
@property
137+
@override
134138
def root_dir(self) -> str:
135139
"""Gets the save directory where the TensorBoard experiments are saved.
136140
@@ -141,6 +145,7 @@ def root_dir(self) -> str:
141145
return self._root_dir
142146

143147
@property
148+
@override
144149
def log_dir(self) -> str:
145150
"""The directory for this run's tensorboard checkpoint.
146151
@@ -191,6 +196,7 @@ def experiment(self) -> "SummaryWriter":
191196
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
192197
return self._experiment
193198

199+
@override
194200
@rank_zero_only
195201
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
196202
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
@@ -212,6 +218,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
212218
f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor."
213219
) from ex
214220

221+
@override
215222
@rank_zero_only
216223
def log_hyperparams( # type: ignore[override]
217224
self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None
@@ -251,6 +258,7 @@ def log_hyperparams( # type: ignore[override]
251258
writer.add_summary(ssi)
252259
writer.add_summary(sei)
253260

261+
@override
254262
@rank_zero_only
255263
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
256264
model_example_input = getattr(model, "example_input_array", None)
@@ -278,10 +286,12 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None
278286
else:
279287
self.experiment.add_graph(model, input_array)
280288

289+
@override
281290
@rank_zero_only
282291
def save(self) -> None:
283292
self.experiment.flush()
284293

294+
@override
285295
@rank_zero_only
286296
def finalize(self, status: str) -> None:
287297
if self._experiment is not None:

0 commit comments

Comments
 (0)