1818import subprocess
1919from dataclasses import dataclass , field
2020from pathlib import Path
21- from typing import Any , Optional , Type , Union
21+ from typing import Any , Dict , List , Optional , Type , Union
2222
2323from invoke .context import Context
2424
3737 import sky .task as skyt
3838 from sky .utils import status_lib
3939 from sky import backends
40+ from sky .volumes import volume as volume_lib
41+ from sky import models
4042
4143 _SKYPILOT_AVAILABLE = True
4244except ImportError :
@@ -94,6 +96,8 @@ class SkypilotExecutor(Executor):
9496 memory : Optional [Union [int | float , list [int | float ]]] = None
9597 instance_type : Optional [Union [str , list [str ]]] = None
9698 num_nodes : int = 1
99+ volumes : Optional [Dict [str , str ]] = None
100+ volume_mounts : Optional [List [Any ]] = None
97101 use_spot : Optional [Union [bool , list [bool ]]] = None
98102 disk_size : Optional [Union [int , list [int ]]] = None
99103 disk_tier : Optional [Union [str , list [str ]]] = None
@@ -108,12 +112,12 @@ class SkypilotExecutor(Executor):
108112 packager : Packager = field (default_factory = lambda : GitArchivePackager ()) # type: ignore # noqa: F821
109113
110114 def __post_init__ (self ):
111- assert _SKYPILOT_AVAILABLE , (
112- 'Skypilot is not installed. Please install it using `pip install "nemo_run[skypilot]"`.'
113- )
114- assert isinstance (self . packager , GitArchivePackager ), (
115- "Only GitArchivePackager is currently supported for SkypilotExecutor."
116- )
115+ assert (
116+ _SKYPILOT_AVAILABLE
117+ ), 'Skypilot is not installed. Please install it using `pip install "nemo_run[skypilot]"`.'
118+ assert isinstance (
119+ self . packager , GitArchivePackager
120+ ), "Only GitArchivePackager is currently supported for SkypilotExecutor."
117121
118122 @classmethod
119123 def parse_app (cls : Type ["SkypilotExecutor" ], app_id : str ) -> tuple [str , str , int ]:
@@ -331,6 +335,14 @@ def macro_values(self) -> Optional[ExecutorMacros]:
331335 )
332336
333337 def _setup_launcher (self ):
338+ # Auto-enable torchrun for distributed training scenarios:
339+ # 1. Multi-node training (num_nodes > 1)
340+ # 2. Single-node multi-GPU training (gpus_per_node > 1)
341+ if self .launcher is None and (
342+ self .num_nodes > 1 or (self .gpus_per_node and self .gpus_per_node > 1 )
343+ ):
344+ self .launcher = "torchrun"
345+
334346 super ()._setup_launcher ()
335347 launcher = self .launcher
336348 # Dynamic rendezvous has an error in Skypilot Kubernetes currently
@@ -342,6 +354,9 @@ def _setup_launcher(self):
342354 launcher .rdzv_backend = "static"
343355 launcher .rdzv_port = 49500
344356
357+ def supports_launcher_transform (self ) -> bool :
358+ return True
359+
345360 def to_task (
346361 self ,
347362 name : str ,
@@ -361,20 +376,48 @@ def to_task(
361376head_node_ip=`echo "$SKYPILOT_NODE_IPS" | head -n1`
362377echo "head_node_ip=$head_node_ip"
363378
379+ export MASTER_ADDR=$head_node_ip
380+ export WORLD_SIZE=$num_nodes
381+ export RANK=`echo "$SKYPILOT_NODE_RANK"`
382+ echo "MASTER_ADDR=$MASTER_ADDR"
364383cd /nemo_run/code
365384
366385{ " " .join (cmd )}
367386"""
387+
368388 task = Task (
369389 name = name ,
370390 setup = self .setup if self .setup else "" ,
371391 run = run_cmd ,
372392 envs = self .env_vars ,
373393 num_nodes = self .num_nodes ,
394+ volumes = self .volumes ,
374395 )
396+
375397 file_mounts = self .file_mounts or {}
376398 file_mounts ["/nemo_run" ] = self .job_dir
377399 task .set_file_mounts (file_mounts )
400+ task .set_volumes (self .volumes )
401+
402+ volume_mounts = []
403+ if self .volume_mounts :
404+ for volume_mount in self .volume_mounts :
405+ volume_mounts .append (
406+ volume_lib .VolumeMount (
407+ path = volume_mount ["path" ],
408+ volume_name = volume_mount ["volume_name" ],
409+ volume_config = models .VolumeConfig (
410+ name = volume_mount ["volume_name" ],
411+ type = volume_mount ["type" ],
412+ cloud = self .cloud ,
413+ region = self .region ,
414+ zone = self .zone ,
415+ name_on_cloud = volume_mount ["volume_name" ],
416+ size = volume_mount ["size" ],
417+ ),
418+ )
419+ )
420+ task .volume_mounts = volume_mounts
378421 task .set_resources (self .to_resources ())
379422
380423 if env_vars :
0 commit comments