Skip to content

Commit 6b34a8e

Browse files
authored
Added Pre-Launch Commands Support to LeptonExecutor (#312)
* update lepton executor to include custom prelaunch commands section Signed-off-by: ansjindal <[email protected]> * add test for prelaunch section Signed-off-by: ansjindal <[email protected]> * add more tests for checking the pre-launch-commands section Signed-off-by: ansjindal <[email protected]> * update lepton executor tests Signed-off-by: ansjindal <[email protected]> --------- Signed-off-by: ansjindal <[email protected]>
1 parent 0a0ed3d commit 6b34a8e

File tree

2 files changed

+222
-1
lines changed

2 files changed

+222
-1
lines changed

nemo_run/core/execution/lepton.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class LeptonExecutor(Executor):
5454
mounts: list[dict[str, Any]] = field(default_factory=list)
5555
lepton_job_dir: str = field(init=False, default="")
5656
custom_spec: dict[str, Any] = field(default_factory=dict)
57+
pre_launch_commands: list[str] = field(default_factory=list) # Custom commands before launch
5758

5859
def stop_job(self, job_id: str):
5960
"""
@@ -244,8 +245,14 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
244245
if len(name) > 35:
245246
logger.warning("length of name exceeds 35 characters. Shortening...")
246247
name = name[:34]
248+
249+
# Build pre-launch commands section
250+
pre_launch_section = ""
251+
if self.pre_launch_commands:
252+
pre_launch_section = "\n".join(self.pre_launch_commands) + "\n"
253+
247254
launch_script = f"""
248-
wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh
255+
{pre_launch_section}wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh
249256
chmod +x init.sh
250257
source init.sh
251258
ln -s {self.lepton_job_dir}/ /nemo_run

test/core/execution/test_lepton.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,217 @@ def test_macro_values(self):
641641
result = executor.macro_values()
642642

643643
assert result is None
644+
645+
def test_pre_launch_commands_initialization(self):
646+
"""Test that pre_launch_commands can be initialized and defaults to empty list."""
647+
# Test default initialization
648+
executor = LeptonExecutor(
649+
container_image="test-image",
650+
nemo_run_dir="/test/path",
651+
)
652+
assert executor.pre_launch_commands == []
653+
654+
# Test initialization with commands
655+
commands = ["echo 'Setting up environment'", "export TEST_VAR=value"]
656+
executor_with_commands = LeptonExecutor(
657+
container_image="test-image",
658+
nemo_run_dir="/test/path",
659+
pre_launch_commands=commands,
660+
)
661+
assert executor_with_commands.pre_launch_commands == commands
662+
663+
def test_launch_script_with_pre_launch_commands(self):
664+
"""Test that pre_launch_commands are correctly included in the launch script."""
665+
666+
# Test without pre_launch_commands
667+
executor = LeptonExecutor(
668+
container_image="test-image",
669+
nemo_run_dir="/test/path",
670+
)
671+
672+
# Test script section generation - empty case
673+
pre_launch_section = ""
674+
if executor.pre_launch_commands:
675+
pre_launch_section = "\n".join(executor.pre_launch_commands) + "\n"
676+
assert pre_launch_section == ""
677+
678+
# Test with pre_launch_commands
679+
commands = ["echo 'Custom setup'", "export MY_VAR=test"]
680+
executor_with_commands = LeptonExecutor(
681+
container_image="test-image",
682+
nemo_run_dir="/test/path",
683+
pre_launch_commands=commands,
684+
)
685+
686+
# Test script section generation - with commands
687+
pre_launch_section_with_commands = ""
688+
if executor_with_commands.pre_launch_commands:
689+
pre_launch_section_with_commands = (
690+
"\n".join(executor_with_commands.pre_launch_commands) + "\n"
691+
)
692+
693+
expected_pre_launch = "echo 'Custom setup'\nexport MY_VAR=test\n"
694+
assert pre_launch_section_with_commands == expected_pre_launch
695+
696+
@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
697+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
698+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
699+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
700+
@patch("builtins.open", new_callable=mock_open)
701+
@patch("os.path.join")
702+
@patch("nemo_run.core.execution.lepton.logger")
703+
def test_launch_method_comprehensive(
704+
self,
705+
mock_logger,
706+
mock_join,
707+
mock_file,
708+
mock_status,
709+
mock_create_job,
710+
mock_move_data,
711+
mock_validate_mounts,
712+
):
713+
"""Test launch method name validation, pre_launch_commands, and script generation."""
714+
# Setup
715+
executor = LeptonExecutor(
716+
container_image="test-image", nemo_run_dir="/test", pre_launch_commands=["echo setup"]
717+
)
718+
executor.job_dir = executor.lepton_job_dir = "/fake"
719+
mock_join.return_value = "/fake/script.sh"
720+
mock_job = MagicMock()
721+
mock_job.metadata.id_ = "job-id"
722+
mock_create_job.return_value = mock_job
723+
mock_status.return_value = LeptonJobState.Running
724+
725+
# Test name transformation and pre_launch_commands
726+
job_id, status = executor.launch("Test_Job.Name", ["python", "script.py"])
727+
assert job_id == "job-id"
728+
729+
# Verify script content includes pre_launch_commands
730+
handle = mock_file.return_value.__enter__.return_value
731+
written_content = handle.write.call_args[0][0]
732+
assert "echo setup\n" in written_content
733+
assert "python script.py" in written_content
734+
735+
# Test long name truncation
736+
long_name = "a" * 50
737+
executor.launch(long_name, ["cmd"])
738+
mock_logger.warning.assert_called_with(
739+
"length of name exceeds 35 characters. Shortening..."
740+
)
741+
742+
@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
743+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
744+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
745+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
746+
@patch("builtins.open", new_callable=mock_open)
747+
@patch("os.path.join")
748+
@patch("nemo_run.core.execution.lepton.logger")
749+
def test_launch_error_paths(
750+
self,
751+
mock_logger,
752+
mock_join,
753+
mock_file,
754+
mock_status,
755+
mock_create_job,
756+
mock_move_data,
757+
mock_validate_mounts,
758+
):
759+
"""Test launch method error handling and logging."""
760+
executor = LeptonExecutor(container_image="test-image", nemo_run_dir="/test/path")
761+
executor.job_dir = executor.lepton_job_dir = "/fake/dir"
762+
mock_join.return_value = "/fake/launch_script.sh"
763+
764+
# Test job creation failure
765+
mock_create_job.return_value = None
766+
with pytest.raises(RuntimeError, match="Failed to create Lepton job"):
767+
executor.launch("test", ["cmd"])
768+
mock_logger.info.assert_any_call("Creating distributed workload")
769+
770+
# Test missing job ID
771+
mock_job = MagicMock()
772+
mock_job.metadata.id_ = None
773+
mock_create_job.return_value = mock_job
774+
with pytest.raises(RuntimeError, match="Failed to retrieve job information"):
775+
executor.launch("test", ["cmd"])
776+
777+
# Test status failure
778+
mock_job.metadata.id_ = "job-id"
779+
mock_status.return_value = None
780+
with pytest.raises(RuntimeError, match="Failed to retrieve job status"):
781+
executor.launch("test", ["cmd"])
782+
783+
# Test success path with logging
784+
mock_status.return_value = LeptonJobState.Running
785+
job_id, status = executor.launch("test", ["cmd"])
786+
assert job_id == "job-id"
787+
mock_logger.info.assert_any_call("Copying experiment directory to remote filesystem")
788+
789+
@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
790+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
791+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
792+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
793+
@patch("builtins.open", new_callable=mock_open)
794+
@patch("os.path.join")
795+
@patch("nemo_run.core.execution.lepton.logger")
796+
def test_launch_long_name_truncation(
797+
self,
798+
mock_logger,
799+
mock_join,
800+
mock_file,
801+
mock_status,
802+
mock_create_job,
803+
mock_move_data,
804+
mock_validate_mounts,
805+
):
806+
"""Test name truncation warning and logic (lines 246-247)."""
807+
executor = LeptonExecutor(container_image="test-image", nemo_run_dir="/test/path")
808+
executor.job_dir = executor.lepton_job_dir = "/fake/dir"
809+
mock_join.return_value = "/fake/launch_script.sh"
810+
811+
mock_job = MagicMock()
812+
mock_job.metadata.id_ = "job-id"
813+
mock_create_job.return_value = mock_job
814+
mock_status.return_value = LeptonJobState.Running
815+
816+
# Test long name triggers warning and truncation
817+
long_name = "a" * 50 # 50 characters, exceeds 35
818+
executor.launch(long_name, ["cmd"])
819+
mock_logger.warning.assert_called_with(
820+
"length of name exceeds 35 characters. Shortening..."
821+
)
822+
823+
@patch("nemo_run.core.execution.lepton.LeptonExecutor._validate_mounts")
824+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.move_data")
825+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.create_lepton_job")
826+
@patch("nemo_run.core.execution.lepton.LeptonExecutor.status")
827+
@patch("builtins.open", new_callable=mock_open)
828+
@patch("os.path.join")
829+
def test_launch_prelaunch_commands_join(
830+
self,
831+
mock_join,
832+
mock_file,
833+
mock_status,
834+
mock_create_job,
835+
mock_move_data,
836+
mock_validate_mounts,
837+
):
838+
"""Test pre_launch_commands joining logic (line 252)."""
839+
executor = LeptonExecutor(
840+
container_image="test-image",
841+
nemo_run_dir="/test/path",
842+
pre_launch_commands=["echo setup", "export VAR=1"],
843+
)
844+
executor.job_dir = executor.lepton_job_dir = "/fake/dir"
845+
mock_join.return_value = "/fake/launch_script.sh"
846+
847+
mock_job = MagicMock()
848+
mock_job.metadata.id_ = "job-id"
849+
mock_create_job.return_value = mock_job
850+
mock_status.return_value = LeptonJobState.Running
851+
852+
executor.launch("test", ["cmd"])
853+
854+
# Verify script contains joined pre_launch_commands
855+
handle = mock_file.return_value.__enter__.return_value
856+
written_content = handle.write.call_args[0][0]
857+
assert "echo setup\nexport VAR=1\n" in written_content

0 commit comments

Comments
 (0)