11"""This module contains the local :class:`~psij.JobExecutor`."""
22import logging
33import os
4+ import shlex
45import signal
56import subprocess
67import threading
78import time
89from abc import ABC , abstractmethod
10+ from tempfile import mkstemp
911from types import FrameType
1012from typing import Optional , Dict , List , Tuple , Type , cast
1113
1214import psutil
1315
14- from psij import InvalidJobException , SubmitException , Launcher
16+ from psij import InvalidJobException , SubmitException , Launcher , ResourceSpecV1
1517from psij import Job , JobSpec , JobExecutorConfig , JobState , JobStatus
1618from psij import JobExecutor
1719from psij .utils import SingletonThread
1820
1921logger = logging .getLogger (__name__ )
2022
2123
24+ def _format_shell_cmd (args : List [str ]) -> str :
25+ """Formats an argument list in a way that allows it to be pasted in a shell."""
26+ cmd = ''
27+ for arg in args :
28+ cmd += shlex .quote (arg )
29+ cmd += ' '
30+ return cmd
31+
32+
2233def _handle_sigchld (signum : int , frame : Optional [FrameType ]) -> None :
2334 _ProcessReaper .get_instance ()._handle_sigchld ()
2435
@@ -67,6 +78,7 @@ class _ChildProcessEntry(_ProcessEntry):
6778 def __init__ (self , job : Job , executor : 'LocalJobExecutor' ,
6879 launcher : Optional [Launcher ]) -> None :
6980 super ().__init__ (job , executor , launcher )
81+ self .nodefile : Optional [str ] = None
7082
7183 def kill (self ) -> None :
7284 super ().kill ()
@@ -75,6 +87,8 @@ def poll(self) -> Tuple[Optional[int], Optional[str]]:
7587 assert self .process is not None
7688 exit_code = self .process .poll ()
7789 if exit_code is not None :
90+ if self .nodefile :
91+ os .unlink (self .nodefile )
7892 if self .process .stdout :
7993 return exit_code , self .process .stdout .read ().decode ('utf-8' )
8094 else :
@@ -103,19 +117,30 @@ def poll(self) -> Tuple[Optional[int], Optional[str]]:
103117 return None , None
104118
105119
106- def _get_env (spec : JobSpec ) -> Optional [Dict [str , str ]]:
120+ def _get_env (spec : JobSpec , nodefile : Optional [str ]) -> Optional [Dict [str , str ]]:
121+ env : Optional [Dict [str , str ]] = None
107122 if spec .inherit_environment :
108- if not spec .environment :
123+ if spec .environment is None and nodefile is None :
109124 # if env is none in Popen, it inherits env from parent
110125 return None
111126 else :
112127 # merge current env with spec env
113128 env = os .environ .copy ()
114- env .update (spec .environment )
129+ if spec .environment :
130+ env .update (spec .environment )
131+ if nodefile is not None :
132+ env ['PSIJ_NODEFILE' ] = nodefile
115133 return env
116134 else :
117135 # only spec env
118- return spec .environment
136+ if nodefile is None :
137+ env = spec .environment
138+ else :
139+ env = {'PSIJ_NODEFILE' : nodefile }
140+ if spec .environment :
141+ env .update (spec .environment )
142+
143+ return env
119144
120145
121146class _ProcessReaper (SingletonThread ):
@@ -222,6 +247,26 @@ def __init__(self, url: Optional[str] = None,
222247 super ().__init__ (url = url , config = config if config else JobExecutorConfig ())
223248 self ._reaper = _ProcessReaper .get_instance ()
224249
250+ def _generate_nodefile (self , job : Job , p : _ChildProcessEntry ) -> Optional [str ]:
251+ assert job .spec is not None
252+ if job .spec .resources is None :
253+ return None
254+ if job .spec .resources .version == 1 :
255+ assert isinstance (job .spec .resources , ResourceSpecV1 )
256+ n = job .spec .resources .computed_process_count
257+ if n == 1 :
258+ # as a bit of an optimization, we don't generate a nodefile when doing "single
259+ # node" jobs on local.
260+ return None
261+ (file , p .nodefile ) = mkstemp (suffix = '.nodelist' )
262+ for i in range (n ):
263+ os .write (file , 'localhost\n ' .encode ())
264+ os .close (file )
265+ return p .nodefile
266+ else :
267+ raise SubmitException ('Cannot handle resource specification with version %s'
268+ % job .spec .resources .version )
269+
225270 def submit (self , job : Job ) -> None :
226271 """
227272 Submits the specified :class:`~psij.Job` to be run locally.
@@ -244,9 +289,12 @@ def submit(self, job: Job) -> None:
244289 with job ._status_cv :
245290 if job .status .state == JobState .CANCELED :
246291 raise SubmitException ('Job canceled' )
247- logger .debug ('Running %s, out=%s, err=%s' , args , spec .stdout_path , spec .stderr_path )
292+ if logger .isEnabledFor (logging .DEBUG ):
293+ logger .debug ('Running %s' , _format_shell_cmd (args ))
294+ nodefile = self ._generate_nodefile (job , p )
295+ env = _get_env (spec , nodefile )
248296 p .process = subprocess .Popen (args , stdout = subprocess .PIPE , stderr = subprocess .STDOUT ,
249- close_fds = True , cwd = spec .directory , env = _get_env ( spec ) )
297+ close_fds = True , cwd = spec .directory , env = env )
250298 self ._reaper .register (p )
251299 job ._native_id = p .process .pid
252300 self ._set_job_status (job , JobStatus (JobState .QUEUED , time = time .time (),
0 commit comments