@@ -561,3 +561,160 @@ def test_to_task(self, mock_task, mock_skypilot_imports, executor):
561561
562562 # Verify the returned task is our mock
563563 assert result == mock_task_instance
564+
565+ @patch ("sky.task.Task" )
566+ def test_to_task_with_storage_mounts (self , mock_task , mock_skypilot_imports ):
567+ # Create a mock task instance
568+ mock_task_instance = MagicMock ()
569+ mock_task .return_value = mock_task_instance
570+ mock_task_instance .set_file_mounts = MagicMock ()
571+ mock_task_instance .set_storage_mounts = MagicMock ()
572+ mock_task_instance .set_resources = MagicMock ()
573+
574+ # Mock sky.data.Storage
575+ mock_storage_class = MagicMock ()
576+ mock_storage_obj = MagicMock ()
577+ mock_storage_class .from_yaml_config .return_value = mock_storage_obj
578+
579+ executor = SkypilotExecutor (
580+ container_image = "test:latest" ,
581+ storage_mounts = {
582+ "/workspace/outputs" : {
583+ "name" : "my-outputs" ,
584+ "store" : "gcs" ,
585+ "mode" : "MOUNT" ,
586+ "persistent" : True ,
587+ }
588+ },
589+ )
590+
591+ with tempfile .TemporaryDirectory () as tmp_dir :
592+ executor .job_dir = tmp_dir
593+
594+ with patch .object (SkypilotExecutor , "to_resources" ) as mock_to_resources :
595+ mock_to_resources .return_value = MagicMock ()
596+
597+ with patch ("sky.data.Storage" , mock_storage_class ):
598+ result = executor .to_task ("test_task" )
599+
600+ # Verify Storage.from_yaml_config was called with the config
601+ mock_storage_class .from_yaml_config .assert_called_once_with (
602+ {
603+ "name" : "my-outputs" ,
604+ "store" : "gcs" ,
605+ "mode" : "MOUNT" ,
606+ "persistent" : True ,
607+ }
608+ )
609+
610+ # Verify set_storage_mounts was called with Storage objects
611+ mock_task_instance .set_storage_mounts .assert_called_once ()
612+ storage_mounts_call = mock_task_instance .set_storage_mounts .call_args [0 ][0 ]
613+ assert "/workspace/outputs" in storage_mounts_call
614+ assert storage_mounts_call ["/workspace/outputs" ] == mock_storage_obj
615+
616+ @patch ("sky.task.Task" )
617+ def test_to_task_with_both_file_and_storage_mounts (self , mock_task , mock_skypilot_imports ):
618+ # Create a mock task instance
619+ mock_task_instance = MagicMock ()
620+ mock_task .return_value = mock_task_instance
621+ mock_task_instance .set_file_mounts = MagicMock ()
622+ mock_task_instance .set_storage_mounts = MagicMock ()
623+ mock_task_instance .set_resources = MagicMock ()
624+
625+ # Mock sky.data.Storage
626+ mock_storage_class = MagicMock ()
627+ mock_storage_obj = MagicMock ()
628+ mock_storage_class .from_yaml_config .return_value = mock_storage_obj
629+
630+ executor = SkypilotExecutor (
631+ container_image = "test:latest" ,
632+ file_mounts = {
633+ "/workspace/code" : "/local/path/to/code" ,
634+ },
635+ storage_mounts = {
636+ "/workspace/outputs" : {
637+ "name" : "my-outputs" ,
638+ "store" : "s3" ,
639+ "mode" : "MOUNT" ,
640+ },
641+ "/workspace/checkpoints" : {
642+ "name" : "my-checkpoints" ,
643+ "store" : "gcs" ,
644+ "mode" : "MOUNT_CACHED" ,
645+ },
646+ },
647+ )
648+
649+ with tempfile .TemporaryDirectory () as tmp_dir :
650+ executor .job_dir = tmp_dir
651+
652+ with patch .object (SkypilotExecutor , "to_resources" ) as mock_to_resources :
653+ mock_to_resources .return_value = MagicMock ()
654+
655+ with patch ("sky.data.Storage" , mock_storage_class ):
656+ result = executor .to_task ("test_task" )
657+
658+ # Verify file_mounts includes both user files and nemo_run
659+ file_mounts_call = mock_task_instance .set_file_mounts .call_args [0 ][0 ]
660+ assert "/workspace/code" in file_mounts_call
661+ assert file_mounts_call ["/workspace/code" ] == "/local/path/to/code"
662+ assert "/nemo_run" in file_mounts_call
663+ assert file_mounts_call ["/nemo_run" ] == tmp_dir
664+
665+ # Verify Storage.from_yaml_config was called for both storage mounts
666+ assert mock_storage_class .from_yaml_config .call_count == 2
667+
668+ # Verify set_storage_mounts was called with both Storage objects
669+ mock_task_instance .set_storage_mounts .assert_called_once ()
670+ storage_mounts_call = mock_task_instance .set_storage_mounts .call_args [0 ][0 ]
671+ assert "/workspace/outputs" in storage_mounts_call
672+ assert "/workspace/checkpoints" in storage_mounts_call
673+ assert len (storage_mounts_call ) == 2
674+
675+ @patch ("sky.task.Task" )
676+ def test_to_task_without_storage_mounts (self , mock_task , mock_skypilot_imports ):
677+ # Test that set_storage_mounts is not called when storage_mounts is None
678+ mock_task_instance = MagicMock ()
679+ mock_task .return_value = mock_task_instance
680+ mock_task_instance .set_file_mounts = MagicMock ()
681+ mock_task_instance .set_storage_mounts = MagicMock ()
682+ mock_task_instance .set_resources = MagicMock ()
683+
684+ executor = SkypilotExecutor (
685+ container_image = "test:latest" ,
686+ file_mounts = {"/workspace/code" : "/local/path" },
687+ storage_mounts = None , # Explicitly set to None
688+ )
689+
690+ with tempfile .TemporaryDirectory () as tmp_dir :
691+ executor .job_dir = tmp_dir
692+
693+ with patch .object (SkypilotExecutor , "to_resources" ) as mock_to_resources :
694+ mock_to_resources .return_value = MagicMock ()
695+
696+ result = executor .to_task ("test_task" )
697+
698+ # Verify set_storage_mounts was NOT called
699+ mock_task_instance .set_storage_mounts .assert_not_called ()
700+
701+ # Verify file_mounts still works
702+ mock_task_instance .set_file_mounts .assert_called_once ()
703+
704+ def test_init_with_storage_mounts (self , mock_skypilot_imports ):
705+ # Test initialization with storage_mounts parameter
706+ executor = SkypilotExecutor (
707+ container_image = "test:latest" ,
708+ storage_mounts = {
709+ "/workspace/data" : {
710+ "name" : "training-data" ,
711+ "store" : "s3" ,
712+ "mode" : "MOUNT" ,
713+ }
714+ },
715+ )
716+
717+ assert executor .storage_mounts is not None
718+ assert "/workspace/data" in executor .storage_mounts
719+ assert executor .storage_mounts ["/workspace/data" ]["name" ] == "training-data"
720+ assert executor .storage_mounts ["/workspace/data" ]["store" ] == "s3"
0 commit comments