44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7- """Remote resource  allocation and provisioning.""" 
7+ """Resource  allocation and provisioning for both local and remote .""" 
88import  asyncio 
99import  functools 
1010import  logging 
@@ -160,20 +160,40 @@ async def get_proc_mesh(
160160        mesh_name : Optional [str ] =  None ,
161161        host_mesh : HostMesh  |  None  =  None ,
162162        env_vars : dict [str , str ] |  None  =  None ,
163+         addr : str  |  None  =  None ,
164+         port : str  |  None  =  None ,
163165    ):
164166        """Gets a proc mesh. 
165167
166-         num_hosts = None implies that you want a local allocation, this may change. 
168+         Args: 
169+             num_procs: The number of processes to allocate. 
170+             with_gpus: Whether to include GPU allocations. 
171+                 This only adds the CUDA_VISIBLE_DEVICES environment variable. 
172+             num_hosts: The number of hosts to allocate. 
173+                 If this is set, a remote allocation is created. 
174+                 If this is None, it uses the local host. 
175+                 This behavior may change in the future. 
176+             host_mesh: The host mesh to allocate the process on. 
177+                 If None, a new host mesh will be created. 
178+             port: The distributed port to use. 
179+                 If None, a port will be detected. 
180+             addr: The distributed address to use. 
181+                 If None, an address will be detected. 
182+ 
183+         Returns: 
184+             A proc mesh. 
167185
168186        """ 
169187        if  env_vars  is  None :
170188            env_vars  =  {}
171189
190+         is_remote  =  num_hosts  is  not None  and  num_hosts  >  0 
191+ 
172192        async  with  self ._lock :
173193            server_name  =  None 
174-             if  num_hosts  is  not None  and  num_hosts  >  0 :
175-                 created_hosts  =  len (self ._server_names )
194+             if  is_remote :
176195                if  mesh_name  is  None :
196+                     created_hosts  =  len (self ._server_names )
177197                    mesh_name  =  f"alloc_{ created_hosts }  
178198                if  host_mesh  is  None :
179199                    host_mesh , server_name  =  await  self .create_host_mesh (
@@ -188,18 +208,22 @@ async def get_proc_mesh(
188208                    host_id  =  host_mesh ._host_id 
189209                    gpu_manager  =  self ._host_gpu_map [host_id ]
190210            else :
211+                 # fallback to local 
191212                host_mesh  =  this_host ()
192213                gpu_manager  =  self ._host_gpu_map [self ._this_host_id ]
193214                host_mesh ._host_id  =  self ._this_host_id 
194215
195216            def  bootstrap (env : dict [str , str ]):
217+                 # bootstrap is run on all processes. We use this 
218+                 # to set environment variables like CUDA etc. 
196219                import  os 
197220
198221                for  k , v  in  env .items ():
199222                    os .environ [k ] =  v 
200223
201224            if  with_gpus :
202-                 addr , port  =  await  get_remote_info (host_mesh )
225+                 if  not  addr  or  not  port :
226+                     addr , port  =  await  get_remote_info (host_mesh )
203227                gpu_ids  =  gpu_manager .get_gpus (num_procs )
204228
205229                env_vars ["MASTER_ADDR" ] =  addr 
@@ -213,7 +237,9 @@ def bootstrap(env: dict[str, str]):
213237                per_host = {"gpus" : num_procs },
214238                bootstrap = functools .partial (bootstrap , env = env_vars ),
215239            )
216-             await  self .launcher .remote_setup (procs )
240+ 
241+             if  is_remote :
242+                 await  self .launcher .remote_setup (procs )
217243
218244            # Tag the proc mesh with additional metadata for our own cleanup later 
219245            if  with_gpus :
@@ -284,8 +310,24 @@ async def get_proc_mesh(
284310    process_config : ProcessConfig ,
285311    host_mesh : HostMesh  |  None  =  None ,
286312    env_vars : dict [str , str ] |  None  =  None ,
313+     port : str  |  None  =  None ,
314+     addr : str  |  None  =  None ,
287315) ->  ProcMesh :
288-     """Returns a proc mesh from the provisioner.""" 
316+     """Returns a proc mesh from the provisioner. 
317+ 
318+     Args: 
319+         process_config: The process config. 
320+         host_mesh: The host mesh to allocate the process on. 
321+             If None, a new host mesh will be created. 
322+         port: The distributed port to use. 
323+             If None, a port will be detected. 
324+         addr: The distributed address to use. 
325+             If None, an address will be detected. 
326+ 
327+     Returns: 
328+         A proc mesh. 
329+ 
330+     """ 
289331    provisioner  =  await  _get_provisioner ()
290332    return  await  provisioner .get_proc_mesh (
291333        num_procs = process_config .procs ,
@@ -294,6 +336,8 @@ async def get_proc_mesh(
294336        mesh_name = process_config .mesh_name ,
295337        host_mesh = host_mesh ,
296338        env_vars = env_vars ,
339+         port = port ,
340+         addr = addr ,
297341    )
298342
299343
0 commit comments