Skip to content

Commit 73b785a

Browse files
committed
Revert "Add checkpoint artifact path prefix to MLflow logger (#20538)"
This reverts commit 87108d8
1 parent f6ef409 commit 73b785a

File tree

4 files changed

+2
-69
lines changed

4 files changed

+2
-69
lines changed

docs/source-pytorch/visualize/loggers.rst

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -54,37 +54,3 @@ Track and Visualize Experiments
5454

5555
</div>
5656
</div>
57-
58-
.. _mlflow_logger:
59-
60-
MLflow Logger
61-
-------------
62-
63-
The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts.
64-
65-
Example usage:
66-
67-
.. code-block:: python
68-
69-
import lightning as L
70-
from lightning.pytorch.loggers import MLFlowLogger
71-
72-
mlf_logger = MLFlowLogger(
73-
experiment_name="lightning_logs",
74-
tracking_uri="file:./ml-runs",
75-
checkpoint_path_prefix="my_prefix"
76-
)
77-
trainer = L.Trainer(logger=mlf_logger)
78-
79-
# Your LightningModule definition
80-
class LitModel(L.LightningModule):
81-
def training_step(self, batch, batch_idx):
82-
# example
83-
self.logger.experiment.whatever_ml_flow_supports(...)
84-
85-
def any_lightning_module_function_or_hook(self):
86-
self.logger.experiment.whatever_ml_flow_supports(...)
87-
88-
# Train your model
89-
model = LitModel()
90-
trainer.fit(model)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535

3636
- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))
3737
- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))
38-
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))
3938
- CometML logger was updated to support the recent Comet SDK ([#20275](https://github.com/Lightning-AI/pytorch-lightning/pull/20275))
4039
- bump: testing with latest `torch` 2.6 ([#20509](https://github.com/Lightning-AI/pytorch-lightning/pull/20509))
4140

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self):
9797
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
9898
which also logs every checkpoint during training.
9999
* if ``log_model == False`` (default), no checkpoint is logged.
100-
checkpoint_path_prefix: A string to prefix the checkpoint artifact's path.
100+
101101
prefix: A string to put at the beginning of metric keys.
102102
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
103103
default.
@@ -121,7 +121,6 @@ def __init__(
121121
tags: Optional[dict[str, Any]] = None,
122122
save_dir: Optional[str] = "./mlruns",
123123
log_model: Literal[True, False, "all"] = False,
124-
checkpoint_path_prefix: str = "",
125124
prefix: str = "",
126125
artifact_location: Optional[str] = None,
127126
run_id: Optional[str] = None,
@@ -148,7 +147,6 @@ def __init__(
148147
self._artifact_location = artifact_location
149148
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
150149
self._initialized = False
151-
self._checkpoint_path_prefix = checkpoint_path_prefix
152150

153151
from mlflow.tracking import MlflowClient
154152

@@ -363,7 +361,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
363361
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
364362

365363
# Artifact path on mlflow
366-
artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem
364+
artifact_path = Path(p).stem
367365

368366
# Log the checkpoint
369367
self.experiment.log_artifact(self._run_id, p, artifact_path)

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -427,33 +427,3 @@ def test_set_tracking_uri(mlflow_mock):
427427
mlflow_mock.set_tracking_uri.assert_not_called()
428428
_ = logger.experiment
429429
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
430-
431-
432-
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
433-
def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path):
434-
"""Test that the logger creates the folders and files in the right place with a prefix."""
435-
client = mlflow_mock.tracking.MlflowClient
436-
437-
# Get model, logger, trainer and train
438-
model = BoringModel()
439-
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix")
440-
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
441-
442-
trainer = Trainer(
443-
default_root_dir=tmp_path,
444-
logger=logger,
445-
max_epochs=2,
446-
limit_train_batches=3,
447-
limit_val_batches=3,
448-
)
449-
trainer.fit(model)
450-
451-
# Checkpoint log
452-
assert client.return_value.log_artifact.call_count == 2
453-
# Metadata and aliases log
454-
assert client.return_value.log_artifacts.call_count == 2
455-
456-
# Check that the prefix is used in the artifact path
457-
for call in client.return_value.log_artifact.call_args_list:
458-
args, _ = call
459-
assert str(args[2]).startswith("my_prefix")

0 commit comments

Comments
 (0)