66import threading
77import time
88from abc import ABC , abstractmethod
9+ from tempfile import mkstemp
910from types import FrameType
1011from typing import Optional , Dict , List , Tuple , Type , cast
1112
1213import psutil
1314
14- from psij import InvalidJobException , SubmitException , Launcher
15+ from psij import InvalidJobException , SubmitException , Launcher , ResourceSpecV1
1516from psij import Job , JobSpec , JobExecutorConfig , JobState , JobStatus
1617from psij import JobExecutor
1718from psij .utils import SingletonThread
@@ -67,6 +68,7 @@ class _ChildProcessEntry(_ProcessEntry):
6768 def __init__ (self , job : Job , executor : 'LocalJobExecutor' ,
6869 launcher : Optional [Launcher ]) -> None :
6970 super ().__init__ (job , executor , launcher )
71+ self .nodefile : Optional [str ] = None
7072
7173 def kill (self ) -> None :
7274 super ().kill ()
@@ -75,6 +77,8 @@ def poll(self) -> Tuple[Optional[int], Optional[str]]:
7577 assert self .process is not None
7678 exit_code = self .process .poll ()
7779 if exit_code is not None :
80+ if self .nodefile :
81+ os .unlink (self .nodefile )
7882 if self .process .stdout :
7983 return exit_code , self .process .stdout .read ().decode ('utf-8' )
8084 else :
@@ -103,19 +107,30 @@ def poll(self) -> Tuple[Optional[int], Optional[str]]:
103107 return None , None
104108
105109
106- def _get_env (spec : JobSpec ) -> Optional [Dict [str , str ]]:
110+ def _get_env (spec : JobSpec , nodefile : Optional [str ]) -> Optional [Dict [str , str ]]:
111+ env : Optional [Dict [str , str ]] = None
107112 if spec .inherit_environment :
108- if not spec .environment :
113+ if spec .environment is None and nodefile is None :
109114 # if env is none in Popen, it inherits env from parent
110115 return None
111116 else :
112117 # merge current env with spec env
113118 env = os .environ .copy ()
114- env .update (spec .environment )
119+ if spec .environment :
120+ env .update (spec .environment )
121+ if nodefile is not None :
122+ env ['PSIJ_NODEFILE' ] = nodefile
115123 return env
116124 else :
117125 # only spec env
118- return spec .environment
126+ if nodefile is None :
127+ env = spec .environment
128+ else :
129+ env = {'PSIJ_NODEFILE' : nodefile }
130+ if spec .environment :
131+ env .update (spec .environment )
132+
133+ return env
119134
120135
121136class _ProcessReaper (SingletonThread ):
@@ -222,6 +237,26 @@ def __init__(self, url: Optional[str] = None,
222237 super ().__init__ (url = url , config = config if config else JobExecutorConfig ())
223238 self ._reaper = _ProcessReaper .get_instance ()
224239
240+ def _generate_nodefile (self , job : Job , p : _ChildProcessEntry ) -> Optional [str ]:
241+ assert job .spec is not None
242+ if job .spec .resources is None :
243+ return None
244+ if job .spec .resources .version == 1 :
245+ assert isinstance (job .spec .resources , ResourceSpecV1 )
246+ n = job .spec .resources .computed_process_count
247+ if n == 1 :
248+ # as a bit of an optimization, we don't generate a nodefile when doing "single
249+ # node" jobs on local.
250+ return None
251+ (file , p .nodefile ) = mkstemp (suffix = '.nodelist' )
252+ for i in range (n ):
253+ os .write (file , 'localhost\n ' .encode ())
254+ os .close (file )
255+ return p .nodefile
256+ else :
257+ raise SubmitException ('Cannot handle resource specification with version %s'
258+ % job .spec .resources .version )
259+
225260 def submit (self , job : Job ) -> None :
226261 """
227262 Submits the specified :class:`~psij.Job` to be run locally.
@@ -245,8 +280,10 @@ def submit(self, job: Job) -> None:
245280 if job .status .state == JobState .CANCELED :
246281 raise SubmitException ('Job canceled' )
247282 logger .debug ('Running %s, out=%s, err=%s' , args , spec .stdout_path , spec .stderr_path )
283+ nodefile = self ._generate_nodefile (job , p )
284+ env = _get_env (spec , nodefile )
248285 p .process = subprocess .Popen (args , stdout = subprocess .PIPE , stderr = subprocess .STDOUT ,
249- close_fds = True , cwd = spec .directory , env = _get_env ( spec ) )
286+ close_fds = True , cwd = spec .directory , env = env )
250287 self ._reaper .register (p )
251288 job ._native_id = p .process .pid
252289 self ._set_job_status (job , JobStatus (JobState .QUEUED , time = time .time (),
0 commit comments