|
3 | 3 | import os |
4 | 4 | from datetime import UTC, datetime |
5 | 5 | from functools import partial |
| 6 | +from pathlib import Path |
6 | 7 | from typing import Annotated, Any, Literal |
7 | 8 |
|
8 | | -import fsspec |
9 | 9 | import lightning as L |
10 | 10 | import numpy as np |
11 | 11 | import torch |
@@ -513,9 +513,12 @@ def main( |
513 | 513 | "--cosine_anneal", help="Use cosine annealing for learning rate scheduling" |
514 | 514 | ), |
515 | 515 | ] = False, |
516 | | - gcs_bucket: Annotated[ |
517 | | - str, typer.Option(help="GCS bucket name (e.g. for checkpoints)") |
518 | | - ] = "cadu-distill", |
| 516 | + checkpoint_dirpath: Annotated[ |
| 517 | + str | None, |
| 518 | + typer.Option( |
| 519 | + help="Path to the checkpoint directory, this can be local or ffspec compatible path" |
| 520 | + ), |
| 521 | + ] = None, |
519 | 522 | ) -> None: |
520 | 523 | L.seed_everything(42, workers=True) |
521 | 524 |
|
@@ -556,19 +559,14 @@ def main( |
556 | 559 | if not no_wandb: |
557 | 560 | wandb_logger = WandbLogger(project=project_name, name=full_run_name) |
558 | 561 |
|
559 | | - try: |
560 | | - from gcsfs import GCSFileSystem |
561 | | - |
562 | | - fs: GCSFileSystem = fsspec.filesystem("gs") |
563 | | - assert len(fs.info(f"gs://{gcs_bucket}")) > 0 |
564 | | - checkpoint_dirpath = f"gs://{gcs_bucket}/checkpoints/{full_run_name}/" |
565 | | - except Exception: |
566 | | - logger.exception( |
567 | | - "Failed to probe GCS, will use local filesystem for checkpoints." |
568 | | - ) |
| 562 | + if checkpoint_dirpath is not None: |
| 563 | + checkpoint_dirpath = f"{checkpoint_dirpath}/{full_run_name}/" |
| 564 | + else: |
569 | 565 | checkpoint_dirpath = f"checkpoints/{full_run_name}/" |
| 566 | + Path(checkpoint_dirpath).mkdir(parents=True, exist_ok=True) |
| 567 | + |
| 568 | + logger.info(f"Using checkpoint directory: {checkpoint_dirpath}") |
570 | 569 |
|
571 | | - logger.info(f"Checkpoint directory: {checkpoint_dirpath}") |
572 | 570 | checkpoint_callback = ModelCheckpoint( |
573 | 571 | dirpath=checkpoint_dirpath, |
574 | 572 | filename="student-caduceus__epoch={epoch:02d}__val_loss_total={val/loss/total:.3f}__step={step}", |
|
0 commit comments