Skip to content

Commit 2806a66

Browse files
authored
Fix multi gpu torchrun in Skypilot (#100)
1 parent 38265c4 commit 2806a66

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

src/nemo_run/core/execution/skypilot.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class SkypilotExecutor(Executor):
8989
region: Optional[Union[str, list[str]]] = None
9090
zone: Optional[Union[str, list[str]]] = None
9191
gpus: Optional[Union[str, list[str]]] = None
92-
gpus_per_node: Optional[Union[int, list[int]]] = None
92+
gpus_per_node: Optional[int] = None
9393
cpus: Optional[Union[int | float, list[int | float]]] = None
9494
memory: Optional[Union[int | float, list[int | float]]] = None
9595
instance_type: Optional[Union[str, list[str]]] = None
@@ -103,6 +103,7 @@ class SkypilotExecutor(Executor):
103103
setup: Optional[str] = None
104104
autodown: bool = False
105105
idle_minutes_to_autostop: Optional[int] = None
106+
torchrun_nproc_per_node: Optional[int] = None
106107
packager: Packager = field(default_factory=lambda: GitArchivePackager()) # type: ignore # noqa: F821
107108

108109
def __post_init__(self):
@@ -126,24 +127,16 @@ def to_resources(self) -> Union[set["sky.Resources"], set["sky.Resources"]]:
126127
resources_cfg = {}
127128
accelerators = None
128129
if self.gpus:
129-
if isinstance(self.gpus, str):
130-
if not self.gpus_per_node:
131-
self.gpus_per_node = 1
132-
133-
assert isinstance(self.gpus_per_node, int)
134-
gpus, gpus_per_node = [self.gpus], [self.gpus_per_node]
130+
if not self.gpus_per_node:
131+
self.gpus_per_node = 1
135132
else:
136-
if not self.gpus_per_node:
137-
self.gpus_per_node = [1 for _ in self.gpus]
133+
assert isinstance(self.gpus_per_node, int)
138134

139-
assert isinstance(self.gpus_per_node, list) and len(self.gpus) == len(
140-
self.gpus_per_node
141-
)
142-
gpus, gpus_per_node = self.gpus, self.gpus_per_node
135+
gpus = [self.gpus] if isinstance(self.gpus, str) else self.gpus
143136

144137
accelerators = {}
145-
for gpu, count in zip(gpus, gpus_per_node):
146-
accelerators[gpu] = count
138+
for gpu in gpus:
139+
accelerators[gpu] = self.gpus_per_node
147140

148141
resources_cfg["accelerators"] = accelerators
149142

@@ -319,7 +312,10 @@ def nnodes(self) -> int:
319312
return self.num_nodes
320313

321314
def nproc_per_node(self) -> int:
322-
return 1
315+
if self.torchrun_nproc_per_node:
316+
return self.torchrun_nproc_per_node
317+
318+
return self.gpus_per_node or 1
323319

324320
def macro_values(self) -> Optional[ExecutorMacros]:
325321
return ExecutorMacros(

0 commit comments

Comments
 (0)