Skip to content

Commit afcd979

Browse files
andylizfhemildesai
andauthored
Support SkyPilot Storage configurations in file_mounts for automatic cloud sync (#335)
* fix: support for SkyPilot Storage configurations in file_mounts - Modified SkypilotExecutor to handle both string paths and dict configs in file_mounts - Dictionary configs are automatically converted to sky.Storage objects - Enables automatic cloud storage mounting (GCS, S3, etc.) for outputs This change allows users to specify cloud storage backends directly in file_mounts, enabling automatic synchronization of training outputs to cloud storage without manual rsync operations. Signed-off-by: Andy Lee <[email protected]> * refactor: Separate storage_mounts from file_mounts for cleaner API Signed-off-by: Andy Lee <[email protected]> * test: Add unit tests for storage_mounts functionality - Test storage_mounts parameter initialization - Test to_task() method with storage_mounts configurations - Test combined file_mounts and storage_mounts usage - Verify Storage.from_yaml_config() integration - Ensure backward compatibility when storage_mounts is None Signed-off-by: Andy Lee <[email protected]> * fix tests Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Andy Lee <[email protected]> Signed-off-by: Hemil Desai <[email protected]> Co-authored-by: Hemil Desai <[email protected]>
1 parent c16a572 commit afcd979

File tree

2 files changed

+187
-1
lines changed

2 files changed

+187
-1
lines changed

nemo_run/core/execution/skypilot.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,22 @@ class SkypilotExecutor(Executor):
6565
network_tier="best",
6666
cluster_name="nemo_tester",
6767
file_mounts={
68-
"nemo_run.whl": "nemo_run.whl"
68+
"nemo_run.whl": "nemo_run.whl",
69+
"/workspace/code": "/local/path/to/code",
70+
},
71+
storage_mounts={
72+
"/workspace/outputs": {
73+
"name": "my-training-outputs",
74+
"store": "gcs", # or "s3", "azure", etc.
75+
"mode": "MOUNT",
76+
"persistent": True,
77+
},
78+
"/workspace/checkpoints": {
79+
"name": "model-checkpoints",
80+
"store": "s3",
81+
"mode": "MOUNT",
82+
"persistent": True,
83+
}
6984
},
7085
setup=\"\"\"
7186
conda deactivate
@@ -99,6 +114,7 @@ class SkypilotExecutor(Executor):
99114
disk_tier: Optional[Union[str, list[str]]] = None
100115
ports: Optional[tuple[str]] = None
101116
file_mounts: Optional[dict[str, str]] = None
117+
storage_mounts: Optional[dict[str, dict[str, Any]]] = None # Can be str or dict configs
102118
cluster_name: Optional[str] = None
103119
setup: Optional[str] = None
104120
autodown: bool = False
@@ -372,9 +388,22 @@ def to_task(
372388
envs=self.env_vars,
373389
num_nodes=self.num_nodes,
374390
)
391+
# Handle regular file mounts
375392
file_mounts = self.file_mounts or {}
376393
file_mounts["/nemo_run"] = self.job_dir
377394
task.set_file_mounts(file_mounts)
395+
396+
# Handle storage mounts separately
397+
if self.storage_mounts:
398+
from sky.data import Storage
399+
400+
storage_objects = {}
401+
for mount_path, config in self.storage_mounts.items():
402+
# Create Storage object from config dict
403+
storage_obj = Storage.from_yaml_config(config)
404+
storage_objects[mount_path] = storage_obj
405+
task.set_storage_mounts(storage_objects)
406+
378407
task.set_resources(self.to_resources())
379408

380409
if env_vars:

test/core/execution/test_skypilot.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,160 @@ def test_to_task(self, mock_task, mock_skypilot_imports, executor):
561561

562562
# Verify the returned task is our mock
563563
assert result == mock_task_instance
564+
565+
@patch("sky.task.Task")
566+
def test_to_task_with_storage_mounts(self, mock_task, mock_skypilot_imports):
567+
# Create a mock task instance
568+
mock_task_instance = MagicMock()
569+
mock_task.return_value = mock_task_instance
570+
mock_task_instance.set_file_mounts = MagicMock()
571+
mock_task_instance.set_storage_mounts = MagicMock()
572+
mock_task_instance.set_resources = MagicMock()
573+
574+
# Mock sky.data.Storage
575+
mock_storage_class = MagicMock()
576+
mock_storage_obj = MagicMock()
577+
mock_storage_class.from_yaml_config.return_value = mock_storage_obj
578+
579+
executor = SkypilotExecutor(
580+
container_image="test:latest",
581+
storage_mounts={
582+
"/workspace/outputs": {
583+
"name": "my-outputs",
584+
"store": "gcs",
585+
"mode": "MOUNT",
586+
"persistent": True,
587+
}
588+
},
589+
)
590+
591+
with tempfile.TemporaryDirectory() as tmp_dir:
592+
executor.job_dir = tmp_dir
593+
594+
with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
595+
mock_to_resources.return_value = MagicMock()
596+
597+
with patch("sky.data.Storage", mock_storage_class):
598+
executor.to_task("test_task")
599+
600+
# Verify Storage.from_yaml_config was called with the config
601+
mock_storage_class.from_yaml_config.assert_called_once_with(
602+
{
603+
"name": "my-outputs",
604+
"store": "gcs",
605+
"mode": "MOUNT",
606+
"persistent": True,
607+
}
608+
)
609+
610+
# Verify set_storage_mounts was called with Storage objects
611+
mock_task_instance.set_storage_mounts.assert_called_once()
612+
storage_mounts_call = mock_task_instance.set_storage_mounts.call_args[0][0]
613+
assert "/workspace/outputs" in storage_mounts_call
614+
assert storage_mounts_call["/workspace/outputs"] == mock_storage_obj
615+
616+
@patch("sky.task.Task")
617+
def test_to_task_with_both_file_and_storage_mounts(self, mock_task, mock_skypilot_imports):
618+
# Create a mock task instance
619+
mock_task_instance = MagicMock()
620+
mock_task.return_value = mock_task_instance
621+
mock_task_instance.set_file_mounts = MagicMock()
622+
mock_task_instance.set_storage_mounts = MagicMock()
623+
mock_task_instance.set_resources = MagicMock()
624+
625+
# Mock sky.data.Storage
626+
mock_storage_class = MagicMock()
627+
mock_storage_obj = MagicMock()
628+
mock_storage_class.from_yaml_config.return_value = mock_storage_obj
629+
630+
executor = SkypilotExecutor(
631+
container_image="test:latest",
632+
file_mounts={
633+
"/workspace/code": "/local/path/to/code",
634+
},
635+
storage_mounts={
636+
"/workspace/outputs": {
637+
"name": "my-outputs",
638+
"store": "s3",
639+
"mode": "MOUNT",
640+
},
641+
"/workspace/checkpoints": {
642+
"name": "my-checkpoints",
643+
"store": "gcs",
644+
"mode": "MOUNT_CACHED",
645+
},
646+
},
647+
)
648+
649+
with tempfile.TemporaryDirectory() as tmp_dir:
650+
executor.job_dir = tmp_dir
651+
652+
with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
653+
mock_to_resources.return_value = MagicMock()
654+
655+
with patch("sky.data.Storage", mock_storage_class):
656+
executor.to_task("test_task")
657+
658+
# Verify file_mounts includes both user files and nemo_run
659+
file_mounts_call = mock_task_instance.set_file_mounts.call_args[0][0]
660+
assert "/workspace/code" in file_mounts_call
661+
assert file_mounts_call["/workspace/code"] == "/local/path/to/code"
662+
assert "/nemo_run" in file_mounts_call
663+
assert file_mounts_call["/nemo_run"] == tmp_dir
664+
665+
# Verify Storage.from_yaml_config was called for both storage mounts
666+
assert mock_storage_class.from_yaml_config.call_count == 2
667+
668+
# Verify set_storage_mounts was called with both Storage objects
669+
mock_task_instance.set_storage_mounts.assert_called_once()
670+
storage_mounts_call = mock_task_instance.set_storage_mounts.call_args[0][0]
671+
assert "/workspace/outputs" in storage_mounts_call
672+
assert "/workspace/checkpoints" in storage_mounts_call
673+
assert len(storage_mounts_call) == 2
674+
675+
@patch("sky.task.Task")
676+
def test_to_task_without_storage_mounts(self, mock_task, mock_skypilot_imports):
677+
# Test that set_storage_mounts is not called when storage_mounts is None
678+
mock_task_instance = MagicMock()
679+
mock_task.return_value = mock_task_instance
680+
mock_task_instance.set_file_mounts = MagicMock()
681+
mock_task_instance.set_storage_mounts = MagicMock()
682+
mock_task_instance.set_resources = MagicMock()
683+
684+
executor = SkypilotExecutor(
685+
container_image="test:latest",
686+
file_mounts={"/workspace/code": "/local/path"},
687+
storage_mounts=None, # Explicitly set to None
688+
)
689+
690+
with tempfile.TemporaryDirectory() as tmp_dir:
691+
executor.job_dir = tmp_dir
692+
693+
with patch.object(SkypilotExecutor, "to_resources") as mock_to_resources:
694+
mock_to_resources.return_value = MagicMock()
695+
696+
executor.to_task("test_task")
697+
698+
# Verify set_storage_mounts was NOT called
699+
mock_task_instance.set_storage_mounts.assert_not_called()
700+
701+
# Verify file_mounts still works
702+
mock_task_instance.set_file_mounts.assert_called_once()
703+
704+
def test_init_with_storage_mounts(self, mock_skypilot_imports):
705+
# Test initialization with storage_mounts parameter
706+
executor = SkypilotExecutor(
707+
container_image="test:latest",
708+
storage_mounts={
709+
"/workspace/data": {
710+
"name": "training-data",
711+
"store": "s3",
712+
"mode": "MOUNT",
713+
}
714+
},
715+
)
716+
717+
assert executor.storage_mounts is not None
718+
assert "/workspace/data" in executor.storage_mounts
719+
assert executor.storage_mounts["/workspace/data"]["name"] == "training-data"
720+
assert executor.storage_mounts["/workspace/data"]["store"] == "s3"

0 commit comments

Comments
 (0)