Skip to content

Commit 1ef3940

Browse files
authored
Merge branch 'master' into ci/bump-pt-2.6
2 parents 74b11a4 + 5073ac1 commit 1ef3940

File tree

13 files changed

+165
-14
lines changed

13 files changed

+165
-14
lines changed

.github/workflows/call-clear-cache.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ on:
2323
jobs:
2424
cron-clear:
2525
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
26-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
26+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
2727
with:
28-
scripts-ref: v0.11.8
28+
scripts-ref: v0.14.0
2929
dry-run: ${{ github.event_name == 'pull_request' }}
3030
pattern: "latest|docs"
3131
age-days: 7
3232

3333
direct-clear:
3434
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
35-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
35+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
3636
with:
37-
scripts-ref: v0.11.8
37+
scripts-ref: v0.14.0
3838
dry-run: ${{ github.event_name == 'pull_request' }}
3939
pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging
4040
age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging

.github/workflows/ci-check-md-links.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ on:
1414

1515
jobs:
1616
check-md-links:
17-
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.12.0
17+
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.14.0
1818
with:
1919
config-file: ".github/markdown-links-config.json"
2020
base-branch: "master"

.github/workflows/ci-schema.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
check:
11-
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.12.0
11+
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.14.0
1212
with:
1313
# skip azure due to the wrong schema file by MSFT
1414
# https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607

docs/source-pytorch/visualize/loggers.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,37 @@ Track and Visualize Experiments
5454

5555
</div>
5656
</div>
57+
58+
.. _mlflow_logger:
59+
60+
MLflow Logger
61+
-------------
62+
63+
The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts.
64+
65+
Example usage:
66+
67+
.. code-block:: python
68+
69+
import lightning as L
70+
from lightning.pytorch.loggers import MLFlowLogger
71+
72+
mlf_logger = MLFlowLogger(
73+
experiment_name="lightning_logs",
74+
tracking_uri="file:./ml-runs",
75+
checkpoint_path_prefix="my_prefix"
76+
)
77+
trainer = L.Trainer(logger=mlf_logger)
78+
79+
# Your LightningModule definition
80+
class LitModel(L.LightningModule):
81+
def training_step(self, batch, batch_idx):
82+
# example
83+
self.logger.experiment.whatever_ml_flow_supports(...)
84+
85+
def any_lightning_module_function_or_hook(self):
86+
self.logger.experiment.whatever_ml_flow_supports(...)
87+
88+
# Train your model
89+
model = LitModel()
90+
trainer.fit(model)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,31 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))
12+
13+
1114
### Changed
1215

16+
- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))
17+
18+
19+
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))
20+
21+
22+
1323
### Removed
1424

25+
-
26+
27+
1528
### Fixed
1629

1730
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
1831

1932

33+
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
34+
35+
2036
## [2.5.0] - 2024-12-19
2137

2238
### Added

src/lightning/pytorch/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def __init__(
314314
trainer_defaults: Optional[dict[str, Any]] = None,
315315
seed_everything_default: Union[bool, int] = True,
316316
parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None,
317+
parser_class: type[LightningArgumentParser] = LightningArgumentParser,
317318
subclass_mode_model: bool = False,
318319
subclass_mode_data: bool = False,
319320
args: ArgsType = None,
@@ -367,6 +368,7 @@ def __init__(
367368
self.trainer_defaults = trainer_defaults or {}
368369
self.seed_everything_default = seed_everything_default
369370
self.parser_kwargs = parser_kwargs or {}
371+
self.parser_class = parser_class
370372
self.auto_configure_optimizers = auto_configure_optimizers
371373

372374
self.model_class = model_class
@@ -404,7 +406,7 @@ def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str,
404406
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
405407
"""Method that instantiates the argument parser."""
406408
kwargs.setdefault("dump_header", [f"lightning.pytorch=={pl.__version__}"])
407-
parser = LightningArgumentParser(**kwargs)
409+
parser = self.parser_class(**kwargs)
408410
parser.add_argument(
409411
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
410412
)

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self):
9797
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
9898
which also logs every checkpoint during training.
9999
* if ``log_model == False`` (default), no checkpoint is logged.
100-
100+
checkpoint_path_prefix: A string to prefix the checkpoint artifact's path.
101101
prefix: A string to put at the beginning of metric keys.
102102
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
103103
default.
@@ -121,6 +121,7 @@ def __init__(
121121
tags: Optional[dict[str, Any]] = None,
122122
save_dir: Optional[str] = "./mlruns",
123123
log_model: Literal[True, False, "all"] = False,
124+
checkpoint_path_prefix: str = "",
124125
prefix: str = "",
125126
artifact_location: Optional[str] = None,
126127
run_id: Optional[str] = None,
@@ -147,6 +148,7 @@ def __init__(
147148
self._artifact_location = artifact_location
148149
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
149150
self._initialized = False
151+
self._checkpoint_path_prefix = checkpoint_path_prefix
150152

151153
from mlflow.tracking import MlflowClient
152154

@@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
361363
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
362364

363365
# Artifact path on mlflow
364-
artifact_path = Path(p).stem
366+
artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem
365367

366368
# Log the checkpoint
367369
self.experiment.log_artifact(self._run_id, p, artifact_path)

src/lightning/pytorch/loggers/wandb.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]:
410410
if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
411411
self._experiment, "define_metric", None
412412
):
413-
self._experiment.define_metric("trainer/global_step")
414-
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
413+
if self._wandb_init.get("sync_tensorboard"):
414+
self._experiment.define_metric("*", step_metric="global_step")
415+
else:
416+
self._experiment.define_metric("trainer/global_step")
417+
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
415418

416419
return self._experiment
417420

@@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
434437
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
435438

436439
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
437-
if step is not None:
440+
if step is not None and not self._wandb_init.get("sync_tensorboard"):
438441
self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
439442
else:
440443
self.experiment.log(metrics)

src/lightning/pytorch/trainer/call.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lightning.pytorch as pl
2222
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2323
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
24+
from lightning.pytorch.loggers import WandbLogger
2425
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
2526
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
2627
from lightning.pytorch.trainer.states import TrainerStatus
@@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
9192
if isinstance(module, _DeviceDtypeModuleMixin):
9293
module._device = trainer.strategy.root_device
9394

95+
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
96+
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
97+
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))
98+
9499
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
95-
for logger in trainer.loggers:
100+
for logger in loggers:
96101
if hasattr(logger, "experiment"):
97102
_ = logger.experiment
98103

tests/tests_pytorch/core/test_results.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from functools import partial
1515

16+
import pytest
1617
import torch
1718
import torch.distributed as dist
1819

@@ -48,6 +49,8 @@ def result_reduce_ddp_fn(strategy):
4849
assert actual.item() == dist.get_world_size()
4950

5051

52+
# flaky with "process 0 terminated with signal SIGABRT"
53+
@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException")
5154
@RunIf(skip_windows=True)
5255
def test_result_reduce_ddp():
5356
spawn_launch(result_reduce_ddp_fn, [torch.device("cpu")] * 2)

0 commit comments

Comments
 (0)