Skip to content

Commit 1b70032

Browse files
topikachugongypre-commit-ci[bot]
authored
Add mlflow logger support (#1985)
Co-authored-by: gongy <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 07def67 commit 1b70032

File tree

8 files changed

+15
-12
lines changed

8 files changed

+15
-12
lines changed

extensions/thunder/pretrain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def setup(
7474
devices: Union[int, str] = "auto",
7575
num_nodes: int = 1,
7676
tokenizer_dir: Optional[Path] = None,
77-
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
77+
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "tensorboard",
7878
seed: int = 42,
7979
compiler: Optional[Literal["thunder", "torch"]] = "thunder",
8080
executors: Optional[List[str]] = ("sdpa", "torchcompile", "nvfuser", "torch"),
@@ -156,7 +156,7 @@ def setup(
156156
)
157157

158158
fabric.print(pprint.pformat(hparams))
159-
if logger_name in ("tensorboard", "wandb"):
159+
if logger_name in ("tensorboard", "wandb", "mlflow"):
160160
fabric.logger.log_hyperparams(hparams)
161161

162162
main(

litgpt/finetune/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def setup(
6363
),
6464
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
6565
optimizer: Union[str, Dict] = "AdamW",
66-
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
66+
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
6767
seed: int = 1337,
6868
access_token: Optional[str] = None,
6969
) -> None:

litgpt/finetune/adapter_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def setup(
6565
),
6666
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
6767
optimizer: Union[str, Dict] = "AdamW",
68-
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
68+
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
6969
seed: int = 1337,
7070
access_token: Optional[str] = None,
7171
) -> None:

litgpt/finetune/full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def setup(
5959
),
6060
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
6161
optimizer: Union[str, Dict] = "AdamW",
62-
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
62+
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
6363
seed: int = 1337,
6464
access_token: Optional[str] = None,
6565
) -> None:

litgpt/finetune/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def setup(
7373
),
7474
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
7575
optimizer: Union[str, Dict] = "AdamW",
76-
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
76+
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
7777
seed: int = 1337,
7878
access_token: Optional[str] = None,
7979
) -> None:

litgpt/pretrain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def setup(
6666
devices: Union[int, str] = "auto",
6767
num_nodes: int = 1,
6868
tokenizer_dir: Optional[Path] = None,
69-
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
69+
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "tensorboard",
7070
seed: int = 42,
7171
):
7272
"""Pretrain a model.
@@ -143,7 +143,7 @@ def setup(
143143
fabric.launch()
144144

145145
fabric.print(pprint.pformat(hparams))
146-
if logger_name in ("tensorboard", "wandb"):
146+
if logger_name in ("tensorboard", "wandb", "mlflow"):
147147
fabric.logger.log_hyperparams(hparams)
148148

149149
main(

litgpt/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from lightning.fabric.strategies import FSDPStrategy
2828
from lightning.fabric.utilities.load import _lazy_load as lazy_load
2929
from lightning.pytorch.cli import instantiate_class
30-
from lightning.pytorch.loggers import WandbLogger
30+
from lightning.pytorch.loggers import MLFlowLogger, WandbLogger
3131
from lightning_utilities.core.imports import module_available
3232
from packaging import version
3333
from torch.serialization import normalize_storage_type
@@ -544,7 +544,7 @@ def parse_devices(devices: Union[str, int]) -> int:
544544

545545

546546
def choose_logger(
547-
logger_name: Literal["csv", "tensorboard", "wandb"],
547+
logger_name: Literal["csv", "tensorboard", "wandb", "mlflow"],
548548
out_dir: Path,
549549
name: str,
550550
log_interval: int = 1,
@@ -557,6 +557,8 @@ def choose_logger(
557557
return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs)
558558
if logger_name == "wandb":
559559
return WandbLogger(project=name, resume=resume, **kwargs)
560+
if logger_name == "mlflow":
561+
return MLFlowLogger(experiment_name=name, **kwargs)
560562
raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.")
561563

562564

tests/test_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from lightning import Fabric
1515
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
1616
from lightning.fabric.plugins import BitsandbytesPrecision
17-
from lightning.pytorch.loggers import WandbLogger
17+
from lightning.pytorch.loggers import MLFlowLogger, WandbLogger
1818
from lightning_utilities.core.imports import RequirementCache
1919

2020
from litgpt import GPT
@@ -307,7 +307,8 @@ def test_choose_logger(tmp_path):
307307
assert isinstance(choose_logger("tensorboard", out_dir=tmp_path, name="tb"), TensorBoardLogger)
308308
if RequirementCache("wandb"):
309309
assert isinstance(choose_logger("wandb", out_dir=tmp_path, name="wandb"), WandbLogger)
310-
310+
if RequirementCache("mlflow") or RequirementCache("mlflow-skinny"):
311+
assert isinstance(choose_logger("mlflow", out_dir=tmp_path, name="wandb"), MLFlowLogger)
311312
with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."):
312313
choose_logger("foo", out_dir=tmp_path, name="foo")
313314

0 commit comments

Comments
 (0)