Skip to content

Commit dd76b26

Browse files
committed
Control checkpoint location
1 parent a437c6e commit dd76b26

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

caduceus_distill/distill.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import os
44
from datetime import UTC, datetime
55
from functools import partial
6+
from pathlib import Path
67
from typing import Annotated, Any, Literal
78

8-
import fsspec
99
import lightning as L
1010
import numpy as np
1111
import torch
@@ -513,9 +513,12 @@ def main(
513513
"--cosine_anneal", help="Use cosine annealing for learning rate scheduling"
514514
),
515515
] = False,
516-
gcs_bucket: Annotated[
517-
str, typer.Option(help="GCS bucket name (e.g. for checkpoints)")
518-
] = "cadu-distill",
516+
checkpoint_dirpath: Annotated[
517+
str | None,
518+
typer.Option(
519+
help="Path to the checkpoint directory, this can be local or ffspec compatible path"
520+
),
521+
] = None,
519522
) -> None:
520523
L.seed_everything(42, workers=True)
521524

@@ -556,19 +559,14 @@ def main(
556559
if not no_wandb:
557560
wandb_logger = WandbLogger(project=project_name, name=full_run_name)
558561

559-
try:
560-
from gcsfs import GCSFileSystem
561-
562-
fs: GCSFileSystem = fsspec.filesystem("gs")
563-
assert len(fs.info(f"gs://{gcs_bucket}")) > 0
564-
checkpoint_dirpath = f"gs://{gcs_bucket}/checkpoints/{full_run_name}/"
565-
except Exception:
566-
logger.exception(
567-
"Failed to probe GCS, will use local filesystem for checkpoints."
568-
)
562+
if checkpoint_dirpath is not None:
563+
checkpoint_dirpath = f"{checkpoint_dirpath}/{full_run_name}/"
564+
else:
569565
checkpoint_dirpath = f"checkpoints/{full_run_name}/"
566+
Path(checkpoint_dirpath).mkdir(parents=True, exist_ok=True)
567+
568+
logger.info(f"Using checkpoint directory: {checkpoint_dirpath}")
570569

571-
logger.info(f"Checkpoint directory: {checkpoint_dirpath}")
572570
checkpoint_callback = ModelCheckpoint(
573571
dirpath=checkpoint_dirpath,
574572
filename="student-caduceus__epoch={epoch:02d}__val_loss_total={val/loss/total:.3f}__step={step}",

0 commit comments

Comments
 (0)