Skip to content

Commit 0b33d98

Browse files
committed
update skypilot to use volume mounts and launcher
Signed-off-by: ansjindal <[email protected]>
1 parent 6d6c0b2 commit 0b33d98

File tree

3 files changed

+127
-4
lines changed

3 files changed

+127
-4
lines changed

docs/source/guides/execution.md

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,23 @@ def your_skypilot_executor(nodes: int, devices: int, container_image: str):
205205
return SkypilotExecutor(
206206
gpus="RTX5880-ADA-GENERATION",
207207
gpus_per_node=devices,
208-
nodes = nodes
209-
env_vars=common_envs()
208+
num_nodes = nodes,
209+
env_vars=common_envs(),
210210
container_image=container_image,
211-
cloud="kubernetes",
211+
infra="k8s/mycontext",
212212
# Optional to reuse Skypilot cluster
213213
cluster_name="tester",
214+
volumes={
215+
"nemo-workspace": "nemo-workspace"
216+
},
217+
volume_mounts=[
218+
{
219+
"path": "/data",
220+
"volume_name": "nemo-workspace",
221+
"size": "50Gi",
222+
"type": "k8s-pvc"
223+
}
224+
],
214225
setup="""
215226
conda deactivate
216227
nvidia-smi

nemo_run/core/execution/skypilot.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import subprocess
1919
from dataclasses import dataclass, field
2020
from pathlib import Path
21-
from typing import Any, Optional, Type, Union
21+
from typing import Any, Dict, List, Optional, Type, Union
2222

2323
from invoke.context import Context
2424

@@ -37,6 +37,8 @@
3737
import sky.task as skyt
3838
from sky import backends
3939
from sky.utils import status_lib
40+
from sky.volumes import volume as volume_lib
41+
from sky import models
4042

4143
_SKYPILOT_AVAILABLE = True
4244
except ImportError:
@@ -95,6 +97,8 @@ class SkypilotExecutor(Executor):
9597
memory: Optional[Union[int | float, list[int | float]]] = None
9698
instance_type: Optional[Union[str, list[str]]] = None
9799
num_nodes: int = 1
100+
volumes: Optional[Dict[str, str]] = None
101+
volume_mounts: Optional[List[Any]] = None
98102
use_spot: Optional[Union[bool, list[bool]]] = None
99103
disk_size: Optional[Union[int, list[int]]] = None
100104
disk_tier: Optional[Union[str, list[str]]] = None
@@ -343,6 +347,14 @@ def macro_values(self) -> Optional[ExecutorMacros]:
343347
)
344348

345349
def _setup_launcher(self):
350+
# Auto-enable torchrun for distributed training scenarios:
351+
# 1. Multi-node training (num_nodes > 1)
352+
# 2. Single-node multi-GPU training (gpus_per_node > 1)
353+
if self.launcher is None and (
354+
self.num_nodes > 1 or (self.gpus_per_node and self.gpus_per_node > 1)
355+
):
356+
self.launcher = "torchrun"
357+
346358
super()._setup_launcher()
347359
launcher = self.launcher
348360
# Dynamic rendezvous has an error in Skypilot Kubernetes currently
@@ -354,6 +366,53 @@ def _setup_launcher(self):
354366
launcher.rdzv_backend = "static"
355367
launcher.rdzv_port = 49500
356368

369+
def supports_launcher_transform(self) -> bool:
370+
return True
371+
372+
def _parse_infra_for_volume_config(self) -> dict:
373+
"""Parse infra string and return volume config parameters."""
374+
config = {}
375+
376+
if self.infra is not None:
377+
# Parse infra string to extract cloud, region, zone components
378+
# Format: cloud, cloud/region, cloud/region/zone, k8s/context
379+
infra_parts = self.infra.split("/")
380+
cloud = infra_parts[0] if infra_parts else None
381+
382+
if cloud:
383+
# Special handling for Kubernetes
384+
if cloud == "k8s":
385+
# VolumeConfig region and zone required even though they are marked as optional
386+
# validation fails otherwise
387+
config["cloud"] = "kubernetes"
388+
config["region"] = "kubernetes"
389+
config["zone"] = "kubernetes"
390+
else:
391+
# Handle regular cloud providers
392+
config["cloud"] = cloud
393+
394+
# Handle region for non-k8s clouds
395+
if len(infra_parts) >= 2:
396+
region = infra_parts[1]
397+
if region and region != "*": # Skip wildcards
398+
config["region"] = region
399+
400+
# Handle zone for non-k8s clouds
401+
if len(infra_parts) >= 3:
402+
zone = infra_parts[2]
403+
if zone and zone != "*": # Skip wildcards
404+
config["zone"] = zone
405+
else:
406+
# Fall back to individual cloud, region, zone parameters
407+
if self.cloud:
408+
config["cloud"] = self.cloud
409+
if self.region:
410+
config["region"] = self.region
411+
if self.zone:
412+
config["zone"] = self.zone
413+
414+
return config
415+
357416
def to_task(
358417
self,
359418
name: str,
@@ -377,16 +436,43 @@ def to_task(
377436
378437
{" ".join(cmd)}
379438
"""
439+
380440
task = Task(
381441
name=name,
382442
setup=self.setup if self.setup else "",
383443
run=run_cmd,
384444
envs=self.env_vars,
385445
num_nodes=self.num_nodes,
446+
volumes=self.volumes,
386447
)
448+
387449
file_mounts = self.file_mounts or {}
388450
file_mounts["/nemo_run"] = self.job_dir
389451
task.set_file_mounts(file_mounts)
452+
task.set_volumes(self.volumes)
453+
454+
volume_mounts = []
455+
if self.volume_mounts:
456+
for volume_mount in self.volume_mounts:
457+
# Configure volume based on infra if specified, otherwise use cloud/region/zone
458+
volume_config_kwargs = {
459+
"name": volume_mount["volume_name"],
460+
"type": volume_mount["type"],
461+
"name_on_cloud": volume_mount["volume_name"],
462+
"size": volume_mount["size"],
463+
}
464+
465+
# Add parsed infra configuration
466+
volume_config_kwargs.update(self._parse_infra_for_volume_config())
467+
468+
volume_mounts.append(
469+
volume_lib.VolumeMount(
470+
path=volume_mount["path"],
471+
volume_name=volume_mount["volume_name"],
472+
volume_config=models.VolumeConfig(**volume_config_kwargs),
473+
)
474+
)
475+
task.volume_mounts = volume_mounts
390476
task.set_resources(self.to_resources())
391477

392478
if env_vars:

test/core/execution/test_skypilot.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,29 @@ def test_to_task(self, mock_task, mock_skypilot_imports, executor):
561561

562562
# Verify the returned task is our mock
563563
assert result == mock_task_instance
564+
565+
def test_parse_infra_for_volume_config(self, mock_skypilot_imports):
566+
"""Test the _parse_infra_for_volume_config helper method."""
567+
568+
# Test k8s infra
569+
executor_k8s = SkypilotExecutor(infra="k8s/my-context")
570+
config = executor_k8s._parse_infra_for_volume_config()
571+
assert config["cloud"] == "kubernetes"
572+
assert config["region"] == "kubernetes"
573+
assert config["zone"] == "kubernetes"
574+
575+
# Test AWS infra with region and zone
576+
executor_aws = SkypilotExecutor(infra="aws/us-east-1/us-east-1a")
577+
config = executor_aws._parse_infra_for_volume_config()
578+
assert config["cloud"] == "aws"
579+
assert config["region"] == "us-east-1"
580+
assert config["zone"] == "us-east-1a"
581+
582+
# Test fallback to individual parameters
583+
executor_fallback = SkypilotExecutor(
584+
cloud="gcp", region="us-central1", zone="us-central1-a"
585+
)
586+
config = executor_fallback._parse_infra_for_volume_config()
587+
assert config["cloud"] == "gcp"
588+
assert config["region"] == "us-central1"
589+
assert config["zone"] == "us-central1-a"

0 commit comments

Comments
 (0)