13
13
14
14
from .base import (SGELikeBatchManagerBase , logger , iflogger , logging )
15
15
16
- from nipype .interfaces .base import CommandLine
16
+ from .. .interfaces .base import CommandLine
17
17
18
18
19
19
class SLURMPlugin (SGELikeBatchManagerBase ):
@@ -38,12 +38,15 @@ def __init__(self, **kwargs):
38
38
self ._max_tries = 2
39
39
self ._template = template
40
40
self ._sbatch_args = None
41
+ self ._jobid_re = "Submitted batch job ([0-9]*)"
41
42
42
43
if 'plugin_args' in kwargs and kwargs ['plugin_args' ]:
43
44
if 'retry_timeout' in kwargs ['plugin_args' ]:
44
45
self ._retry_timeout = kwargs ['plugin_args' ]['retry_timeout' ]
45
46
if 'max_tries' in kwargs ['plugin_args' ]:
46
47
self ._max_tries = kwargs ['plugin_args' ]['max_tries' ]
48
+ if 'jobid_re' in kwargs ['plugin_args' ]:
49
+ self ._jobid_re = kwargs ['plugin_args' ]['jobid_re' ]
47
50
if 'template' in kwargs ['plugin_args' ]:
48
51
self ._template = kwargs ['plugin_args' ]['template' ]
49
52
if os .path .isfile (self ._template ):
@@ -55,17 +58,16 @@ def __init__(self, **kwargs):
55
58
56
59
def _is_pending (self , taskid ):
57
60
# subprocess.Popen requires taskid to be a string
58
- proc = subprocess .Popen (["squeue" , '-j' , '%s' % taskid ],
59
- stdout = subprocess .PIPE ,
60
- stderr = subprocess .PIPE )
61
- o , _ = proc .communicate ()
62
-
63
- return o .find (str (taskid )) > - 1
61
+ res = CommandLine ('squeue' ,
62
+ args = ' ' .join (['-j' , '%s' % taskid ]),
63
+ terminal_output = 'allatonce' ).run ()
64
+ return res .runtime .stdout .find (str (taskid )) > - 1
64
65
65
66
def _submit_batchtask (self , scriptfile , node ):
66
67
"""
67
- This is more or less the _submit_batchtask from sge.py with flipped variable
68
- names, different command line switches, and different output formatting/processing
68
+ This is more or less the _submit_batchtask from sge.py with flipped
69
+ variable names, different command line switches, and different output
70
+ formatting/processing
69
71
"""
70
72
cmd = CommandLine ('sbatch' , environ = dict (os .environ ),
71
73
terminal_output = 'allatonce' )
@@ -118,7 +120,7 @@ def _submit_batchtask(self, scriptfile, node):
118
120
iflogger .setLevel (oldlevel )
119
121
# retrieve taskid
120
122
lines = [line for line in result .runtime .stdout .split ('\n ' ) if line ]
121
- taskid = int (re .match ("Submitted batch job ([0-9]*)" ,
123
+ taskid = int (re .match (self . _jobid_re ,
122
124
lines [- 1 ]).groups ()[0 ])
123
125
self ._pending [taskid ] = node .output_dir ()
124
126
logger .debug ('submitted sbatch task: %d for node %s' % (taskid , node ._id ))
0 commit comments