@@ -548,6 +548,101 @@ def test_cleanup(self, mock_mlflow, temp_dir):
548548 # Check that end_run was called
549549 mock_mlflow .end_run .assert_called_once ()
550550
551+ @patch ("nemo_rl.utils.logger.mlflow" )
552+ def test_init_with_none_log_dir (self , mock_mlflow ):
553+ """Test initialization with None log_dir uses server default artifact location."""
554+ cfg = {
555+ "experiment_name" : "test-experiment" ,
556+ "run_name" : "test-run" ,
557+ "tracking_uri" : "http://localhost:5000" ,
558+ }
559+ mock_mlflow .get_experiment_by_name .return_value = None
560+
561+ MLflowLogger (cfg , log_dir = None )
562+
563+ # Verify create_experiment was called without artifact_location
564+ mock_mlflow .create_experiment .assert_called_once_with (name = "test-experiment" )
565+ mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" )
566+
567+ @patch ("nemo_rl.utils.logger.mlflow" )
568+ def test_init_with_custom_log_dir (self , mock_mlflow ):
569+ """Test initialization with custom log_dir sets artifact_location."""
570+ cfg = {
571+ "experiment_name" : "test-experiment" ,
572+ "run_name" : "test-run" ,
573+ "tracking_uri" : "http://localhost:5000" ,
574+ }
575+ mock_mlflow .get_experiment_by_name .return_value = None
576+
577+ MLflowLogger (cfg , log_dir = "/custom/path" )
578+
579+ # Verify create_experiment was called with artifact_location
580+ mock_mlflow .create_experiment .assert_called_once_with (
581+ name = "test-experiment" , artifact_location = "/custom/path"
582+ )
583+ mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" )
584+
585+ @patch ("nemo_rl.utils.logger.mlflow" )
586+ def test_init_with_artifact_location_in_config (self , mock_mlflow ):
587+ """Test initialization with artifact_location in config takes precedence over log_dir."""
588+ cfg = {
589+ "experiment_name" : "test-experiment" ,
590+ "run_name" : "test-run" ,
591+ "tracking_uri" : "http://localhost:5000" ,
592+ "artifact_location" : "/config/artifact/path" ,
593+ }
594+ mock_mlflow .get_experiment_by_name .return_value = None
595+
596+ MLflowLogger (cfg , log_dir = "/fallback/path" )
597+
598+ # Verify create_experiment was called with artifact_location from config
599+ mock_mlflow .create_experiment .assert_called_once_with (
600+ name = cfg ["experiment_name" ], artifact_location = cfg ["artifact_location" ]
601+ )
602+ mock_mlflow .set_tracking_uri .assert_called_once_with (cfg ["tracking_uri" ])
603+ mock_mlflow .start_run .assert_called_once_with (run_name = cfg ["run_name" ])
604+
605+ @patch ("nemo_rl.utils.logger.mlflow" )
606+ def test_init_with_artifact_location_none_in_config (self , mock_mlflow ):
607+ """Test initialization with artifact_location=None in config uses server default."""
608+ cfg = {
609+ "experiment_name" : "test-experiment" ,
610+ "run_name" : "test-run" ,
611+ "tracking_uri" : "http://localhost:5000" ,
612+ "artifact_location" : None ,
613+ }
614+ mock_mlflow .get_experiment_by_name .return_value = None
615+
616+ MLflowLogger (cfg , log_dir = "/fallback/path" )
617+
618+ # Verify create_experiment was called without artifact_location
619+ # (None is explicitly set, so we don't pass it to MLflow)
620+ mock_mlflow .create_experiment .assert_called_once_with (
621+ name = cfg ["experiment_name" ], artifact_location = cfg ["artifact_location" ]
622+ )
623+ mock_mlflow .set_tracking_uri .assert_called_once_with (cfg ["tracking_uri" ])
624+ mock_mlflow .start_run .assert_called_once_with (run_name = cfg ["run_name" ])
625+
626+ @patch ("nemo_rl.utils.logger.mlflow" )
627+ def test_init_without_artifact_location_uses_log_dir (self , mock_mlflow ):
628+ """Test initialization without artifact_location in config uses log_dir."""
629+ cfg = {
630+ "experiment_name" : "test-experiment" ,
631+ "run_name" : "test-run" ,
632+ "tracking_uri" : "http://localhost:5000" ,
633+ }
634+ mock_mlflow .get_experiment_by_name .return_value = None
635+
636+ log_dir = "/fallback/path"
637+ MLflowLogger (cfg , log_dir = log_dir )
638+
639+ # Verify create_experiment was called with log_dir as artifact_location
640+ mock_mlflow .create_experiment .assert_called_once_with (
641+ name = cfg ["experiment_name" ], artifact_location = log_dir
642+ )
643+ mock_mlflow .set_tracking_uri .assert_called_once_with (cfg ["tracking_uri" ])
644+ mock_mlflow .start_run .assert_called_once_with (run_name = cfg ["run_name" ])
645+
551646
552647class TestRayGpuMonitorLogger :
553648 """Test the RayGpuMonitorLogger class."""
0 commit comments