Skip to content

Commit fdce0f6

Browse files
authored
fix: bug for mlflow offline logging (#675)
## Description After merging the support for azure, we missed a small bug that breaks the offline logging. This PR fixes the bug and add some tests to capture this better in the future. Main issue before was that for forking and resuming the logger was not setting the tracking_uri to the save_dir. ## What problem does this change solve? <!-- Describe if it's a bugfix, new feature, doc update, or breaking change --> ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 921e108 commit fdce0f6

File tree

2 files changed

+86
-8
lines changed

2 files changed

+86
-8
lines changed

training/src/anemoi/training/diagnostics/mlflow/logger.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,16 @@ def __init__(
346346
LOGGER.info("Maximum number of params allowed to be logged is: %s", max_params_length)
347347

348348
self.tracking_uri = tracking_uri
349-
if (self._resumed or self._forked) and self.offline:
350-
self.tracking_uri = save_dir
351-
352349
# Before creating the run we need to overwrite the tracking_uri and save_dir if offline
353350
if self.offline:
354-
# OFFLINE - When we run offline we can pass a save_dir pointing to a local path
355-
self.tracking_uri = None
356-
if save_dir is None:
357-
# otherwise, by default we create a dir called "None"... not ideal
358-
save_dir = "./mlruns"
351+
if self._resumed or self._forked:
352+
self.tracking_uri = save_dir
353+
else:
354+
# OFFLINE - When we run offline we can pass a save_dir pointing to a local path
355+
self.tracking_uri = None
356+
if save_dir is None:
357+
# otherwise, by default we create a dir called "None"... not ideal
358+
save_dir = "./mlruns"
359359

360360
else:
361361
# ONLINE - When we pass a tracking_uri to mlflow then it will ignore the

training/tests/unit/diagnostics/mlflow/test_mlflow_logger.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pathlib import Path
22

3+
import omegaconf
34
import pytest
5+
import yaml
6+
from hydra.utils import instantiate
47

58
from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger
69
from anemoi.training.schemas.diagnostics import MlflowSchema
@@ -19,6 +22,35 @@ def tmp_uri(monkeypatch: pytest.MonkeyPatch, tmp_path: str) -> Path:
1922
return uri
2023

2124

25+
@pytest.fixture
26+
def default_offline_config(tmp_path: str) -> omegaconf.DictConfig:
27+
base = """
28+
diagnostics:
29+
log:
30+
mlflow:
31+
_target_: anemoi.training.diagnostics.mlflow.logger.AnemoiMLflowLogger
32+
offline: Trure
33+
authentication: False
34+
tracking_uri: 'https:test.int'
35+
experiment_name: 'anemoi-debug'
36+
project_name: 'Anemoi'
37+
system: False
38+
terminal: True
39+
run_name: null
40+
on_resume_create_child: True
41+
expand_hyperparams:
42+
- config
43+
http_max_retries: 35
44+
max_params_length: 2000
45+
save_dir: '/scratch/example'
46+
"""
47+
48+
cfg = omegaconf.OmegaConf.create(yaml.safe_load(base))
49+
cfg.diagnostics.log.mlflow.save_dir = tmp_path
50+
51+
return cfg
52+
53+
2254
@pytest.fixture
2355
def default_logger(tmp_path: str, tmp_uri: str) -> AnemoiMLflowLogger:
2456
return AnemoiMLflowLogger(
@@ -31,6 +63,52 @@ def default_logger(tmp_path: str, tmp_uri: str) -> AnemoiMLflowLogger:
3163
)
3264

3365

66+
def create_run(save_dir: str, experiment_name: str) -> str:
67+
import contextlib
68+
69+
import mlflow
70+
71+
mlflow.set_tracking_uri(f"file://{save_dir}")
72+
with contextlib.suppress(mlflow.exceptions.MlflowException):
73+
mlflow.create_experiment(experiment_name)
74+
mlflow.set_experiment(experiment_name)
75+
with mlflow.start_run():
76+
mlflow.log_param("lr", 0.001)
77+
return mlflow.active_run().info.run_id
78+
79+
80+
def test_offline_logger(default_offline_config: omegaconf.DictConfig) -> None:
81+
mlflow_logger = instantiate(default_offline_config.diagnostics.log.mlflow)
82+
assert not mlflow_logger.tracking_uri
83+
84+
85+
def test_offline_resumed_logger(default_offline_config: omegaconf.DictConfig) -> None:
86+
87+
run_id = create_run(
88+
save_dir=default_offline_config.diagnostics.log.mlflow.save_dir,
89+
experiment_name=default_offline_config.diagnostics.log.mlflow.experiment_name,
90+
)
91+
logger_resumed = instantiate(
92+
default_offline_config.diagnostics.log.mlflow,
93+
run_id=run_id,
94+
fork_run_id=None,
95+
)
96+
assert logger_resumed.tracking_uri == default_offline_config.diagnostics.log.mlflow.save_dir
97+
98+
99+
def test_offline_forked_logger(default_offline_config: omegaconf.DictConfig) -> None:
100+
fork_run_id = create_run(
101+
save_dir=default_offline_config.diagnostics.log.mlflow.save_dir,
102+
experiment_name=default_offline_config.diagnostics.log.mlflow.experiment_name,
103+
)
104+
logger_forked = instantiate(
105+
default_offline_config.diagnostics.log.mlflow,
106+
run_id=None,
107+
fork_run_id=fork_run_id,
108+
)
109+
assert logger_forked.tracking_uri == default_offline_config.diagnostics.log.mlflow.save_dir
110+
111+
34112
def test_mlflowlogger_params_limit(default_logger: AnemoiMLflowLogger) -> None:
35113

36114
default_logger._max_params_length = 3

0 commit comments

Comments
 (0)