Skip to content

Commit 0870eec

Browse files
committed
fix
Signed-off-by: Hemil Desai <[email protected]>
1 parent a141c36 commit 0870eec

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

test/run/test_experiment.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,3 +1448,66 @@ def test_experiment_skip_status_at_exit(mock_get_runner, temp_dir):
14481448
with patch.object(exp, "status") as mock_status:
14491449
pass # Leaving the context triggers __exit__
14501450
mock_status.assert_not_called()
1451+
1452+
1453+
def test_experiment_status_includes_handle(temp_dir):
1454+
"""status(return_dict=True) should include handle field added in diff."""
1455+
with Experiment("test-exp", base_dir=temp_dir) as exp:
1456+
task = run.Partial(dummy_function, x=1, y=2)
1457+
job_id = exp.add(task, name="job-status")
1458+
# set job launched and handle
1459+
exp.jobs[0].launched = True
1460+
exp.jobs[0].handle = "handle-123"
1461+
exp.jobs[0].status = MagicMock(return_value=AppState.SUCCEEDED)
1462+
1463+
status_dict = exp.status(return_dict=True)
1464+
assert status_dict
1465+
assert status_dict[job_id]["handle"] == "handle-123"
1466+
1467+
1468+
def test_initialize_tunnels_extract_from_executors(temp_dir):
1469+
"""_initialize_tunnels(extract_from_executors=True) should add tunnels from slurm executors and call connect."""
1470+
1471+
# Fake Tunnel
1472+
class FakeTunnel:
1473+
def __init__(self):
1474+
self.key = "t1"
1475+
self.session = None
1476+
self.connected = False
1477+
1478+
def connect(self):
1479+
self.connected = True
1480+
self.session = "sess"
1481+
1482+
def to_config(self):
1483+
return run.Config(FakeTunnel)
1484+
1485+
# Fake SlurmExecutor
1486+
class FakeSlurmExecutor(LocalExecutor):
1487+
def __init__(self):
1488+
super().__init__()
1489+
self.tunnel = FakeTunnel()
1490+
1491+
# override clone to avoid deep copy issues
1492+
def clone(self):
1493+
return self
1494+
1495+
def to_config(self):
1496+
# Minimal config stub acceptable for tests
1497+
return run.Config(FakeSlurmExecutor)
1498+
1499+
with patch("nemo_run.run.experiment.SlurmExecutor", FakeSlurmExecutor):
1500+
with Experiment("test-exp", base_dir=temp_dir) as exp:
1501+
# Create a Job manually to avoid executor.clone
1502+
from nemo_run.run.job import Job
1503+
1504+
job = Job(
1505+
id="slurm-job",
1506+
task=run.Partial(dummy_function, x=1, y=2),
1507+
executor=FakeSlurmExecutor(),
1508+
)
1509+
exp.jobs = [job] # replace jobs list directly
1510+
1511+
# Should pull tunnel and connect
1512+
exp._initialize_tunnels(extract_from_executors=True)
1513+
assert "t1" in exp.tunnels

test/run/test_job.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,33 @@ def test_job_dryrun_info_stored_and_reused(simple_task, docker_executor, mock_ru
707707
# Extract kwargs of second call
708708
_, second_kwargs = mock_launch.call_args
709709
assert second_kwargs["dryrun_info"] == "plan"
710+
711+
712+
# Additional tests for serialize_metadata_for_scripts flag
713+
714+
715+
def test_job_prepare_serialize_metadata_flag(simple_task, docker_executor):
716+
"""Job.prepare should forward serialize_metadata_for_scripts to package()."""
717+
job = Job(id="j1", task=simple_task, executor=docker_executor)
718+
719+
with patch("nemo_run.run.job.package") as mock_package:
720+
job.prepare(serialize_metadata_for_scripts=False)
721+
mock_package.assert_called_once()
722+
# get call kwargs to ensure flag propagated
723+
_, kwargs = mock_package.call_args
724+
assert kwargs["serialize_metadata_for_scripts"] is False
725+
726+
727+
def test_job_group_prepare_serialize_metadata_flag(simple_task, docker_executor):
728+
"""JobGroup.prepare should forward serialize_metadata_for_scripts to package() for each task."""
729+
group = JobGroup(id="g1", tasks=[simple_task, simple_task], executors=docker_executor)
730+
group._merge = False
731+
group.executors = [docker_executor] * 2
732+
733+
with patch("nemo_run.run.job.package") as mock_package:
734+
group.prepare(serialize_metadata_for_scripts=False)
735+
# Called for each task (2)
736+
assert mock_package.call_count == 2
737+
# Verify at least one call had flag False
738+
for _args, kwargs in mock_package.call_args_list:
739+
assert kwargs["serialize_metadata_for_scripts"] is False

0 commit comments

Comments
 (0)