20
20
from lightning_utilities .core .imports import RequirementCache
21
21
from torch import Tensor
22
22
from torch .nn import Module
23
+ from typing_extensions import override
23
24
24
25
from lightning .fabric .loggers .logger import Logger , rank_zero_experiment
25
26
from lightning .fabric .utilities .cloud_io import _is_dir , get_filesystem
@@ -109,6 +110,7 @@ def __init__(
109
110
self ._kwargs = kwargs
110
111
111
112
@property
113
+ @override
112
114
def name (self ) -> str :
113
115
"""Get the name of the experiment.
114
116
@@ -119,6 +121,7 @@ def name(self) -> str:
119
121
return self ._name
120
122
121
123
@property
124
+ @override
122
125
def version (self ) -> Union [int , str ]:
123
126
"""Get the experiment version.
124
127
@@ -131,6 +134,7 @@ def version(self) -> Union[int, str]:
131
134
return self ._version
132
135
133
136
@property
137
+ @override
134
138
def root_dir (self ) -> str :
135
139
"""Gets the save directory where the TensorBoard experiments are saved.
136
140
@@ -141,6 +145,7 @@ def root_dir(self) -> str:
141
145
return self ._root_dir
142
146
143
147
@property
148
+ @override
144
149
def log_dir (self ) -> str :
145
150
"""The directory for this run's tensorboard checkpoint.
146
151
@@ -191,6 +196,7 @@ def experiment(self) -> "SummaryWriter":
191
196
self ._experiment = SummaryWriter (log_dir = self .log_dir , ** self ._kwargs )
192
197
return self ._experiment
193
198
199
+ @override
194
200
@rank_zero_only
195
201
def log_metrics (self , metrics : Mapping [str , float ], step : Optional [int ] = None ) -> None :
196
202
assert rank_zero_only .rank == 0 , "experiment tried to log from global_rank != 0"
@@ -212,6 +218,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
212
218
f"\n you tried to log { v } which is currently not supported. Try a dict or a scalar/tensor."
213
219
) from ex
214
220
221
+ @override
215
222
@rank_zero_only
216
223
def log_hyperparams ( # type: ignore[override]
217
224
self , params : Union [Dict [str , Any ], Namespace ], metrics : Optional [Dict [str , Any ]] = None
@@ -251,6 +258,7 @@ def log_hyperparams( # type: ignore[override]
251
258
writer .add_summary (ssi )
252
259
writer .add_summary (sei )
253
260
261
+ @override
254
262
@rank_zero_only
255
263
def log_graph (self , model : Module , input_array : Optional [Tensor ] = None ) -> None :
256
264
model_example_input = getattr (model , "example_input_array" , None )
@@ -278,10 +286,12 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None
278
286
else :
279
287
self .experiment .add_graph (model , input_array )
280
288
289
+ @override
281
290
@rank_zero_only
282
291
def save (self ) -> None :
283
292
self .experiment .flush ()
284
293
294
+ @override
285
295
@rank_zero_only
286
296
def finalize (self , status : str ) -> None :
287
297
if self ._experiment is not None :
0 commit comments