Skip to content

Commit bcc8de8

Browse files
authored
Update Trainer's ckpt_path type for pathlib Path (#19362)
1 parent b0e1ee2 commit bcc8de8

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def fit(
506506
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
507507
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
508508
datamodule: Optional[LightningDataModule] = None,
509-
ckpt_path: Optional[str] = None,
509+
ckpt_path: Optional[_PATH] = None,
510510
) -> None:
511511
r"""Runs the full optimization routine.
512512
@@ -550,7 +550,7 @@ def _fit_impl(
550550
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
551551
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
552552
datamodule: Optional[LightningDataModule] = None,
553-
ckpt_path: Optional[str] = None,
553+
ckpt_path: Optional[_PATH] = None,
554554
) -> None:
555555
log.debug(f"{self.__class__.__name__}: trainer fit stage")
556556

@@ -586,7 +586,7 @@ def validate(
586586
self,
587587
model: Optional["pl.LightningModule"] = None,
588588
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
589-
ckpt_path: Optional[str] = None,
589+
ckpt_path: Optional[_PATH] = None,
590590
verbose: bool = True,
591591
datamodule: Optional[LightningDataModule] = None,
592592
) -> _EVALUATE_OUTPUT:
@@ -649,7 +649,7 @@ def _validate_impl(
649649
self,
650650
model: Optional["pl.LightningModule"] = None,
651651
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
652-
ckpt_path: Optional[str] = None,
652+
ckpt_path: Optional[_PATH] = None,
653653
verbose: bool = True,
654654
datamodule: Optional[LightningDataModule] = None,
655655
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@@ -694,7 +694,7 @@ def test(
694694
self,
695695
model: Optional["pl.LightningModule"] = None,
696696
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
697-
ckpt_path: Optional[str] = None,
697+
ckpt_path: Optional[_PATH] = None,
698698
verbose: bool = True,
699699
datamodule: Optional[LightningDataModule] = None,
700700
) -> _EVALUATE_OUTPUT:
@@ -758,7 +758,7 @@ def _test_impl(
758758
self,
759759
model: Optional["pl.LightningModule"] = None,
760760
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
761-
ckpt_path: Optional[str] = None,
761+
ckpt_path: Optional[_PATH] = None,
762762
verbose: bool = True,
763763
datamodule: Optional[LightningDataModule] = None,
764764
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
@@ -805,7 +805,7 @@ def predict(
805805
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
806806
datamodule: Optional[LightningDataModule] = None,
807807
return_predictions: Optional[bool] = None,
808-
ckpt_path: Optional[str] = None,
808+
ckpt_path: Optional[_PATH] = None,
809809
) -> Optional[_PREDICT_OUTPUT]:
810810
r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to
811811
perform distributed and batched predictions. Logging is disabled in the predict hooks.
@@ -870,7 +870,7 @@ def _predict_impl(
870870
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
871871
datamodule: Optional[LightningDataModule] = None,
872872
return_predictions: Optional[bool] = None,
873-
ckpt_path: Optional[str] = None,
873+
ckpt_path: Optional[_PATH] = None,
874874
) -> Optional[_PREDICT_OUTPUT]:
875875
# --------------------
876876
# SETUP HOOK

0 commit comments

Comments
 (0)