@@ -172,6 +172,12 @@ def _req_keepvars_default(self):
172
172
"specification."
173
173
).tag (config = True )
174
174
175
+ connect_to_job_cmd = Unicode ('' ,
176
+ help = "Command to connect to running batch job and forward the port "
177
+ "of the running notebook to the Hub. If empty, direct connectivity is assumed. "
178
+ "Uses self.job_id as {job_id} and the self.port as {port}."
179
+ ).tag (config = True )
180
+
175
181
# Raw output of job submission command unless overridden
176
182
job_id = Unicode ()
177
183
@@ -200,6 +206,18 @@ def cmd_formatted_for_batch(self):
200
206
"""The command which is substituted inside of the batch script"""
201
207
return ' ' .join ([self .batchspawner_singleuser_cmd ] + self .cmd + self .get_args ())
202
208
209
+ async def connect_to_job (self ):
210
+ """This command ensures the port of the singleuser server is reachable from the
211
+ Batchspawner machine. By default, it does nothing, i.e. direct connectivity
212
+ is assumed.
213
+ """
214
+ subvars = self .get_req_subvars ()
215
+ subvars ['job_id' ] = self .job_id
216
+ subvars ['port' ] = self .port
217
+ cmd = ' ' .join ((format_template (self .exec_prefix , ** subvars ),
218
+ format_template (self .connect_to_job_cmd , ** subvars )))
219
+ await self .run_background_command (cmd )
220
+
203
221
async def run_command (self , cmd , input = None , env = None ):
204
222
proc = await asyncio .create_subprocess_shell (cmd , env = env ,
205
223
stdin = asyncio .subprocess .PIPE ,
@@ -243,6 +261,46 @@ async def run_command(self, cmd, input=None, env=None):
243
261
out = out .decode ().strip ()
244
262
return out
245
263
264
+ # List of running background processes, e.g. used by connect_to_job.
265
+ background_processes = []
266
+
267
+ async def _async_wait_process (self , sleep_time ):
268
+ """Asynchronously sleeping process for delayed checks"""
269
+ await asyncio .sleep (sleep_time )
270
+
271
+ async def run_background_command (self , cmd , startup_check_delay = 1 , input = None , env = None ):
272
+ """Runs the given background command, adds it to background_processes,
273
+ and checks if the command is still running after startup_check_delay."""
274
+ background_process = self .run_command (cmd , input , env )
275
+ success_check_delay = self ._async_wait_process (startup_check_delay )
276
+
277
+ # Start up both the success check process and the actual process.
278
+ done , pending = await asyncio .wait ([background_process , success_check_delay ], return_when = asyncio .FIRST_COMPLETED )
279
+
280
+ # If the success check process is the one which exited first, all is good, else fail.
281
+ if list (done )[0 ]._coro == success_check_delay :
282
+ background_task = list (pending )[0 ]
283
+ self .background_processes .append (background_task )
284
+ return background_task
285
+ else :
286
+ self .log .error ("Background command exited early: %s" % cmd )
287
+ gather_pending = asyncio .gather (* pending )
288
+ gather_pending .cancel ()
289
+ try :
290
+ self .log .debug ("Cancelling pending success check task..." )
291
+ await gather_pending
292
+ except asyncio .CancelledError :
293
+ self .log .debug ("Cancel was successful." )
294
+ pass
295
+
296
+ # Retrieve exception from "done" process.
297
+ try :
298
+ gather_done = asyncio .gather (* done )
299
+ await gather_done
300
+ except :
301
+ self .log .debug ("Retrieving exception from failed background task..." )
302
+ raise RuntimeError ('{} failed!' .format (cmd ))
303
+
246
304
async def _get_batch_script (self , ** subvars ):
247
305
"""Format batch script from vars"""
248
306
# Could be overridden by subclasses, but mainly useful for testing
@@ -270,6 +328,27 @@ async def submit_batch_script(self):
270
328
self .job_id = ''
271
329
return self .job_id
272
330
331
+ def background_tasks_ok (self ):
332
+ # Check background processes.
333
+ if self .background_processes :
334
+ self .log .debug ('Checking background processes...' )
335
+ for background_process in self .background_processes :
336
+ if background_process .done ():
337
+ self .log .debug ('Found a background process in state "done"...' )
338
+ try :
339
+ background_exception = background_process .exception ()
340
+ except asyncio .CancelledError :
341
+ self .log .error ('Background process was cancelled!' )
342
+ if background_exception :
343
+ self .log .error ('Background process exited with an exception:' )
344
+ self .log .error (background_exception )
345
+ self .log .error ('At least one background process exited!' )
346
+ return False
347
+ else :
348
+ self .log .debug ('Found a not-yet-done background process...' )
349
+ self .log .debug ('All background processes still running.' )
350
+ return True
351
+
273
352
# Override if your batch system needs something more elaborate to query the job status
274
353
batch_query_cmd = Unicode ('' ,
275
354
help = "Command to run to query job status. Formatted using req_xyz traits as {xyz} "
@@ -314,6 +393,29 @@ async def cancel_batch_job(self):
314
393
cmd = ' ' .join ((format_template (self .exec_prefix , ** subvars ),
315
394
format_template (self .batch_cancel_cmd , ** subvars )))
316
395
self .log .info ('Cancelling job ' + self .job_id + ': ' + cmd )
396
+
397
+ if self .background_processes :
398
+ self .log .debug ('Job being cancelled, cancelling background processes...' )
399
+ for background_process in self .background_processes :
400
+ if not background_process .cancelled ():
401
+ try :
402
+ background_process .cancel ()
403
+ except :
404
+ self .log .error ('Encountered an exception cancelling background process...' )
405
+ self .log .debug ('Cancelled background process, waiting for it to finish...' )
406
+ try :
407
+ await asyncio .wait ([background_process ])
408
+ except asyncio .CancelledError :
409
+ self .log .error ('Successfully cancelled background process.' )
410
+ pass
411
+ except :
412
+ self .log .error ('Background process exited with another exception!' )
413
+ raise
414
+ else :
415
+ self .log .debug ('Background process already cancelled...' )
416
+ self .background_processes .clear ()
417
+ self .log .debug ('All background processes cancelled.' )
418
+
317
419
await self .run_command (cmd )
318
420
319
421
def load_state (self , state ):
@@ -361,6 +463,13 @@ async def poll(self):
361
463
"""Poll the process"""
362
464
status = await self .query_job_status ()
363
465
if status in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
466
+ if not self .background_tasks_ok ():
467
+ self .log .debug ('Going to stop job, since background tasks have failed!' )
468
+ await self .stop (now = True )
469
+ status = await self .query_job_status ()
470
+ if status not in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
471
+ self .clear_state ()
472
+ return 1
364
473
return None
365
474
else :
366
475
self .clear_state ()
@@ -420,6 +529,9 @@ async def start(self):
420
529
self .job_id , self .ip , self .port )
421
530
)
422
531
532
+ if self .connect_to_job_cmd :
533
+ await self .connect_to_job ()
534
+
423
535
return self .ip , self .port
424
536
425
537
async def stop (self , now = False ):
0 commit comments