Skip to content

Commit c92f2c3

Browse files
committed
test: Add unit tests for storage_mounts functionality
- Test storage_mounts parameter initialization - Test to_task() method with storage_mounts configurations - Test combined file_mounts and storage_mounts usage - Verify Storage.from_yaml_config() integration - Ensure backward compatibility when storage_mounts is None Signed-off-by: Andy Lee <[email protected]>
1 parent 35f4930 commit c92f2c3

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

test/core/execution/test_skypilot.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,160 @@ def test_to_task(self, mock_task, mock_skypilot_imports, executor):
561561

562562
# Verify the returned task is our mock
563563
assert result == mock_task_instance
564+
565+
@patch("sky.task.Task")
566+
def test_to_task_with_storage_mounts(self, mock_task, mock_skypilot_imports):
567+
# Create a mock task instance
568+
mock_task_instance = MagicMock()
569+
mock_task.return_value = mock_task_instance
570+
mock_task_instance.set_file_mounts = MagicMock()
571+
mock_task_instance.set_storage_mounts = MagicMock()
572+
mock_task_instance.set_resources = MagicMock()
573+
574+
# Mock sky.data.Storage
575+
mock_storage_class = MagicMock()
576+
mock_storage_obj = MagicMock()
577+
mock_storage_class.from_yaml_config.return_value = mock_storage_obj
578+
579+
executor = SkypilotExecutor(
580+
container_image="test:latest",
581+
storage_mounts={
582+
"/workspace/outputs": {
583+
"name": "my-outputs",
584+
"store": "gcs",
585+
"mode": "MOUNT",
586+
"persistent": True,
587+
}
588+
},
589+
)
590+
591+
with tempfile.TemporaryDirectory() as tmp_dir:
592+
executor.job_dir = tmp_dir
593+
594+
with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
595+
mock_to_resources.return_value = MagicMock()
596+
597+
with patch("sky.data.Storage", mock_storage_class):
598+
result = executor.to_task("test_task")
599+
600+
# Verify Storage.from_yaml_config was called with the config
601+
mock_storage_class.from_yaml_config.assert_called_once_with(
602+
{
603+
"name": "my-outputs",
604+
"store": "gcs",
605+
"mode": "MOUNT",
606+
"persistent": True,
607+
}
608+
)
609+
610+
# Verify set_storage_mounts was called with Storage objects
611+
mock_task_instance.set_storage_mounts.assert_called_once()
612+
storage_mounts_call = mock_task_instance.set_storage_mounts.call_args[0][0]
613+
assert "/workspace/outputs" in storage_mounts_call
614+
assert storage_mounts_call["/workspace/outputs"] == mock_storage_obj
615+
616+
@patch("sky.task.Task")
617+
def test_to_task_with_both_file_and_storage_mounts(self, mock_task, mock_skypilot_imports):
618+
# Create a mock task instance
619+
mock_task_instance = MagicMock()
620+
mock_task.return_value = mock_task_instance
621+
mock_task_instance.set_file_mounts = MagicMock()
622+
mock_task_instance.set_storage_mounts = MagicMock()
623+
mock_task_instance.set_resources = MagicMock()
624+
625+
# Mock sky.data.Storage
626+
mock_storage_class = MagicMock()
627+
mock_storage_obj = MagicMock()
628+
mock_storage_class.from_yaml_config.return_value = mock_storage_obj
629+
630+
executor = SkypilotExecutor(
631+
container_image="test:latest",
632+
file_mounts={
633+
"/workspace/code": "/local/path/to/code",
634+
},
635+
storage_mounts={
636+
"/workspace/outputs": {
637+
"name": "my-outputs",
638+
"store": "s3",
639+
"mode": "MOUNT",
640+
},
641+
"/workspace/checkpoints": {
642+
"name": "my-checkpoints",
643+
"store": "gcs",
644+
"mode": "MOUNT_CACHED",
645+
},
646+
},
647+
)
648+
649+
with tempfile.TemporaryDirectory() as tmp_dir:
650+
executor.job_dir = tmp_dir
651+
652+
with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
653+
mock_to_resources.return_value = MagicMock()
654+
655+
with patch("sky.data.Storage", mock_storage_class):
656+
result = executor.to_task("test_task")
657+
658+
# Verify file_mounts includes both user files and nemo_run
659+
file_mounts_call = mock_task_instance.set_file_mounts.call_args[0][0]
660+
assert "/workspace/code" in file_mounts_call
661+
assert file_mounts_call["/workspace/code"] == "/local/path/to/code"
662+
assert "/nemo_run" in file_mounts_call
663+
assert file_mounts_call["/nemo_run"] == tmp_dir
664+
665+
# Verify Storage.from_yaml_config was called for both storage mounts
666+
assert mock_storage_class.from_yaml_config.call_count == 2
667+
668+
# Verify set_storage_mounts was called with both Storage objects
669+
mock_task_instance.set_storage_mounts.assert_called_once()
670+
storage_mounts_call = mock_task_instance.set_storage_mounts.call_args[0][0]
671+
assert "/workspace/outputs" in storage_mounts_call
672+
assert "/workspace/checkpoints" in storage_mounts_call
673+
assert len(storage_mounts_call) == 2
674+
675+
@patch("sky.task.Task")
676+
def test_to_task_without_storage_mounts(self, mock_task, mock_skypilot_imports):
677+
# Test that set_storage_mounts is not called when storage_mounts is None
678+
mock_task_instance = MagicMock()
679+
mock_task.return_value = mock_task_instance
680+
mock_task_instance.set_file_mounts = MagicMock()
681+
mock_task_instance.set_storage_mounts = MagicMock()
682+
mock_task_instance.set_resources = MagicMock()
683+
684+
executor = SkypilotExecutor(
685+
container_image="test:latest",
686+
file_mounts={"/workspace/code": "/local/path"},
687+
storage_mounts=None, # Explicitly set to None
688+
)
689+
690+
with tempfile.TemporaryDirectory() as tmp_dir:
691+
executor.job_dir = tmp_dir
692+
693+
with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
694+
mock_to_resources.return_value = MagicMock()
695+
696+
result = executor.to_task("test_task")
697+
698+
# Verify set_storage_mounts was NOT called
699+
mock_task_instance.set_storage_mounts.assert_not_called()
700+
701+
# Verify file_mounts still works
702+
mock_task_instance.set_file_mounts.assert_called_once()
703+
704+
def test_init_with_storage_mounts(self, mock_skypilot_imports):
705+
# Test initialization with storage_mounts parameter
706+
executor = SkypilotExecutor(
707+
container_image="test:latest",
708+
storage_mounts={
709+
"/workspace/data": {
710+
"name": "training-data",
711+
"store": "s3",
712+
"mode": "MOUNT",
713+
}
714+
},
715+
)
716+
717+
assert executor.storage_mounts is not None
718+
assert "/workspace/data" in executor.storage_mounts
719+
assert executor.storage_mounts["/workspace/data"]["name"] == "training-data"
720+
assert executor.storage_mounts["/workspace/data"]["store"] == "s3"

0 commit comments

Comments
 (0)