Skip to content

Commit 7cba1dc

Browse files
committed
implement initial support for running shell commands asynchronously using run_shell_cmd
1 parent 8aaaec2 commit 7cba1dc

File tree

3 files changed

+107
-32
lines changed

3 files changed

+107
-32
lines changed

easybuild/tools/run.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import subprocess
4646
import sys
4747
import tempfile
48+
import threading
4849
import time
4950
from collections import namedtuple
5051
from datetime import datetime
@@ -79,7 +80,7 @@
7980

8081

8182
RunShellCmdResult = namedtuple('RunShellCmdResult', ('cmd', 'exit_code', 'output', 'stderr', 'work_dir',
82-
'out_file', 'err_file'))
83+
'out_file', 'err_file', 'thread_id'))
8384

8485

8586
class RunShellCmdError(BaseException):
@@ -199,7 +200,7 @@ def run_shell_cmd(cmd, fail_on_error=True, split_stderr=False, stdin=None, env=N
199200
:param use_bash: execute command through bash shell (enabled by default)
200201
:param output_file: collect command output in temporary output file
201202
:param stream_output: stream command output to stdout (auto-enabled with --logtostdout if None)
202-
:param asynchronous: run command asynchronously
203+
:param asynchronous: indicate that command is being run asynchronously
203204
:param with_hooks: trigger pre/post run_shell_cmd hooks (if defined)
204205
:param qa_patterns: list of 2-tuples with patterns for questions + corresponding answers
205206
:param qa_wait_patterns: list of 2-tuples with patterns for non-questions
@@ -223,9 +224,6 @@ def to_cmd_str(cmd):
223224
return cmd_str
224225

225226
# temporarily raise a NotImplementedError until all options are implemented
226-
if asynchronous:
227-
raise NotImplementedError
228-
229227
if qa_patterns or qa_wait_patterns:
230228
raise NotImplementedError
231229

@@ -235,6 +233,11 @@ def to_cmd_str(cmd):
235233
cmd_str = to_cmd_str(cmd)
236234
cmd_name = os.path.basename(cmd_str.split(' ')[0])
237235

236+
thread_id = None
237+
if asynchronous:
238+
thread_id = threading.get_native_id()
239+
_log.info(f"Initiating running of shell command '{cmd_str}' via thread with ID {thread_id}")
240+
238241
# auto-enable streaming of command output under --logtostdout/-l, unless it was disabled explicitely
239242
if stream_output is None and build_option('logtostdout'):
240243
_log.info(f"Auto-enabling streaming output of '{cmd_str}' command because logging to stdout is enabled")
@@ -259,16 +262,16 @@ def to_cmd_str(cmd):
259262
if not in_dry_run and build_option('extended_dry_run'):
260263
if not hidden or verbose_dry_run:
261264
silent = build_option('silent')
262-
msg = f" running command \"{cmd_str}\"\n"
265+
msg = f" running shell command \"{cmd_str}\"\n"
263266
msg += f" (in {work_dir})"
264267
dry_run_msg(msg, silent=silent)
265268

266269
return RunShellCmdResult(cmd=cmd_str, exit_code=0, output='', stderr=None, work_dir=work_dir,
267-
out_file=cmd_out_fp, err_file=cmd_err_fp)
270+
out_file=cmd_out_fp, err_file=cmd_err_fp, thread_id=thread_id)
268271

269272
start_time = datetime.now()
270273
if not hidden:
271-
cmd_trace_msg(cmd_str, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp)
274+
_cmd_trace_msg(cmd_str, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id)
272275

273276
if stream_output:
274277
print_msg(f"(streaming) output for command '{cmd_str}':")
@@ -293,7 +296,11 @@ def to_cmd_str(cmd):
293296

294297
stderr = subprocess.PIPE if split_stderr else subprocess.STDOUT
295298

296-
_log.info(f"Running command '{cmd_str}' in {work_dir}")
299+
log_msg = f"Running shell command '{cmd_str}' in {work_dir}"
300+
if thread_id:
301+
log_msg += f" (via thread with ID {thread_id})"
302+
_log.info(log_msg)
303+
297304
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=stderr, stdin=subprocess.PIPE,
298305
cwd=work_dir, env=env, shell=shell, executable=executable)
299306

@@ -337,7 +344,7 @@ def to_cmd_str(cmd):
337344
raise EasyBuildError(f"Failed to dump command output to temporary file: {err}")
338345

339346
res = RunShellCmdResult(cmd=cmd_str, exit_code=proc.returncode, output=output, stderr=stderr, work_dir=work_dir,
340-
out_file=cmd_out_fp, err_file=cmd_err_fp)
347+
out_file=cmd_out_fp, err_file=cmd_err_fp, thread_id=thread_id)
341348

342349
# always log command output
343350
cmd_name = cmd_str.split(' ')[0]
@@ -370,7 +377,7 @@ def to_cmd_str(cmd):
370377
return res
371378

372379

373-
def cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp):
380+
def _cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id):
374381
"""
375382
Helper function to construct and print trace message for command being run
376383
@@ -380,11 +387,18 @@ def cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp):
380387
:param stdin: stdin input value for command
381388
:param cmd_out_fp: path to output file for command
382389
:param cmd_err_fp: path to errors/warnings output file for command
390+
:param thread_id: thread ID (None when not running shell command asynchronously)
383391
"""
384392
start_time = start_time.strftime('%Y-%m-%d %H:%M:%S')
385393

394+
if thread_id:
395+
run_cmd_msg = f"running shell command (asynchronously, thread ID: {thread_id}):"
396+
else:
397+
run_cmd_msg = "running shell command:"
398+
386399
lines = [
387-
"running command:",
400+
run_cmd_msg,
401+
f"\t{cmd}",
388402
f"\t[started at: {start_time}]",
389403
f"\t[working dir: {work_dir}]",
390404
]
@@ -395,8 +409,6 @@ def cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp):
395409
if cmd_err_fp:
396410
lines.append(f"\t[errors/warnings saved to {cmd_err_fp}]")
397411

398-
lines.append('\t' + cmd)
399-
400412
trace_msg('\n'.join(lines))
401413

402414

test/framework/run.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import tempfile
4242
import textwrap
4343
import time
44+
from concurrent.futures import ThreadPoolExecutor
4445
from test.framework.utilities import EnhancedTestCase, TestLoaderFiltered, init_config
4546
from unittest import TextTestRunner
4647
from easybuild.base.fancylogger import setLogLevelDebug
@@ -248,7 +249,7 @@ def test_run_shell_cmd_log(self):
248249
fd, logfile = tempfile.mkstemp(suffix='.log', prefix='eb-test-')
249250
os.close(fd)
250251

251-
regex_start_cmd = re.compile("Running command 'echo hello' in /")
252+
regex_start_cmd = re.compile("Running shell command 'echo hello' in /")
252253
regex_cmd_exit = re.compile(r"Shell command completed successfully \(see output above\): echo hello")
253254

254255
# command output is always logged
@@ -448,7 +449,7 @@ def test_run_cmd_work_dir(self):
448449

449450
def test_run_shell_cmd_work_dir(self):
450451
"""
451-
Test running command in specific directory with run_shell_cmd function.
452+
Test running shell command in specific directory with run_shell_cmd function.
452453
"""
453454
orig_wd = os.getcwd()
454455
self.assertFalse(os.path.samefile(orig_wd, self.test_prefix))
@@ -615,11 +616,11 @@ def test_run_shell_cmd_trace(self):
615616
"""Test run_shell_cmd function in trace mode, and with tracing disabled."""
616617

617618
pattern = [
618-
r"^ >> running command:",
619+
r"^ >> running shell command:",
620+
r"\techo hello",
619621
r"\t\[started at: .*\]",
620622
r"\t\[working dir: .*\]",
621623
r"\t\[output saved to .*\]",
622-
r"\techo hello",
623624
r" >> command completed: exit 0, ran in .*",
624625
]
625626

@@ -675,11 +676,11 @@ def test_run_shell_cmd_trace_stdin(self):
675676
init_config(build_options={'trace': True})
676677

677678
pattern = [
678-
r"^ >> running command:",
679+
r"^ >> running shell command:",
680+
r"\techo hello",
679681
r"\t\[started at: [0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] [0-9][0-9]:[0-9][0-9]:[0-9][0-9]\]",
680682
r"\t\[working dir: .*\]",
681683
r"\t\[output saved to .*\]",
682-
r"\techo hello",
683684
r" >> command completed: exit 0, ran in .*",
684685
]
685686

@@ -707,8 +708,8 @@ def test_run_shell_cmd_trace_stdin(self):
707708
self.assertEqual(res.output, 'hello')
708709
self.assertEqual(res.exit_code, 0)
709710
self.assertEqual(stderr, '')
710-
pattern.insert(3, r"\t\[input: hello\]")
711-
pattern[-2] = "\tcat"
711+
pattern.insert(4, r"\t\[input: hello\]")
712+
pattern[1] = "\tcat"
712713
regex = re.compile('\n'.join(pattern))
713714
self.assertTrue(regex.search(stdout), "Pattern '%s' found in: %s" % (regex.pattern, stdout))
714715

@@ -909,7 +910,8 @@ def test_run_shell_cmd_cache(self):
909910
# inject value into cache to check whether executing command again really returns cached value
910911
with self.mocked_stdout_stderr():
911912
cached_res = RunShellCmdResult(cmd=cmd, output="123456", exit_code=123, stderr=None,
912-
work_dir='/test_ulimit', out_file='/tmp/foo.out', err_file=None)
913+
work_dir='/test_ulimit', out_file='/tmp/foo.out', err_file=None,
914+
thread_id=None)
913915
run_shell_cmd.update_cache({(cmd, None): cached_res})
914916
res = run_shell_cmd(cmd)
915917
self.assertEqual(res.cmd, cmd)
@@ -928,7 +930,8 @@ def test_run_shell_cmd_cache(self):
928930
# inject different output for cat with 'foo' as stdin to check whether cached value is used
929931
with self.mocked_stdout_stderr():
930932
cached_res = RunShellCmdResult(cmd=cmd, output="bar", exit_code=123, stderr=None,
931-
work_dir='/test_cat', out_file='/tmp/cat.out', err_file=None)
933+
work_dir='/test_cat', out_file='/tmp/cat.out', err_file=None,
934+
thread_id=None)
932935
run_shell_cmd.update_cache({(cmd, 'foo'): cached_res})
933936
res = run_shell_cmd(cmd, stdin='foo')
934937
self.assertEqual(res.cmd, cmd)
@@ -1006,7 +1009,7 @@ def test_run_shell_cmd_dry_run(self):
10061009
self.assertEqual(res.output, '')
10071010
self.assertEqual(res.stderr, None)
10081011
# check dry run output
1009-
expected = """ running command "somecommand foo 123 bar"\n"""
1012+
expected = """ running shell command "somecommand foo 123 bar"\n"""
10101013
self.assertIn(expected, stdout)
10111014

10121015
# check enabling 'hidden'
@@ -1029,7 +1032,7 @@ def test_run_shell_cmd_dry_run(self):
10291032
fail_on_error=False, in_dry_run=True)
10301033
stdout = self.get_stdout()
10311034
self.mock_stdout(False)
1032-
self.assertNotIn('running command "', stdout)
1035+
self.assertNotIn('running shell command "', stdout)
10331036
self.assertNotEqual(res.exit_code, 0)
10341037
self.assertEqual(res.output, 'done\n')
10351038
self.assertEqual(res.stderr, None)
@@ -1207,7 +1210,7 @@ def test_run_cmd_async(self):
12071210
"for i in $(seq 1 50)",
12081211
"do sleep 0.1",
12091212
"for j in $(seq 1000)",
1210-
"do echo foo",
1213+
"do echo foo${i}${j}",
12111214
"done",
12121215
"done",
12131216
"echo done",
@@ -1257,8 +1260,68 @@ def test_run_cmd_async(self):
12571260
res = check_async_cmd(*cmd_info, output=res['output'])
12581261
self.assertEqual(res['done'], True)
12591262
self.assertEqual(res['exit_code'], 0)
1260-
self.assertTrue(res['output'].startswith('start\n'))
1261-
self.assertTrue(res['output'].endswith('\ndone\n'))
1263+
self.assertEqual(len(res['output']), 435661)
1264+
self.assertTrue(res['output'].startswith('start\nfoo11\nfoo12\n'))
1265+
self.assertTrue('\nfoo49999\nfoo491000\nfoo501\n' in res['output'])
1266+
self.assertTrue(res['output'].endswith('\nfoo501000\ndone\n'))
1267+
1268+
def test_run_shell_cmd_async(self):
1269+
"""Test asynchronously running of a shell command via run_shell_cmd """
1270+
1271+
thread_pool = ThreadPoolExecutor()
1272+
1273+
os.environ['TEST'] = 'test123'
1274+
env = os.environ.copy()
1275+
1276+
test_cmd = "echo 'sleeping...'; sleep 2; echo $TEST"
1277+
task = thread_pool.submit(run_shell_cmd, test_cmd, hidden=True, asynchronous=True, env=env)
1278+
1279+
# change value of $TEST to check that command is completed with correct environment
1280+
os.environ['TEST'] = 'some_other_value'
1281+
1282+
# initial poll should result in None, since it takes a while for the command to complete
1283+
self.assertEqual(task.done(), False)
1284+
1285+
# wait until command is done
1286+
while not task.done():
1287+
time.sleep(1)
1288+
res = task.result()
1289+
1290+
self.assertEqual(res.exit_code, 0)
1291+
self.assertEqual(res.output, 'sleeping...\ntest123\n')
1292+
1293+
# check asynchronous running of failing command
1294+
error_test_cmd = "echo 'FAIL!' >&2; exit 123"
1295+
task = thread_pool.submit(run_shell_cmd, error_test_cmd, hidden=True, fail_on_error=False, asynchronous=True)
1296+
time.sleep(1)
1297+
res = task.result()
1298+
self.assertEqual(res.exit_code, 123)
1299+
self.assertEqual(res.output, "FAIL!\n")
1300+
self.assertTrue(res.thread_id)
1301+
1302+
# also test with a command that produces a lot of output,
1303+
# since that tends to lock up things unless we frequently grab some output...
1304+
verbose_test_cmd = ';'.join([
1305+
"echo start",
1306+
"for i in $(seq 1 50)",
1307+
"do sleep 0.1",
1308+
"for j in $(seq 1000)",
1309+
"do echo foo${i}${j}",
1310+
"done",
1311+
"done",
1312+
"echo done",
1313+
])
1314+
task = thread_pool.submit(run_shell_cmd, verbose_test_cmd, hidden=True, asynchronous=True)
1315+
1316+
while not task.done():
1317+
time.sleep(1)
1318+
res = task.result()
1319+
1320+
self.assertEqual(res.exit_code, 0)
1321+
self.assertEqual(len(res.output), 435661)
1322+
self.assertTrue(res.output.startswith('start\nfoo11\nfoo12\n'))
1323+
self.assertTrue('\nfoo49999\nfoo491000\nfoo501\n' in res.output)
1324+
self.assertTrue(res.output.endswith('\nfoo501000\ndone\n'))
12621325

12631326
def test_check_log_for_errors(self):
12641327
fd, logfile = tempfile.mkstemp(suffix='.log', prefix='eb-test-')
@@ -1373,7 +1436,7 @@ def post_run_shell_cmd_hook(cmd, *args, **kwargs):
13731436

13741437
def test_run_shell_cmd_with_hooks(self):
13751438
"""
1376-
Test running command with run_shell_cmd function with pre/post run_shell_cmd hooks in place.
1439+
Test running shell command with run_shell_cmd function with pre/post run_shell_cmd hooks in place.
13771440
"""
13781441
cwd = os.getcwd()
13791442

test/framework/systemtools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def mocked_run_shell_cmd(cmd, **kwargs):
341341
}
342342
if cmd in known_cmds:
343343
return RunShellCmdResult(cmd=cmd, exit_code=0, output=known_cmds[cmd], stderr=None, work_dir=os.getcwd(),
344-
out_file=None, err_file=None)
344+
out_file=None, err_file=None, thread_id=None)
345345
else:
346346
return run_shell_cmd(cmd, **kwargs)
347347

@@ -774,7 +774,7 @@ def test_gcc_version_darwin(self):
774774
out = "Apple LLVM version 7.0.0 (clang-700.1.76)"
775775
cwd = os.getcwd()
776776
mocked_run_res = RunShellCmdResult(cmd="gcc --version", exit_code=0, output=out, stderr=None, work_dir=cwd,
777-
out_file=None, err_file=None)
777+
out_file=None, err_file=None, thread_id=None)
778778
st.run_shell_cmd = lambda *args, **kwargs: mocked_run_res
779779
self.assertEqual(get_gcc_version(), None)
780780

0 commit comments

Comments
 (0)