Skip to content

Commit c5becaa

Browse files
committed
Merge remote-tracking branch 'upstream/pr/90' into dev
Conflicts: batchspawner/batchspawner.py
2 parents 625c495 + ad03efc commit c5becaa

File tree

1 file changed

+48
-50
lines changed

1 file changed

+48
-50
lines changed

batchspawner/batchspawner.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818
import pwd
1919
import os
20+
import asyncio
2021
import re
2122

2223
import xml.etree.ElementTree as ET
@@ -183,56 +184,58 @@ def parse_job_id(self, output):
183184
def cmd_formatted_for_batch(self):
184185
return ' '.join(['batchspawner-singleuser'] + self.cmd + self.get_args())
185186

186-
@gen.coroutine
187-
def run_command(self, cmd, input=None, env=None):
188-
proc = Subprocess(cmd, shell=True, env=env, stdin=Subprocess.STREAM, stdout=Subprocess.STREAM,stderr=Subprocess.STREAM)
189-
inbytes = None
187+
async def run_command(self, cmd, input=None, env=None):
188+
proc = await asyncio.create_subprocess_shell(cmd, env=env,
189+
stdin=asyncio.subprocess.PIPE,
190+
stdout=asyncio.subprocess.PIPE,
191+
stderr=asyncio.subprocess.PIPE)
192+
inbytes=None
193+
190194
if input:
191-
inbytes = input.encode()
192-
try:
193-
yield proc.stdin.write(inbytes)
194-
except StreamClosedError as exp:
195-
# Apparently harmless
196-
pass
197-
proc.stdin.close()
198-
out, eout = yield [proc.stdout.read_until_close(),
199-
proc.stderr.read_until_close()]
200-
proc.stdout.close()
201-
proc.stderr.close()
202-
eout = eout.decode().strip()
195+
inbytes=input.encode()
196+
203197
try:
204-
err = yield proc.wait_for_exit()
205-
except CalledProcessError:
198+
out, eout = await proc.communicate(input=inbytes)
199+
except:
200+
self.log.debug("Exception raised when trying to run command: %s" % command)
201+
proc.kill()
202+
self.log.debug("Running command failed done kill")
203+
out, eout = await proc.communicate()
204+
out = out.decode.strip()
205+
eout = eout.decode.strip()
206206
self.log.error("Subprocess returned exitcode %s" % proc.returncode)
207207
self.log.error('Stdout:')
208208
self.log.error(out)
209209
self.log.error('Stderr:')
210210
self.log.error(eout)
211211
raise RuntimeError('{} exit status {}: {}'.format(cmd, proc.returncode, eout))
212-
if err != 0:
213-
return err # exit error?
214212
else:
215-
out = out.decode().strip()
216-
return out
217-
218-
@gen.coroutine
219-
def _get_batch_script(self, **subvars):
213+
eout = eout.decode().strip()
214+
err = proc.returncode
215+
if err != 0:
216+
self.log.error("Subprocess returned exitcode %s" % err)
217+
self.log.error(eout)
218+
raise RuntimeError(eout)
219+
220+
out = out.decode().strip()
221+
return out
222+
223+
async def _get_batch_script(self, **subvars):
220224
"""Format batch script from vars"""
221-
# Colud be overridden by subclasses, but mainly useful for testing
225+
# Could be overridden by subclasses, but mainly useful for testing
222226
return format_template(self.batch_script, **subvars)
223227

224-
@gen.coroutine
225-
def submit_batch_script(self):
228+
async def submit_batch_script(self):
226229
subvars = self.get_req_subvars()
227230
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
228231
format_template(self.batch_submit_cmd, **subvars)))
229232
subvars['cmd'] = self.cmd_formatted_for_batch()
230233
if hasattr(self, 'user_options'):
231234
subvars.update(self.user_options)
232-
script = yield self._get_batch_script(**subvars)
235+
script = await self._get_batch_script(**subvars)
233236
self.log.info('Spawner submitting job using ' + cmd)
234237
self.log.info('Spawner submitted script:\n' + script)
235-
out = yield self.run_command(cmd, input=script, env=self.get_env())
238+
out = await self.run_command(cmd, input=script, env=self.get_env())
236239
try:
237240
self.log.info('Job submitted. cmd: ' + cmd + ' output: ' + out)
238241
self.job_id = self.parse_job_id(out)
@@ -247,8 +250,7 @@ def submit_batch_script(self):
247250
"and self.job_id as {job_id}."
248251
).tag(config=True)
249252

250-
@gen.coroutine
251-
def read_job_state(self):
253+
async def read_job_state(self):
252254
if self.job_id is None or len(self.job_id) == 0:
253255
# job not running
254256
self.job_status = ''
@@ -259,7 +261,7 @@ def read_job_state(self):
259261
format_template(self.batch_query_cmd, **subvars)))
260262
self.log.debug('Spawner querying job: ' + cmd)
261263
try:
262-
out = yield self.run_command(cmd, env=self.get_env())
264+
out = await self.run_command(cmd)
263265
self.job_status = out
264266
except Exception as e:
265267
self.log.error('Error querying job ' + self.job_id)
@@ -271,14 +273,13 @@ def read_job_state(self):
271273
help="Command to stop/cancel a previously submitted job. Formatted like batch_query_cmd."
272274
).tag(config=True)
273275

274-
@gen.coroutine
275-
def cancel_batch_job(self):
276+
async def cancel_batch_job(self):
276277
subvars = self.get_req_subvars()
277278
subvars['job_id'] = self.job_id
278279
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
279280
format_template(self.batch_cancel_cmd, **subvars)))
280281
self.log.info('Cancelling job ' + self.job_id + ': ' + cmd)
281-
yield self.run_command(cmd, env=self.get_env())
282+
await self.run_command(cmd)
282283

283284
def load_state(self, state):
284285
"""load job_id from state"""
@@ -317,11 +318,10 @@ def state_gethost(self):
317318
"Return string, hostname or addr of running job, likely by parsing self.job_status"
318319
raise NotImplementedError("Subclass must provide implementation")
319320

320-
@gen.coroutine
321-
def poll(self):
321+
async def poll(self):
322322
"""Poll the process"""
323323
if self.job_id is not None and len(self.job_id) > 0:
324-
yield self.read_job_state()
324+
await self.read_job_state()
325325
if self.state_isrunning() or self.state_ispending():
326326
return None
327327
else:
@@ -337,16 +337,15 @@ def poll(self):
337337
help="Polling interval (seconds) to check job state during startup"
338338
).tag(config=True)
339339

340-
@gen.coroutine
341-
def start(self):
340+
async def start(self):
342341
"""Start the process"""
343342
self.ip = self.traits()['ip'].default_value
344343
self.port = self.traits()['port'].default_value
345344

346345
if jupyterhub.version_info >= (0,8) and self.server:
347346
self.server.port = self.port
348347

349-
job = yield self.submit_batch_script()
348+
job = await self.submit_batch_script()
350349

351350
# We are called with a timeout, and if the timeout expires this function will
352351
# be interrupted at the next yield, and self.stop() will be called.
@@ -355,7 +354,7 @@ def start(self):
355354
if len(self.job_id) == 0:
356355
raise RuntimeError("Jupyter batch job submission failure (no jobid in output)")
357356
while True:
358-
yield self.poll()
357+
await self.poll()
359358
if self.state_isrunning():
360359
break
361360
else:
@@ -367,11 +366,11 @@ def start(self):
367366
raise RuntimeError('The Jupyter batch job has disappeared'
368367
' while pending in the queue or died immediately'
369368
' after starting.')
370-
yield gen.sleep(self.startup_poll_interval)
369+
await gen.sleep(self.startup_poll_interval)
371370

372371
self.ip = self.state_gethost()
373372
while self.port == 0:
374-
yield gen.sleep(self.startup_poll_interval)
373+
await gen.sleep(self.startup_poll_interval)
375374
# Test framework: For testing, mock_port is set because we
376375
# don't actually run the single-user server yet.
377376
if hasattr(self, 'mock_port'):
@@ -388,22 +387,21 @@ def start(self):
388387

389388
return self.ip, self.port
390389

391-
@gen.coroutine
392-
def stop(self, now=False):
390+
async def stop(self, now=False):
393391
"""Stop the singleuser server job.
394392
395393
Returns immediately after sending job cancellation command if now=True, otherwise
396394
tries to confirm that job is no longer running."""
397395

398396
self.log.info("Stopping server job " + self.job_id)
399-
yield self.cancel_batch_job()
397+
await self.cancel_batch_job()
400398
if now:
401399
return
402400
for i in range(10):
403-
yield self.poll()
401+
await self.poll()
404402
if not self.state_isrunning():
405403
return
406-
yield gen.sleep(1.0)
404+
await gen.sleep(1.0)
407405
if self.job_id:
408406
self.log.warn("Notebook server job {0} at {1}:{2} possibly failed to terminate".format(
409407
self.job_id, self.ip, self.port)

0 commit comments

Comments
 (0)