diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index f5a47bb4ff..4b3d6b6bfc 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -76,7 +76,7 @@ def setup( devices: Union[int, str] = "auto", num_nodes: int = 1, tokenizer_dir: Optional[Path] = None, - logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "tensorboard", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "tensorboard", seed: int = 42, compiler: Optional[Literal["thunder", "torch"]] = "thunder", executors: Optional[List[str]] = ("sdpa", "torchcompile", "nvfuser", "torch"), diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index b874231bf0..66f503208f 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -64,7 +64,7 @@ def setup( eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", - logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index 8ec25c905c..d8646e8cbf 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -66,7 +66,7 @@ def setup( eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", - logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index eec16bcde2..a6587e43a7 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -60,7 +60,7 @@ def setup( eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", - logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index eee57cff78..a92465bfe8 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -75,7 +75,7 @@ def setup( log: LogArgs = LogArgs(), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), optimizer: Union[str, Dict] = "AdamW", - logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: diff --git a/litgpt/finetune/lora_legacy.py b/litgpt/finetune/lora_legacy.py index 6575bc10db..6955b87f5c 100644 --- a/litgpt/finetune/lora_legacy.py +++ b/litgpt/finetune/lora_legacy.py @@ -74,7 +74,7 @@ def setup( log: LogArgs = LogArgs(), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), optimizer: Union[str, Dict] = "AdamW", - logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index e61b1494e7..d1524c11ef 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -70,7 +70,7 @@ def setup( devices: Union[int, str] = "auto", num_nodes: int = 1, tokenizer_dir: Optional[Path] = None, - logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "tensorboard", + logger_name: Literal["wandb", "tensorboard", "csv", "mlflow", "swanlab"] = "tensorboard", seed: int = 42, ): """Pretrain a model. diff --git a/litgpt/utils.py b/litgpt/utils.py index 6a175d2c98..4239b8744d 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -28,7 +28,7 @@ from lightning.fabric.strategies import FSDPStrategy, ModelParallelStrategy from lightning.fabric.utilities.load import _lazy_load as lazy_load from lightning.pytorch.cli import instantiate_class -from lightning.pytorch.loggers import MLFlowLogger, WandbLogger +from lightning.pytorch.loggers import MLFlowLogger, SwanLabLogger, WandbLogger from lightning_utilities.core.imports import RequirementCache, module_available from packaging import version from torch.serialization import normalize_storage_type @@ -567,7 +567,7 @@ def parse_devices(devices: Union[str, int]) -> int: def choose_logger( - logger_name: Literal["csv", "tensorboard", "wandb", "mlflow"], + logger_name: Literal["csv", "tensorboard", "wandb", "mlflow", "swanlab"], out_dir: Path, name: str, log_interval: int = 1, @@ -586,7 +586,29 @@ def choose_logger( return WandbLogger(project=project, name=run, group=group, resume=resume, **kwargs) if logger_name == "mlflow": return MLFlowLogger(experiment_name=name, **kwargs) - raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.") + if logger_name == "swanlab": + project = log_args.pop("project", name) if log_args else name + experiment_name = log_args.pop("run", None) if log_args else None + description = log_args.pop("description", None) if log_args else None + logdir = ( + log_args.pop("logdir", str(out_dir / "logs" / "swanlab")) if log_args else str(out_dir / "logs" / "swanlab") + ) + mode = log_args.pop("mode", "cloud") if log_args else "cloud" + config = log_args.pop("config", None) if log_args else None + version = log_args.pop("version", None) if log_args else None + return SwanLabLogger( + project=project, + experiment_name=experiment_name, + description=description, + logdir=logdir, + mode=mode, + version=version if resume else None, + config=config, + **kwargs, + ) + raise ValueError( + f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb', 'mlflow', 'swanlab'." + ) def get_argument_names(cls): diff --git a/tests/test_utils.py b/tests/test_utils.py index 3a04c6ed57..f8c7dc94c9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,6 +23,7 @@ from litgpt.utils import ( CLI, CycleIterator, + SwanLabLogger, _RunIf, capture_hparams, check_file_size_on_cpu_and_warn, @@ -298,6 +299,15 @@ def test_choose_logger(tmp_path): assert isinstance(choose_logger("wandb", out_dir=tmp_path, name="wandb"), WandbLogger) if RequirementCache("mlflow") or RequirementCache("mlflow-skinny"): assert isinstance(choose_logger("mlflow", out_dir=tmp_path, name="wandb"), MLFlowLogger) + if RequirementCache("swanlab"): + assert isinstance( + choose_logger( + "swanlab", out_dir=tmp_path, name="swanlab", log_args={"project": "test", "mode": "disabled"} + ), + SwanLabLogger, + ) + with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."): + choose_logger("foo", out_dir=tmp_path, name="foo") with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."): choose_logger("foo", out_dir=tmp_path, name="foo")