Skip to content

Commit 0b15d09

Browse files
Set weights_only=False when loading ckpts, since Lightning now defers to torch's default (True)
See PR on this change: Lightning-AI/pytorch-lightning#21072
1 parent e014551 commit 0b15d09

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/lightning_hydra_template/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
5757
log_hyperparameters(object_dict)
5858

5959
log.info("Starting testing!")
60-
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
60+
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path, weights_only=False)
6161

6262
# for predictions use trainer.predict(...)
6363
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)

src/lightning_hydra_template/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
7474

7575
if cfg.get("train"):
7676
log.info("Starting training!")
77-
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
77+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"), weights_only=False)
7878

7979
train_metrics = trainer.callback_metrics
8080

@@ -84,7 +84,7 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
8484
if ckpt_path == "":
8585
log.warning("Best ckpt not found! Using current weights for testing...")
8686
ckpt_path = None
87-
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
87+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path, weights_only=False)
8888
log.info(f"Best ckpt path: {ckpt_path}")
8989

9090
test_metrics = trainer.callback_metrics

0 commit comments

Comments
 (0)