Skip to content

Commit 80f1569

Browse files
committed
BatchSpawnerBase: Add background_tasks, connect_to_job feature.
This adds the possibility to start a "connect_to_job" background task on the hub on job start, which establishes connectivity to the actual single user server. An example for this can be "condor_ssh_to_job" for HTCondor batch systems. Additionally, the background tasks are monitored: - for successful startup. The background task is given some time to successfully establish connectivity. - in poll() during job runtime and if they fail, the job is terminated.
1 parent 9ecccb6 commit 80f1569

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

batchspawner/batchspawner.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ def _req_keepvars_default(self):
172172
"specification."
173173
).tag(config=True)
174174

175+
connect_to_job_cmd = Unicode('',
176+
help="Command to connect to running batch job and forward the port "
177+
"of the running notebook to the Hub. If empty, direct connectivity is assumed. "
178+
"Uses self.job_id as {job_id} and the self.port as {port}."
179+
).tag(config=True)
180+
175181
# Raw output of job submission command unless overridden
176182
job_id = Unicode()
177183

@@ -200,6 +206,18 @@ def cmd_formatted_for_batch(self):
200206
"""The command which is substituted inside of the batch script"""
201207
return ' '.join([self.batchspawner_singleuser_cmd] + self.cmd + self.get_args())
202208

209+
async def connect_to_job(self):
210+
"""This command ensures the port of the singleuser server is reachable from the
211+
Batchspawner machine. By default, it does nothing, i.e. direct connectivity
212+
is assumed.
213+
"""
214+
subvars = self.get_req_subvars()
215+
subvars['job_id'] = self.job_id
216+
subvars['port'] = self.port
217+
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
218+
format_template(self.connect_to_job_cmd, **subvars)))
219+
await self.run_background_command(cmd)
220+
203221
async def run_command(self, cmd, input=None, env=None):
204222
proc = await asyncio.create_subprocess_shell(cmd, env=env,
205223
stdin=asyncio.subprocess.PIPE,
@@ -243,6 +261,46 @@ async def run_command(self, cmd, input=None, env=None):
243261
out = out.decode().strip()
244262
return out
245263

264+
# List of running background processes, e.g. used by connect_to_job.
265+
background_processes = []
266+
267+
async def _async_wait_process(self, sleep_time):
268+
"""Asynchronously sleeping process for delayed checks"""
269+
await asyncio.sleep(sleep_time)
270+
271+
async def run_background_command(self, cmd, startup_check_delay=1, input=None, env=None):
272+
"""Runs the given background command, adds it to background_processes,
273+
and checks if the command is still running after startup_check_delay."""
274+
background_process = self.run_command(cmd, input, env)
275+
success_check_delay = self._async_wait_process(startup_check_delay)
276+
277+
# Start up both the success check process and the actual process.
278+
done, pending = await asyncio.wait([background_process, success_check_delay], return_when=asyncio.FIRST_COMPLETED)
279+
280+
# If the success check process is the one which exited first, all is good, else fail.
281+
if list(done)[0]._coro == success_check_delay:
282+
background_task = list(pending)[0]
283+
self.background_processes.append(background_task)
284+
return background_task
285+
else:
286+
self.log.error("Background command exited early: %s" % cmd)
287+
gather_pending = asyncio.gather(*pending)
288+
gather_pending.cancel()
289+
try:
290+
self.log.debug("Cancelling pending success check task...")
291+
await gather_pending
292+
except asyncio.CancelledError:
293+
self.log.debug("Cancel was successful.")
294+
pass
295+
296+
# Retrieve exception from "done" process.
297+
try:
298+
gather_done = asyncio.gather(*done)
299+
await gather_done
300+
except:
301+
self.log.debug("Retrieving exception from failed background task...")
302+
raise RuntimeError('{} failed!'.format(cmd))
303+
246304
async def _get_batch_script(self, **subvars):
247305
"""Format batch script from vars"""
248306
# Could be overridden by subclasses, but mainly useful for testing
@@ -270,6 +328,27 @@ async def submit_batch_script(self):
270328
self.job_id = ''
271329
return self.job_id
272330

331+
def background_tasks_ok(self):
332+
# Check background processes.
333+
if self.background_processes:
334+
self.log.debug('Checking background processes...')
335+
for background_process in self.background_processes:
336+
if background_process.done():
337+
self.log.debug('Found a background process in state "done"...')
338+
try:
339+
background_exception = background_process.exception()
340+
except asyncio.CancelledError:
341+
self.log.error('Background process was cancelled!')
342+
if background_exception:
343+
self.log.error('Background process exited with an exception:')
344+
self.log.error(background_exception)
345+
self.log.error('At least one background process exited!')
346+
return False
347+
else:
348+
self.log.debug('Found a not-yet-done background process...')
349+
self.log.debug('All background processes still running.')
350+
return True
351+
273352
# Override if your batch system needs something more elaborate to query the job status
274353
batch_query_cmd = Unicode('',
275354
help="Command to run to query job status. Formatted using req_xyz traits as {xyz} "
@@ -314,6 +393,29 @@ async def cancel_batch_job(self):
314393
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
315394
format_template(self.batch_cancel_cmd, **subvars)))
316395
self.log.info('Cancelling job ' + self.job_id + ': ' + cmd)
396+
397+
if self.background_processes:
398+
self.log.debug('Job being cancelled, cancelling background processes...')
399+
for background_process in self.background_processes:
400+
if not background_process.cancelled():
401+
try:
402+
background_process.cancel()
403+
except:
404+
self.log.error('Encountered an exception cancelling background process...')
405+
self.log.debug('Cancelled background process, waiting for it to finish...')
406+
try:
407+
await asyncio.wait([background_process])
408+
except asyncio.CancelledError:
409+
self.log.error('Successfully cancelled background process.')
410+
pass
411+
except:
412+
self.log.error('Background process exited with another exception!')
413+
raise
414+
else:
415+
self.log.debug('Background process already cancelled...')
416+
self.background_processes.clear()
417+
self.log.debug('All background processes cancelled.')
418+
317419
await self.run_command(cmd)
318420

319421
def load_state(self, state):
@@ -361,6 +463,13 @@ async def poll(self):
361463
"""Poll the process"""
362464
status = await self.query_job_status()
363465
if status in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
466+
if not self.background_tasks_ok():
467+
self.log.debug('Going to stop job, since background tasks have failed!')
468+
await self.stop(now=True)
469+
status = await self.query_job_status()
470+
if status not in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
471+
self.clear_state()
472+
return 1
364473
return None
365474
else:
366475
self.clear_state()
@@ -420,6 +529,9 @@ async def start(self):
420529
self.job_id, self.ip, self.port)
421530
)
422531

532+
if self.connect_to_job_cmd:
533+
await self.connect_to_job()
534+
423535
return self.ip, self.port
424536

425537
async def stop(self, now=False):

0 commit comments

Comments
 (0)