Skip to content

Commit 0cb6cce

Browse files
authored
Merge branch 'master' into cli-custom-parser-class
2 parents bdc917e + 71793c6 commit 0cb6cce

File tree

14 files changed

+142
-24
lines changed

14 files changed

+142
-24
lines changed

.github/workflows/call-clear-cache.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ on:
2323
jobs:
2424
cron-clear:
2525
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
26-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
26+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
2727
with:
28-
scripts-ref: v0.11.8
28+
scripts-ref: v0.14.0
2929
dry-run: ${{ github.event_name == 'pull_request' }}
3030
pattern: "latest|docs"
3131
age-days: 7
3232

3333
direct-clear:
3434
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
35-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
35+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
3636
with:
37-
scripts-ref: v0.11.8
37+
scripts-ref: v0.14.0
3838
dry-run: ${{ github.event_name == 'pull_request' }}
3939
pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging
4040
age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging

.github/workflows/ci-check-md-links.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ on:
1414

1515
jobs:
1616
check-md-links:
17-
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.12.0
17+
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.14.0
1818
with:
1919
config-file: ".github/markdown-links-config.json"
2020
base-branch: "master"

.github/workflows/ci-schema.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
check:
11-
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.12.0
11+
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.14.0
1212
with:
1313
# skip azure due to the wrong schema file by MSFT
1414
# https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607

docs/source-pytorch/visualize/loggers.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,37 @@ Track and Visualize Experiments
5454

5555
</div>
5656
</div>
57+
58+
.. _mlflow_logger:
59+
60+
MLflow Logger
61+
-------------
62+
63+
The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts.
64+
65+
Example usage:
66+
67+
.. code-block:: python
68+
69+
import lightning as L
70+
from lightning.pytorch.loggers import MLFlowLogger
71+
72+
mlf_logger = MLFlowLogger(
73+
experiment_name="lightning_logs",
74+
tracking_uri="file:./ml-runs",
75+
checkpoint_path_prefix="my_prefix"
76+
)
77+
trainer = L.Trainer(logger=mlf_logger)
78+
79+
# Your LightningModule definition
80+
class LitModel(L.LightningModule):
81+
def training_step(self, batch, batch_idx):
82+
# example
83+
self.logger.experiment.whatever_ml_flow_supports(...)
84+
85+
def any_lightning_module_function_or_hook(self):
86+
self.logger.experiment.whatever_ml_flow_supports(...)
87+
88+
# Train your model
89+
model = LitModel()
90+
trainer.fit(model)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,25 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))
1212

13+
1314
### Changed
1415

16+
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored.
17+
18+
1519
### Removed
1620

21+
-
22+
23+
1724
### Fixed
1825

26+
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
27+
28+
29+
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
30+
31+
1932
## [2.5.0] - 2024-12-19
2033

2134
### Added

src/lightning/pytorch/loggers/csv_logs.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,10 @@ def __init__(self, log_dir: str) -> None:
5555
self.hparams: dict[str, Any] = {}
5656

5757
def log_hparams(self, params: dict[str, Any]) -> None:
58-
"""Record hparams."""
58+
"""Record hparams and save into files."""
5959
self.hparams.update(params)
60-
61-
@override
62-
def save(self) -> None:
63-
"""Save recorded hparams and metrics into files."""
6460
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
6561
save_hparams_to_yaml(hparams_file, self.hparams)
66-
return super().save()
6762

6863

6964
class CSVLogger(Logger, FabricCSVLogger):
@@ -144,7 +139,7 @@ def save_dir(self) -> str:
144139

145140
@override
146141
@rank_zero_only
147-
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
142+
def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None:
148143
params = _convert_params(params)
149144
self.experiment.log_hparams(params)
150145

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self):
9797
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
9898
which also logs every checkpoint during training.
9999
* if ``log_model == False`` (default), no checkpoint is logged.
100-
100+
checkpoint_path_prefix: A string to prefix the checkpoint artifact's path.
101101
prefix: A string to put at the beginning of metric keys.
102102
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
103103
default.
@@ -121,6 +121,7 @@ def __init__(
121121
tags: Optional[dict[str, Any]] = None,
122122
save_dir: Optional[str] = "./mlruns",
123123
log_model: Literal[True, False, "all"] = False,
124+
checkpoint_path_prefix: str = "",
124125
prefix: str = "",
125126
artifact_location: Optional[str] = None,
126127
run_id: Optional[str] = None,
@@ -147,6 +148,7 @@ def __init__(
147148
self._artifact_location = artifact_location
148149
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
149150
self._initialized = False
151+
self._checkpoint_path_prefix = checkpoint_path_prefix
150152

151153
from mlflow.tracking import MlflowClient
152154

@@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
361363
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
362364

363365
# Artifact path on mlflow
364-
artifact_path = Path(p).stem
366+
artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem
365367

366368
# Log the checkpoint
367369
self.experiment.log_artifact(self._run_id, p, artifact_path)

src/lightning/pytorch/trainer/call.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lightning.pytorch as pl
2222
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2323
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
24+
from lightning.pytorch.loggers import WandbLogger
2425
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
2526
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
2627
from lightning.pytorch.trainer.states import TrainerStatus
@@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
9192
if isinstance(module, _DeviceDtypeModuleMixin):
9293
module._device = trainer.strategy.root_device
9394

95+
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
96+
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
97+
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))
98+
9499
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
95-
for logger in trainer.loggers:
100+
for logger in loggers:
96101
if hasattr(logger, "experiment"):
97102
_ = logger.experiment
98103

tests/tests_fabric/utilities/test_seed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import random
3+
import warnings
34
from unittest import mock
45
from unittest.mock import Mock
56

@@ -30,9 +31,9 @@ def test_seed_stays_same_with_multiple_seed_everything_calls():
3031
seed_everything()
3132
initial_seed = os.environ.get("PL_GLOBAL_SEED")
3233

33-
with pytest.warns(None) as record:
34+
with warnings.catch_warnings():
35+
warnings.simplefilter("error")
3436
seed_everything()
35-
assert not record # does not warn
3637
seed = os.environ.get("PL_GLOBAL_SEED")
3738

3839
assert initial_seed == seed

tests/tests_pytorch/callbacks/test_lr_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer):
548548
"""Called when the epoch begins."""
549549
if epoch == 1 and isinstance(optimizer, torch.optim.SGD):
550550
self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1)
551-
if epoch == 2 and isinstance(optimizer, torch.optim.Adam):
551+
if epoch == 2 and type(optimizer) is torch.optim.Adam:
552552
self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1)
553553

554-
if epoch == 3 and isinstance(optimizer, torch.optim.Adam):
554+
if epoch == 3 and type(optimizer) is torch.optim.Adam:
555555
assert len(optimizer.param_groups) == 2
556556
self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1)
557557
assert len(optimizer.param_groups) == 3

0 commit comments

Comments
 (0)