@@ -89,7 +89,7 @@ class SkypilotExecutor(Executor):
8989 region : Optional [Union [str , list [str ]]] = None
9090 zone : Optional [Union [str , list [str ]]] = None
9191 gpus : Optional [Union [str , list [str ]]] = None
92- gpus_per_node : Optional [Union [ int , list [ int ]] ] = None
92+ gpus_per_node : Optional [int ] = None
9393 cpus : Optional [Union [int | float , list [int | float ]]] = None
9494 memory : Optional [Union [int | float , list [int | float ]]] = None
9595 instance_type : Optional [Union [str , list [str ]]] = None
@@ -103,6 +103,7 @@ class SkypilotExecutor(Executor):
103103 setup : Optional [str ] = None
104104 autodown : bool = False
105105 idle_minutes_to_autostop : Optional [int ] = None
106+ torchrun_nproc_per_node : Optional [int ] = None
106107 packager : Packager = field (default_factory = lambda : GitArchivePackager ()) # type: ignore # noqa: F821
107108
108109 def __post_init__ (self ):
@@ -126,24 +127,16 @@ def to_resources(self) -> Union[set["sky.Resources"], set["sky.Resources"]]:
126127 resources_cfg = {}
127128 accelerators = None
128129 if self .gpus :
129- if isinstance (self .gpus , str ):
130- if not self .gpus_per_node :
131- self .gpus_per_node = 1
132-
133- assert isinstance (self .gpus_per_node , int )
134- gpus , gpus_per_node = [self .gpus ], [self .gpus_per_node ]
130+ if not self .gpus_per_node :
131+ self .gpus_per_node = 1
135132 else :
136- if not self .gpus_per_node :
137- self .gpus_per_node = [1 for _ in self .gpus ]
133+ assert isinstance (self .gpus_per_node , int )
138134
139- assert isinstance (self .gpus_per_node , list ) and len (self .gpus ) == len (
140- self .gpus_per_node
141- )
142- gpus , gpus_per_node = self .gpus , self .gpus_per_node
135+ gpus = [self .gpus ] if isinstance (self .gpus , str ) else self .gpus
143136
144137 accelerators = {}
145- for gpu , count in zip ( gpus , gpus_per_node ) :
146- accelerators [gpu ] = count
138+ for gpu in gpus :
139+ accelerators [gpu ] = self . gpus_per_node
147140
148141 resources_cfg ["accelerators" ] = accelerators
149142
@@ -319,7 +312,10 @@ def nnodes(self) -> int:
319312 return self .num_nodes
320313
321314 def nproc_per_node (self ) -> int :
322- return 1
315+ if self .torchrun_nproc_per_node :
316+ return self .torchrun_nproc_per_node
317+
318+ return self .gpus_per_node or 1
323319
324320 def macro_values (self ) -> Optional [ExecutorMacros ]:
325321 return ExecutorMacros (
0 commit comments