1818import subprocess
1919from dataclasses import dataclass , field
2020from pathlib import Path
21- from typing import Any , Optional , Type , Union
21+ from typing import Any , Dict , List , Optional , Type , Union
2222
2323from invoke .context import Context
2424
3737 import sky .task as skyt
3838 from sky import backends
3939 from sky .utils import status_lib
40+ from sky .volumes import volume as volume_lib
41+ from sky import models
4042
4143 _SKYPILOT_AVAILABLE = True
4244except ImportError :
@@ -95,6 +97,8 @@ class SkypilotExecutor(Executor):
9597 memory : Optional [Union [int | float , list [int | float ]]] = None
9698 instance_type : Optional [Union [str , list [str ]]] = None
9799 num_nodes : int = 1
100+ volumes : Optional [Dict [str , str ]] = None
101+ volume_mounts : Optional [List [Any ]] = None
98102 use_spot : Optional [Union [bool , list [bool ]]] = None
99103 disk_size : Optional [Union [int , list [int ]]] = None
100104 disk_tier : Optional [Union [str , list [str ]]] = None
@@ -343,6 +347,14 @@ def macro_values(self) -> Optional[ExecutorMacros]:
343347 )
344348
345349 def _setup_launcher (self ):
350+ # Auto-enable torchrun for distributed training scenarios:
351+ # 1. Multi-node training (num_nodes > 1)
352+ # 2. Single-node multi-GPU training (gpus_per_node > 1)
353+ if self .launcher is None and (
354+ self .num_nodes > 1 or (self .gpus_per_node and self .gpus_per_node > 1 )
355+ ):
356+ self .launcher = "torchrun"
357+
346358 super ()._setup_launcher ()
347359 launcher = self .launcher
348360 # Dynamic rendezvous has an error in Skypilot Kubernetes currently
@@ -354,6 +366,53 @@ def _setup_launcher(self):
354366 launcher .rdzv_backend = "static"
355367 launcher .rdzv_port = 49500
356368
369+ def supports_launcher_transform (self ) -> bool :
370+ return True
371+
372+ def _parse_infra_for_volume_config (self ) -> dict :
373+ """Parse infra string and return volume config parameters."""
374+ config = {}
375+
376+ if self .infra is not None :
377+ # Parse infra string to extract cloud, region, zone components
378+ # Format: cloud, cloud/region, cloud/region/zone, k8s/context
379+ infra_parts = self .infra .split ("/" )
380+ cloud = infra_parts [0 ] if infra_parts else None
381+
382+ if cloud :
383+ # Special handling for Kubernetes
384+ if cloud == "k8s" :
385+ # VolumeConfig region and zone required even though they are marked as optional
386+ # validation fails otherwise
387+ config ["cloud" ] = "kubernetes"
388+ config ["region" ] = "kubernetes"
389+ config ["zone" ] = "kubernetes"
390+ else :
391+ # Handle regular cloud providers
392+ config ["cloud" ] = cloud
393+
394+ # Handle region for non-k8s clouds
395+ if len (infra_parts ) >= 2 :
396+ region = infra_parts [1 ]
397+ if region and region != "*" : # Skip wildcards
398+ config ["region" ] = region
399+
400+ # Handle zone for non-k8s clouds
401+ if len (infra_parts ) >= 3 :
402+ zone = infra_parts [2 ]
403+ if zone and zone != "*" : # Skip wildcards
404+ config ["zone" ] = zone
405+ else :
406+ # Fall back to individual cloud, region, zone parameters
407+ if self .cloud :
408+ config ["cloud" ] = self .cloud
409+ if self .region :
410+ config ["region" ] = self .region
411+ if self .zone :
412+ config ["zone" ] = self .zone
413+
414+ return config
415+
357416 def to_task (
358417 self ,
359418 name : str ,
@@ -377,16 +436,43 @@ def to_task(
377436
378437{ " " .join (cmd )}
379438"""
439+
380440 task = Task (
381441 name = name ,
382442 setup = self .setup if self .setup else "" ,
383443 run = run_cmd ,
384444 envs = self .env_vars ,
385445 num_nodes = self .num_nodes ,
446+ volumes = self .volumes ,
386447 )
448+
387449 file_mounts = self .file_mounts or {}
388450 file_mounts ["/nemo_run" ] = self .job_dir
389451 task .set_file_mounts (file_mounts )
452+ task .set_volumes (self .volumes )
453+
454+ volume_mounts = []
455+ if self .volume_mounts :
456+ for volume_mount in self .volume_mounts :
457+ # Configure volume based on infra if specified, otherwise use cloud/region/zone
458+ volume_config_kwargs = {
459+ "name" : volume_mount ["volume_name" ],
460+ "type" : volume_mount ["type" ],
461+ "name_on_cloud" : volume_mount ["volume_name" ],
462+ "size" : volume_mount ["size" ],
463+ }
464+
465+ # Add parsed infra configuration
466+ volume_config_kwargs .update (self ._parse_infra_for_volume_config ())
467+
468+ volume_mounts .append (
469+ volume_lib .VolumeMount (
470+ path = volume_mount ["path" ],
471+ volume_name = volume_mount ["volume_name" ],
472+ volume_config = models .VolumeConfig (** volume_config_kwargs ),
473+ )
474+ )
475+ task .volume_mounts = volume_mounts
390476 task .set_resources (self .to_resources ())
391477
392478 if env_vars :
0 commit comments