Skip to content

Commit bc0e840

Browse files
authored
Improve overall test coverage (#177)
1 parent 41f05d9 commit bc0e840

19 files changed

+6482
-680
lines changed

nemo_run/run/experiment.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def __init__(
303303
self._title = title
304304
self._id = id or f"{title}_{int(time.time())}"
305305

306-
base_dir = base_dir or get_nemorun_home()
306+
base_dir = str(base_dir or get_nemorun_home())
307307
self._exp_dir = os.path.join(base_dir, "experiments", title, self._id)
308308

309309
self.log_level = log_level
@@ -963,7 +963,7 @@ def reset(self) -> "Experiment":
963963
self.console.log(
964964
f"[bold magenta]Experiment {self._id} has not run yet, skipping reset..."
965965
)
966-
return
966+
return self
967967

968968
old_id, old_exp_dir, old_launched = self._id, self._exp_dir, self._launched
969969
self._id = f"{self._title}_{int(time.time())}"
@@ -1233,18 +1233,19 @@ def maybe_load_external_main(exp_dir: str):
12331233
_LOADED_MAINS.add(main_file)
12341234

12351235
spec = importlib.util.spec_from_file_location("__external_main__", main_file)
1236-
new_main_module = importlib.util.module_from_spec(spec)
1237-
spec.loader.exec_module(new_main_module)
1236+
if spec is not None and spec.loader is not None:
1237+
new_main_module = importlib.util.module_from_spec(spec)
1238+
spec.loader.exec_module(new_main_module)
12381239

1239-
if "__external_main__" not in sys.modules:
1240-
sys.modules["__external_main__"] = new_main_module
1241-
else:
1242-
external = sys.modules["__external_main__"]
1240+
if "__external_main__" not in sys.modules:
1241+
sys.modules["__external_main__"] = new_main_module
1242+
else:
1243+
external = sys.modules["__external_main__"]
1244+
for attr in dir(new_main_module):
1245+
if not attr.startswith("__"):
1246+
setattr(external, attr, getattr(new_main_module, attr))
1247+
1248+
existing_main = sys.modules["__main__"]
12431249
for attr in dir(new_main_module):
12441250
if not attr.startswith("__"):
1245-
setattr(external, attr, getattr(new_main_module, attr))
1246-
1247-
existing_main = sys.modules["__main__"]
1248-
for attr in dir(new_main_module):
1249-
if not attr.startswith("__"):
1250-
setattr(existing_main, attr, getattr(new_main_module, attr))
1251+
setattr(existing_main, attr, getattr(new_main_module, attr))

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ dev = [
6464
"pytest-mock>=3.14.0",
6565
"ipykernel>=6.29.4",
6666
"ipywidgets>=8.1.2",
67-
"jupyter>=1.1.1"
67+
"jupyter>=1.1.1",
68+
"pytest-cov"
6869
]
6970

7071
lint = [

test/cli/test_api.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,3 +775,165 @@ def test_verbose_logging(self, runner, app):
775775
mock_configure.reset_mock()
776776
runner.invoke(app, ["error-command"])
777777
mock_configure.assert_called_once_with(False)
778+
779+
780+
class TestTorchrunAndConfirmation:
781+
"""Test torchrun detection and confirmation behavior."""
782+
783+
@patch("os.environ", {"WORLD_SIZE": "2"})
784+
def test_is_torchrun_true(self):
785+
"""Test that _is_torchrun returns True when WORLD_SIZE > 1."""
786+
from nemo_run.cli.api import _is_torchrun
787+
788+
assert _is_torchrun() is True
789+
790+
@patch("os.environ", {})
791+
def test_is_torchrun_false_no_env(self):
792+
"""Test that _is_torchrun returns False when WORLD_SIZE not in environment."""
793+
from nemo_run.cli.api import _is_torchrun
794+
795+
assert _is_torchrun() is False
796+
797+
@patch("os.environ", {"WORLD_SIZE": "1"})
798+
def test_is_torchrun_false_size_one(self):
799+
"""Test that _is_torchrun returns False when WORLD_SIZE = 1."""
800+
from nemo_run.cli.api import _is_torchrun
801+
802+
assert _is_torchrun() is False
803+
804+
@patch("nemo_run.cli.api._is_torchrun", return_value=True)
805+
def test_should_continue_torchrun(self, mock_torchrun):
806+
"""Test that _should_continue returns True under torchrun."""
807+
ctx = run.cli.RunContext(name="test")
808+
assert ctx._should_continue(False) is True
809+
mock_torchrun.assert_called_once()
810+
811+
@patch("nemo_run.cli.api._is_torchrun", return_value=False)
812+
@patch("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION", True)
813+
def test_should_continue_global_flag_true(self, mock_torchrun):
814+
"""Test that _should_continue respects global NEMORUN_SKIP_CONFIRMATION flag."""
815+
ctx = run.cli.RunContext(name="test")
816+
assert ctx._should_continue(False) is True
817+
mock_torchrun.assert_called_once()
818+
819+
@patch("nemo_run.cli.api._is_torchrun", return_value=False)
820+
@patch("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION", False)
821+
def test_should_continue_global_flag_false(self, mock_torchrun):
822+
"""Test that _should_continue respects global NEMORUN_SKIP_CONFIRMATION flag."""
823+
ctx = run.cli.RunContext(name="test")
824+
assert ctx._should_continue(False) is False
825+
mock_torchrun.assert_called_once()
826+
827+
@patch("nemo_run.cli.api._is_torchrun", return_value=False)
828+
@patch("nemo_run.cli.api.NEMORUN_SKIP_CONFIRMATION", None)
829+
def test_should_continue_skip_confirmation(self, mock_torchrun):
830+
"""Test that _should_continue respects skip_confirmation parameter."""
831+
ctx = run.cli.RunContext(name="test")
832+
assert ctx._should_continue(True) is True
833+
mock_torchrun.assert_called_once()
834+
835+
836+
class TestRunContextLaunch:
837+
"""Test RunContext.launch method."""
838+
839+
def test_launch_with_dryrun(self):
840+
"""Test launch with dryrun."""
841+
ctx = run.cli.RunContext(name="test_run", dryrun=True)
842+
mock_experiment = Mock(spec=run.Experiment)
843+
844+
ctx.launch(mock_experiment)
845+
846+
mock_experiment.dryrun.assert_called_once()
847+
mock_experiment.run.assert_not_called()
848+
849+
def test_launch_normal(self):
850+
"""Test launch without dryrun."""
851+
ctx = run.cli.RunContext(name="test_run", direct=True, tail_logs=True)
852+
mock_experiment = Mock(spec=run.Experiment)
853+
854+
ctx.launch(mock_experiment)
855+
856+
mock_experiment.run.assert_called_once_with(
857+
sequential=False, detach=False, direct=True, tail_logs=True
858+
)
859+
860+
def test_launch_with_executor(self):
861+
"""Test launch with executor specified."""
862+
ctx = run.cli.RunContext(name="test_run")
863+
ctx.executor = Mock(spec=run.LocalExecutor)
864+
mock_experiment = Mock(spec=run.Experiment)
865+
866+
ctx.launch(mock_experiment)
867+
868+
mock_experiment.run.assert_called_once_with(
869+
sequential=False, detach=False, direct=False, tail_logs=False
870+
)
871+
872+
def test_launch_sequential(self):
873+
"""Test launch with sequential=True."""
874+
ctx = run.cli.RunContext(name="test_run")
875+
# Initialize executor to None explicitly
876+
ctx.executor = None
877+
mock_experiment = Mock(spec=run.Experiment)
878+
879+
ctx.launch(mock_experiment, sequential=True)
880+
881+
mock_experiment.run.assert_called_once_with(
882+
sequential=True, detach=False, direct=True, tail_logs=False
883+
)
884+
885+
886+
class TestParsePrefixedArgs:
887+
"""Test _parse_prefixed_args function."""
888+
889+
def test_parse_prefixed_args_simple(self):
890+
"""Test parsing simple prefixed arguments."""
891+
from nemo_run.cli.api import _parse_prefixed_args
892+
893+
args = ["executor=local", "other=value"]
894+
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "executor")
895+
896+
assert prefix_value == "local"
897+
assert prefix_args == []
898+
assert other_args == ["other=value"]
899+
900+
def test_parse_prefixed_args_with_dot_notation(self):
901+
"""Test parsing prefixed arguments with dot notation."""
902+
from nemo_run.cli.api import _parse_prefixed_args
903+
904+
args = ["executor=local", "executor.gpu=2", "other=value"]
905+
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "executor")
906+
907+
assert prefix_value == "local"
908+
assert prefix_args == ["gpu=2"]
909+
assert other_args == ["other=value"]
910+
911+
def test_parse_prefixed_args_with_brackets(self):
912+
"""Test parsing prefixed arguments with bracket notation."""
913+
from nemo_run.cli.api import _parse_prefixed_args
914+
915+
args = ["plugins=list", "plugins[0].name=test", "other=value"]
916+
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "plugins")
917+
918+
assert prefix_value == "list"
919+
assert prefix_args == ["[0].name=test"]
920+
assert other_args == ["other=value"]
921+
922+
def test_parse_prefixed_args_invalid_format(self):
923+
"""Test parsing prefixed arguments with invalid format."""
924+
from nemo_run.cli.api import _parse_prefixed_args
925+
926+
args = ["executorblah", "other=value"]
927+
with pytest.raises(ValueError, match="Executor overwrites must start with 'executor.'"):
928+
_parse_prefixed_args(args, "executor")
929+
930+
def test_parse_prefixed_args_no_prefix(self):
931+
"""Test parsing when no prefixed arguments are present."""
932+
from nemo_run.cli.api import _parse_prefixed_args
933+
934+
args = ["arg1=value1", "arg2=value2"]
935+
prefix_value, prefix_args, other_args = _parse_prefixed_args(args, "executor")
936+
937+
assert prefix_value is None
938+
assert prefix_args == []
939+
assert other_args == ["arg1=value1", "arg2=value2"]

test/cli/test_cli_parser.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,30 @@
1515

1616
import sys
1717
from pathlib import Path
18-
from test.dummy_factory import DummyModel
1918
from typing import Any, Dict, List, Literal, Optional, Type, Union
2019

2120
import pytest
2221

2322
from nemo_run.cli.cli_parser import (
2423
ArgumentParsingError,
2524
ArgumentValueError,
25+
CLIException,
26+
CollectionParseError,
2627
DictParseError,
2728
ListParseError,
2829
LiteralParseError,
2930
OperationError,
3031
ParseError,
3132
PythonicParser,
3233
TypeParser,
34+
TypeParsingError,
3335
UndefinedVariableError,
3436
UnknownTypeError,
3537
parse_cli_args,
3638
parse_value,
3739
)
3840
from nemo_run.config import Config, Partial
41+
from test.dummy_factory import DummyModel
3942

4043

4144
class TestSimpleValueParsing:
@@ -664,3 +667,91 @@ def func(a: List[Dict[str, Union[int, List[str]]]]):
664667

665668
result = parse_cli_args(func, ["a=[{'x': 1, 'y': ['a', 'b']}, {'z': 2}]"])
666669
assert result.a == [{"x": 1, "y": ["a", "b"]}, {"z": 2}]
670+
671+
672+
class TestCLIException:
673+
"""Test the CLIException class hierarchy."""
674+
675+
def test_cli_exception_base(self):
676+
"""Test the base CLIException class."""
677+
ex = CLIException("Test message", "test_arg", {"key": "value"})
678+
assert "Test message" in str(ex)
679+
assert "test_arg" in str(ex)
680+
assert "{'key': 'value'}" in str(ex)
681+
assert ex.arg == "test_arg"
682+
assert ex.context == {"key": "value"}
683+
684+
def test_user_friendly_message(self):
685+
"""Test the user_friendly_message method."""
686+
ex = CLIException("Test message", "test_arg", {"key": "value"})
687+
friendly = ex.user_friendly_message()
688+
assert "Error processing argument 'test_arg'" in friendly
689+
assert "Test message" in friendly
690+
691+
def test_argument_parsing_error(self):
692+
"""Test ArgumentParsingError."""
693+
ex = ArgumentParsingError("Invalid syntax", "bad=arg", {"line": 10})
694+
assert isinstance(ex, CLIException)
695+
assert "Invalid syntax" in str(ex)
696+
697+
def test_type_parsing_error(self):
698+
"""Test TypeParsingError."""
699+
ex = TypeParsingError("Type mismatch", "arg=value", {"expected": "int"})
700+
assert isinstance(ex, CLIException)
701+
assert "Type mismatch" in str(ex)
702+
703+
def test_operation_error(self):
704+
"""Test OperationError."""
705+
ex = OperationError("Invalid operation", "arg+=value", {"op": "+="})
706+
assert isinstance(ex, CLIException)
707+
assert "Invalid operation" in str(ex)
708+
709+
def test_argument_value_error(self):
710+
"""Test ArgumentValueError."""
711+
ex = ArgumentValueError("Invalid value", "arg=value", {"expected": "option"})
712+
assert isinstance(ex, CLIException)
713+
assert "Invalid value" in str(ex)
714+
715+
def test_undefined_variable_error(self):
716+
"""Test UndefinedVariableError."""
717+
ex = UndefinedVariableError("Variable not defined", "undefined+=1", {})
718+
assert isinstance(ex, CLIException)
719+
assert "Variable not defined" in str(ex)
720+
721+
def test_parse_error(self):
722+
"""Test ParseError."""
723+
ex = ParseError("abc", int, "Cannot convert string to int")
724+
assert isinstance(ex, CLIException)
725+
assert "Failed to parse 'abc' as <class 'int'>" in str(ex)
726+
assert ex.value == "abc"
727+
assert ex.reason == "Cannot convert string to int"
728+
729+
def test_literal_parse_error(self):
730+
"""Test LiteralParseError."""
731+
ex = LiteralParseError("red", Literal, "Expected one of ['blue', 'green']")
732+
assert isinstance(ex, ParseError)
733+
assert "Failed to parse 'red'" in str(ex)
734+
735+
def test_collection_parse_error(self):
736+
"""Test CollectionParseError."""
737+
ex = CollectionParseError("[1,2,", list, "Invalid syntax")
738+
assert isinstance(ex, ParseError)
739+
assert "Failed to parse '[1,2,'" in str(ex)
740+
741+
def test_list_parse_error(self):
742+
"""Test ListParseError."""
743+
ex = ListParseError("[1,2,", list, "Invalid syntax")
744+
assert isinstance(ex, CollectionParseError)
745+
assert "Failed to parse '[1,2,'" in str(ex)
746+
747+
def test_dict_parse_error(self):
748+
"""Test DictParseError."""
749+
ex = DictParseError("{1:2,", dict, "Invalid syntax")
750+
assert isinstance(ex, CollectionParseError)
751+
assert "Failed to parse '{1:2,'" in str(ex)
752+
753+
def test_unknown_type_error(self):
754+
"""Test UnknownTypeError."""
755+
ex = UnknownTypeError("value", str, "Unknown type")
756+
assert isinstance(ex, ParseError)
757+
assert "Failed to parse 'value'" in str(ex)

0 commit comments

Comments
 (0)