Skip to content

Commit 896dbf7

Browse files
committed
fix type hints in XLA, SLURM, and Comet logger modules for improved clarity
1 parent df3eb79 commit 896dbf7

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

src/lightning/fabric/accelerators/xla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15+
import warnings
1516
from typing import Any
1617

1718
import torch

src/lightning/fabric/plugins/environments/slurm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def auto_requeue(self) -> bool:
5656
return self.slurm_impl.auto_requeue
5757

5858
@property
59-
def requeue_signal(self) -> Optional[signal.Signals]:
59+
def requeue_signal(self) -> signal.Signals | None:
6060
return self.slurm_impl.requeue_signal
6161

6262
@property
@@ -93,7 +93,7 @@ def job_name() -> str | None:
9393
return os.environ.get("SLURM_JOB_NAME")
9494

9595
@staticmethod
96-
def job_id() -> Optional[int]:
96+
def job_id() -> int | None:
9797
_raise_enterprise_not_available()
9898
from pytorch_lightning_enterprise.plugins.environments.slurm import (
9999
SLURMEnvironment as EnterpriseSLURMEnvironment,

src/lightning/pytorch/loggers/comet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None:
242242

243243
@override
244244
@rank_zero_only
245-
def log_metrics(self, metrics: Mapping[str, Tensor | float], step: Optional[int] = None) -> None:
245+
def log_metrics(self, metrics: Mapping[str, Tensor | float], step: int | None = None) -> None:
246246
return self.logger_impl.log_metrics(metrics, step)
247247

248248
@override
@@ -287,5 +287,5 @@ def version(self) -> str | None:
287287
return self.logger_impl.version
288288

289289
@override
290-
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
290+
def log_graph(self, model: Module, input_array: Tensor | None = None) -> None:
291291
return self.logger_impl.log_graph(model, input_array)

0 commit comments

Comments
 (0)