Skip to content

Commit 21f4a22

Browse files
Update dependency lightning to >=2.6,<2.7 (#76)
Set `weights_only=False` when loading ckpts, since Lightning now defers to torch's default (`True`) * See PR on this change: Lightning-AI/pytorch-lightning#21072 --------- Co-authored-by: Nathan Painchaud <[email protected]>
1 parent 162695e commit 21f4a22

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ dependencies = [
99
"hydra-colorlog>=1.2.0,<2",
1010
"hydra-core>=1.3.2,<2",
1111
"hydra-optuna-sweeper>=1.2.0,<2",
12-
"lightning>=2.5.0.post0,<2.6",
12+
"lightning>=2.6,<2.7",
1313
"rich>=14,<15",
1414
"rootutils>=1.0.7,<2",
1515
"torchmetrics>=1.8,<1.9",

src/lightning_hydra_template/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
5757
log_hyperparameters(object_dict)
5858

5959
log.info("Starting testing!")
60-
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
60+
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path, weights_only=False)
6161

6262
# for predictions use trainer.predict(...)
6363
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)

src/lightning_hydra_template/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
7474

7575
if cfg.get("train"):
7676
log.info("Starting training!")
77-
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
77+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"), weights_only=False)
7878

7979
train_metrics = trainer.callback_metrics
8080

@@ -84,7 +84,7 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
8484
if ckpt_path == "":
8585
log.warning("Best ckpt not found! Using current weights for testing...")
8686
ckpt_path = None
87-
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
87+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path, weights_only=False)
8888
log.info(f"Best ckpt path: {ckpt_path}")
8989

9090
test_metrics = trainer.callback_metrics

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)