Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def setup(
devices: Union[int, str] = "auto",
num_nodes: int = 1,
tokenizer_dir: Optional[Path] = None,
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "tensorboard",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "tensorboard",
seed: int = 42,
compiler: Optional[Literal["thunder", "torch"]] = "thunder",
executors: Optional[List[str]] = ("sdpa", "torchcompile", "nvfuser", "torch"),
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def setup(
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def setup(
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setup(
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def setup(
log: LogArgs = LogArgs(),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/lora_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def setup(
log: LogArgs = LogArgs(),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def setup(
devices: Union[int, str] = "auto",
num_nodes: int = 1,
tokenizer_dir: Optional[Path] = None,
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "tensorboard",
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "tensorboard",
seed: int = 42,
):
"""Pretrain a model.
Expand Down
28 changes: 25 additions & 3 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from lightning.fabric.strategies import FSDPStrategy, ModelParallelStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.cli import instantiate_class
from lightning.pytorch.loggers import MLFlowLogger, WandbLogger
from lightning.pytorch.loggers import MLFlowLogger, SwanLabLogger, WandbLogger
from lightning_utilities.core.imports import RequirementCache, module_available
from packaging import version
from torch.serialization import normalize_storage_type
Expand Down Expand Up @@ -567,7 +567,7 @@ def parse_devices(devices: Union[str, int]) -> int:


def choose_logger(
logger_name: Literal["csv", "tensorboard", "wandb", "mlflow"],
logger_name: Literal["csv", "tensorboard", "wandb", "mlflow", "swanlab"],
out_dir: Path,
name: str,
log_interval: int = 1,
Expand All @@ -586,7 +586,29 @@ def choose_logger(
return WandbLogger(project=project, name=run, group=group, resume=resume, **kwargs)
if logger_name == "mlflow":
return MLFlowLogger(experiment_name=name, **kwargs)
raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.")
if logger_name == "swanlab":
project = log_args.pop("project", name) if log_args else name
experiment_name = log_args.pop("run", None) if log_args else None
description = log_args.pop("description", None) if log_args else None
logdir = (
log_args.pop("logdir", str(out_dir / "logs" / "swanlab")) if log_args else str(out_dir / "logs" / "swanlab")
)
mode = log_args.pop("mode", "cloud") if log_args else "cloud"
config = log_args.pop("config", None) if log_args else None
version = log_args.pop("version", None) if log_args else None
return SwanLabLogger(
project=project,
experiment_name=experiment_name,
description=description,
logdir=logdir,
mode=mode,
version=version if resume else None,
config=config,
**kwargs,
)
raise ValueError(
f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb', 'mlflow', 'swanlab'."
)


def get_argument_names(cls):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
SwanLabLogger,
_RunIf,
capture_hparams,
check_file_size_on_cpu_and_warn,
Expand Down Expand Up @@ -298,6 +299,15 @@ def test_choose_logger(tmp_path):
assert isinstance(choose_logger("wandb", out_dir=tmp_path, name="wandb"), WandbLogger)
if RequirementCache("mlflow") or RequirementCache("mlflow-skinny"):
assert isinstance(choose_logger("mlflow", out_dir=tmp_path, name="wandb"), MLFlowLogger)
if RequirementCache("swanlab"):
assert isinstance(
choose_logger(
"swanlab", out_dir=tmp_path, name="swanlab", log_args={"project": "test", "mode": "disabled"}
),
SwanLabLogger,
)
with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."):
choose_logger("foo", out_dir=tmp_path, name="foo")
with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."):
choose_logger("foo", out_dir=tmp_path, name="foo")

Expand Down
Loading