Skip to content

Commit 7257c30

Browse files
clumsyazzhipa
andauthored
fix: honor mlflow server artifact_location (#1536) (#1538)
Signed-off-by: Alexander Zhipa <azzhipa@amazon.com> Co-authored-by: Alexander Zhipa <azzhipa@amazon.com>
1 parent 55dc433 commit 7257c30

File tree

2 files changed

+116
-19
lines changed

2 files changed

+116
-19
lines changed

nemo_rl/utils/logger.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class MLflowConfig(TypedDict):
6565
experiment_name: str
6666
run_name: str
6767
tracking_uri: NotRequired[str]
68+
artifact_location: NotRequired[str | None]
6869

6970

7071
class GPUMonitoringConfig(TypedDict):
@@ -724,31 +725,30 @@ def __init__(self, cfg: MLflowConfig, log_dir: Optional[str] = None):
724725
725726
Args:
726727
cfg: MLflow configuration
727-
log_dir: Optional log directory
728+
log_dir: Optional log directory (used as fallback if artifact_location not in cfg)
728729
"""
729-
if cfg["tracking_uri"]:
730-
mlflow.set_tracking_uri(cfg["tracking_uri"])
730+
tracking_uri = cfg.get("tracking_uri")
731+
if tracking_uri:
732+
mlflow.set_tracking_uri(tracking_uri)
731733

732-
experiment = mlflow.get_experiment_by_name(cfg["experiment_name"])
734+
experiment_name = cfg["experiment_name"]
735+
experiment = mlflow.get_experiment_by_name(experiment_name)
733736
if experiment is None:
734-
if log_dir:
735-
mlflow.create_experiment(
736-
name=cfg["experiment_name"],
737-
artifact_location=log_dir,
738-
)
739-
else:
740-
mlflow.create_experiment(cfg["experiment_name"])
737+
mlflow.create_experiment(
738+
name=experiment_name,
739+
**{"artifact_location": cfg.get("artifact_location", log_dir)}
740+
if "artifact_location" in cfg or log_dir
741+
else {},
742+
)
741743
else:
742-
mlflow.set_experiment(cfg["experiment_name"])
744+
mlflow.set_experiment(experiment_name)
743745

744746
# Start run
745-
run_kwargs: dict[str, str] = {}
746-
run_kwargs["run_name"] = cfg["run_name"]
747-
747+
run_name = cfg["run_name"]
748+
run_kwargs = {"run_name": run_name}
748749
self.run = mlflow.start_run(**run_kwargs)
749750
print(
750-
f"Initialized MLflowLogger for experiment {cfg['experiment_name']}, "
751-
f"run {cfg['run_name']}"
751+
f"Initialized MLflowLogger for experiment {experiment_name}, run {run_name}"
752752
)
753753

754754
def log_metrics(
@@ -847,8 +847,10 @@ def __init__(self, cfg: LoggerConfig):
847847
self.loggers.append(tensorboard_logger)
848848

849849
if cfg["mlflow_enabled"]:
850-
mlflow_log_dir = os.path.join(self.base_log_dir, "mlflow")
851-
os.makedirs(mlflow_log_dir, exist_ok=True)
850+
mlflow_log_dir = self.base_log_dir
851+
if mlflow_log_dir:
852+
mlflow_log_dir = os.path.join(mlflow_log_dir, "mlflow")
853+
os.makedirs(mlflow_log_dir, exist_ok=True)
852854
mlflow_logger = MLflowLogger(cfg["mlflow"], log_dir=mlflow_log_dir)
853855
self.loggers.append(mlflow_logger)
854856

tests/unit/utils/test_logger.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,101 @@ def test_cleanup(self, mock_mlflow, temp_dir):
548548
# Check that end_run was called
549549
mock_mlflow.end_run.assert_called_once()
550550

551+
@patch("nemo_rl.utils.logger.mlflow")
552+
def test_init_with_none_log_dir(self, mock_mlflow):
553+
"""Test initialization with None log_dir uses server default artifact location."""
554+
cfg = {
555+
"experiment_name": "test-experiment",
556+
"run_name": "test-run",
557+
"tracking_uri": "http://localhost:5000",
558+
}
559+
mock_mlflow.get_experiment_by_name.return_value = None
560+
561+
MLflowLogger(cfg, log_dir=None)
562+
563+
# Verify create_experiment was called without artifact_location
564+
mock_mlflow.create_experiment.assert_called_once_with(name="test-experiment")
565+
mock_mlflow.start_run.assert_called_once_with(run_name="test-run")
566+
567+
@patch("nemo_rl.utils.logger.mlflow")
568+
def test_init_with_custom_log_dir(self, mock_mlflow):
569+
"""Test initialization with custom log_dir sets artifact_location."""
570+
cfg = {
571+
"experiment_name": "test-experiment",
572+
"run_name": "test-run",
573+
"tracking_uri": "http://localhost:5000",
574+
}
575+
mock_mlflow.get_experiment_by_name.return_value = None
576+
577+
MLflowLogger(cfg, log_dir="/custom/path")
578+
579+
# Verify create_experiment was called with artifact_location
580+
mock_mlflow.create_experiment.assert_called_once_with(
581+
name="test-experiment", artifact_location="/custom/path"
582+
)
583+
mock_mlflow.start_run.assert_called_once_with(run_name="test-run")
584+
585+
@patch("nemo_rl.utils.logger.mlflow")
586+
def test_init_with_artifact_location_in_config(self, mock_mlflow):
587+
"""Test initialization with artifact_location in config takes precedence over log_dir."""
588+
cfg = {
589+
"experiment_name": "test-experiment",
590+
"run_name": "test-run",
591+
"tracking_uri": "http://localhost:5000",
592+
"artifact_location": "/config/artifact/path",
593+
}
594+
mock_mlflow.get_experiment_by_name.return_value = None
595+
596+
MLflowLogger(cfg, log_dir="/fallback/path")
597+
598+
# Verify create_experiment was called with artifact_location from config
599+
mock_mlflow.create_experiment.assert_called_once_with(
600+
name=cfg["experiment_name"], artifact_location=cfg["artifact_location"]
601+
)
602+
mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"])
603+
mock_mlflow.start_run.assert_called_once_with(run_name=cfg["run_name"])
604+
605+
@patch("nemo_rl.utils.logger.mlflow")
606+
def test_init_with_artifact_location_none_in_config(self, mock_mlflow):
607+
"""Test initialization with artifact_location=None in config uses server default."""
608+
cfg = {
609+
"experiment_name": "test-experiment",
610+
"run_name": "test-run",
611+
"tracking_uri": "http://localhost:5000",
612+
"artifact_location": None,
613+
}
614+
mock_mlflow.get_experiment_by_name.return_value = None
615+
616+
MLflowLogger(cfg, log_dir="/fallback/path")
617+
618+
# Verify create_experiment was called without artifact_location
619+
# (None is explicitly set, so we don't pass it to MLflow)
620+
mock_mlflow.create_experiment.assert_called_once_with(
621+
name=cfg["experiment_name"], artifact_location=cfg["artifact_location"]
622+
)
623+
mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"])
624+
mock_mlflow.start_run.assert_called_once_with(run_name=cfg["run_name"])
625+
626+
@patch("nemo_rl.utils.logger.mlflow")
627+
def test_init_without_artifact_location_uses_log_dir(self, mock_mlflow):
628+
"""Test initialization without artifact_location in config uses log_dir."""
629+
cfg = {
630+
"experiment_name": "test-experiment",
631+
"run_name": "test-run",
632+
"tracking_uri": "http://localhost:5000",
633+
}
634+
mock_mlflow.get_experiment_by_name.return_value = None
635+
636+
log_dir = "/fallback/path"
637+
MLflowLogger(cfg, log_dir=log_dir)
638+
639+
# Verify create_experiment was called with log_dir as artifact_location
640+
mock_mlflow.create_experiment.assert_called_once_with(
641+
name=cfg["experiment_name"], artifact_location=log_dir
642+
)
643+
mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"])
644+
mock_mlflow.start_run.assert_called_once_with(run_name=cfg["run_name"])
645+
551646

552647
class TestRayGpuMonitorLogger:
553648
"""Test the RayGpuMonitorLogger class."""

0 commit comments

Comments
 (0)