@@ -775,3 +775,165 @@ def test_verbose_logging(self, runner, app):
775775 mock_configure .reset_mock ()
776776 runner .invoke (app , ["error-command" ])
777777 mock_configure .assert_called_once_with (False )
778+
779+
780+ class TestTorchrunAndConfirmation :
781+ """Test torchrun detection and confirmation behavior."""
782+
783+ @patch ("os.environ" , {"WORLD_SIZE" : "2" })
784+ def test_is_torchrun_true (self ):
785+ """Test that _is_torchrun returns True when WORLD_SIZE > 1."""
786+ from nemo_run .cli .api import _is_torchrun
787+
788+ assert _is_torchrun () is True
789+
790+ @patch ("os.environ" , {})
791+ def test_is_torchrun_false_no_env (self ):
792+ """Test that _is_torchrun returns False when WORLD_SIZE not in environment."""
793+ from nemo_run .cli .api import _is_torchrun
794+
795+ assert _is_torchrun () is False
796+
797+ @patch ("os.environ" , {"WORLD_SIZE" : "1" })
798+ def test_is_torchrun_false_size_one (self ):
799+ """Test that _is_torchrun returns False when WORLD_SIZE = 1."""
800+ from nemo_run .cli .api import _is_torchrun
801+
802+ assert _is_torchrun () is False
803+
804+ @patch ("nemo_run.cli.api._is_torchrun" , return_value = True )
805+ def test_should_continue_torchrun (self , mock_torchrun ):
806+ """Test that _should_continue returns True under torchrun."""
807+ ctx = run .cli .RunContext (name = "test" )
808+ assert ctx ._should_continue (False ) is True
809+ mock_torchrun .assert_called_once ()
810+
811+ @patch ("nemo_run.cli.api._is_torchrun" , return_value = False )
812+ @patch ("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION" , True )
813+ def test_should_continue_global_flag_true (self , mock_torchrun ):
814+ """Test that _should_continue respects global NEMORUN_SKIP_CONFIRMATION flag."""
815+ ctx = run .cli .RunContext (name = "test" )
816+ assert ctx ._should_continue (False ) is True
817+ mock_torchrun .assert_called_once ()
818+
819+ @patch ("nemo_run.cli.api._is_torchrun" , return_value = False )
820+ @patch ("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION" , False )
821+ def test_should_continue_global_flag_false (self , mock_torchrun ):
822+ """Test that _should_continue respects global NEMORUN_SKIP_CONFIRMATION flag."""
823+ ctx = run .cli .RunContext (name = "test" )
824+ assert ctx ._should_continue (False ) is False
825+ mock_torchrun .assert_called_once ()
826+
827+ @patch ("nemo_run.cli.api._is_torchrun" , return_value = False )
828+ @patch ("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION" , None )
829+ def test_should_continue_skip_confirmation (self , mock_torchrun ):
830+ """Test that _should_continue respects skip_confirmation parameter."""
831+ ctx = run .cli .RunContext (name = "test" )
832+ assert ctx ._should_continue (True ) is True
833+ mock_torchrun .assert_called_once ()
834+
835+
836+ class TestRunContextLaunch :
837+ """Test RunContext.launch method."""
838+
839+ def test_launch_with_dryrun (self ):
840+ """Test launch with dryrun."""
841+ ctx = run .cli .RunContext (name = "test_run" , dryrun = True )
842+ mock_experiment = Mock (spec = run .Experiment )
843+
844+ ctx .launch (mock_experiment )
845+
846+ mock_experiment .dryrun .assert_called_once ()
847+ mock_experiment .run .assert_not_called ()
848+
849+ def test_launch_normal (self ):
850+ """Test launch without dryrun."""
851+ ctx = run .cli .RunContext (name = "test_run" , direct = True , tail_logs = True )
852+ mock_experiment = Mock (spec = run .Experiment )
853+
854+ ctx .launch (mock_experiment )
855+
856+ mock_experiment .run .assert_called_once_with (
857+ sequential = False , detach = False , direct = True , tail_logs = True
858+ )
859+
860+ def test_launch_with_executor (self ):
861+ """Test launch with executor specified."""
862+ ctx = run .cli .RunContext (name = "test_run" )
863+ ctx .executor = Mock (spec = run .LocalExecutor )
864+ mock_experiment = Mock (spec = run .Experiment )
865+
866+ ctx .launch (mock_experiment )
867+
868+ mock_experiment .run .assert_called_once_with (
869+ sequential = False , detach = False , direct = False , tail_logs = False
870+ )
871+
872+ def test_launch_sequential (self ):
873+ """Test launch with sequential=True."""
874+ ctx = run .cli .RunContext (name = "test_run" )
875+ # Initialize executor to None explicitly
876+ ctx .executor = None
877+ mock_experiment = Mock (spec = run .Experiment )
878+
879+ ctx .launch (mock_experiment , sequential = True )
880+
881+ mock_experiment .run .assert_called_once_with (
882+ sequential = True , detach = False , direct = True , tail_logs = False
883+ )
884+
885+
886+ class TestParsePrefixedArgs :
887+ """Test _parse_prefixed_args function."""
888+
889+ def test_parse_prefixed_args_simple (self ):
890+ """Test parsing simple prefixed arguments."""
891+ from nemo_run .cli .api import _parse_prefixed_args
892+
893+ args = ["executor=local" , "other=value" ]
894+ prefix_value , prefix_args , other_args = _parse_prefixed_args (args , "executor" )
895+
896+ assert prefix_value == "local"
897+ assert prefix_args == []
898+ assert other_args == ["other=value" ]
899+
900+ def test_parse_prefixed_args_with_dot_notation (self ):
901+ """Test parsing prefixed arguments with dot notation."""
902+ from nemo_run .cli .api import _parse_prefixed_args
903+
904+ args = ["executor=local" , "executor.gpu=2" , "other=value" ]
905+ prefix_value , prefix_args , other_args = _parse_prefixed_args (args , "executor" )
906+
907+ assert prefix_value == "local"
908+ assert prefix_args == ["gpu=2" ]
909+ assert other_args == ["other=value" ]
910+
911+ def test_parse_prefixed_args_with_brackets (self ):
912+ """Test parsing prefixed arguments with bracket notation."""
913+ from nemo_run .cli .api import _parse_prefixed_args
914+
915+ args = ["plugins=list" , "plugins[0].name=test" , "other=value" ]
916+ prefix_value , prefix_args , other_args = _parse_prefixed_args (args , "plugins" )
917+
918+ assert prefix_value == "list"
919+ assert prefix_args == ["[0].name=test" ]
920+ assert other_args == ["other=value" ]
921+
922+ def test_parse_prefixed_args_invalid_format (self ):
923+ """Test parsing prefixed arguments with invalid format."""
924+ from nemo_run .cli .api import _parse_prefixed_args
925+
926+ args = ["executorblah" , "other=value" ]
927+ with pytest .raises (ValueError , match = "Executor overwrites must start with 'executor.'" ):
928+ _parse_prefixed_args (args , "executor" )
929+
930+ def test_parse_prefixed_args_no_prefix (self ):
931+ """Test parsing when no prefixed arguments are present."""
932+ from nemo_run .cli .api import _parse_prefixed_args
933+
934+ args = ["arg1=value1" , "arg2=value2" ]
935+ prefix_value , prefix_args , other_args = _parse_prefixed_args (args , "executor" )
936+
937+ assert prefix_value is None
938+ assert prefix_args == []
939+ assert other_args == ["arg1=value1" , "arg2=value2" ]
0 commit comments