Skip to content

Commit 3b2517f

Browse files
authored
refact: cleanup bats helper (#3636)
1 parent 981d282 commit 3b2517f

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

test/bin/wait-for

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,31 @@ DEFAULT_TIMEOUT = 30
1515
# TODO: print unmatched patterns
1616

1717

18-
async def terminate(p):
19-
# Terminate the process group (shell, crowdsec plugins)
18+
async def terminate_group(p: asyncio.subprocess.Process):
19+
"""
20+
Terminate the process group (shell, crowdsec plugins)
21+
"""
2022
try:
2123
os.killpg(os.getpgid(p.pid), signal.SIGTERM)
2224
except ProcessLookupError:
2325
pass
2426

2527

26-
async def monitor(cmd, args, want_out, want_err, timeout):
27-
"""Monitor a process and terminate it if a pattern is matched in stdout or stderr.
28+
async def monitor(
29+
cmd: str,
30+
args: list[str],
31+
out_regex: re.Pattern[str] | None,
32+
err_regex: re.Pattern[str] | None,
33+
timeout: float
34+
) -> int:
35+
"""
36+
Run a subprocess, monitor its stdout/stderr for matches, and handle timeouts or pattern hits.
2837
2938
Args:
3039
cmd: The command to run.
3140
args: A list of arguments to pass to the command.
32-
stdout: A regular expression pattern to search for in stdout.
33-
stderr: A regular expression pattern to search for in stderr.
41+
out_regex: A compiled regular expression to search for in stdout.
42+
err_regex: A compiled regular expression to search for in stderr.
3443
timeout: The maximum number of seconds to wait for the process to terminate.
3544
3645
Returns:
@@ -39,17 +48,18 @@ async def monitor(cmd, args, want_out, want_err, timeout):
3948

4049
status = None
4150

42-
async def read_stream(stream, outstream, pattern):
51+
async def read_stream(stream: asyncio.StreamReader | None, out, pattern: re.Pattern[str] | None):
4352
nonlocal status
4453
if stream is None:
4554
return
55+
4656
while True:
4757
line = await stream.readline()
4858
if line:
4959
line = line.decode('utf-8')
50-
outstream.write(line)
60+
out.write(line)
5161
if pattern and pattern.search(line):
52-
await terminate(process)
62+
await terminate_group(process)
5363
# this is nasty.
5464
# if we timeout, we want to return a different exit code
5565
# in case of a match, so that the caller can tell
@@ -76,9 +86,6 @@ async def monitor(cmd, args, want_out, want_err, timeout):
7686
# (required to kill child processes when cmd is a shell)
7787
preexec_fn=os.setsid)
7888

79-
out_regex = re.compile(want_out) if want_out else None
80-
err_regex = re.compile(want_err) if want_err else None
81-
8289
# Apply a timeout
8390
try:
8491
await asyncio.wait_for(
@@ -90,27 +97,38 @@ async def monitor(cmd, args, want_out, want_err, timeout):
9097
if status is None:
9198
status = process.returncode
9299
except asyncio.TimeoutError:
93-
await terminate(process)
100+
await terminate_group(process)
94101
status = 241
95102

96103
# Return the same exit code, stdout and stderr as the spawned process
97-
return status
104+
return status or 0
105+
106+
107+
class Args(argparse.Namespace):
108+
cmd: str = ''
109+
args: list[str] = []
110+
out: str = ''
111+
err: str = ''
112+
timeout: float = DEFAULT_TIMEOUT
98113

99114

100115
async def main():
101116
parser = argparse.ArgumentParser(
102117
description='Monitor a process and terminate it if a pattern is matched in stdout or stderr.')
103-
parser.add_argument('cmd', help='The command to run.')
104-
parser.add_argument('args', nargs=argparse.REMAINDER, help='A list of arguments to pass to the command.')
105-
parser.add_argument('--out', default='', help='A regular expression pattern to search for in stdout.')
106-
parser.add_argument('--err', default='', help='A regular expression pattern to search for in stderr.')
107-
parser.add_argument('--timeout', type=float, default=DEFAULT_TIMEOUT)
108-
args = parser.parse_args()
118+
_ = parser.add_argument('cmd', help='The command to run.')
119+
_ = parser.add_argument('args', nargs=argparse.REMAINDER, help='A list of arguments to pass to the command.')
120+
_ = parser.add_argument('--out', help='A regular expression pattern to search for in stdout.')
121+
_ = parser.add_argument('--err', help='A regular expression pattern to search for in stderr.')
122+
_ = parser.add_argument('--timeout', type=float, default=DEFAULT_TIMEOUT)
123+
args: Args = parser.parse_args(namespace=Args())
124+
125+
out_regex = re.compile(args.out) if args.out else None
126+
err_regex = re.compile(args.err) if args.err else None
109127

110-
exit_code = await monitor(args.cmd, args.args, args.out, args.err, args.timeout)
128+
exit_code = await monitor(args.cmd, args.args, out_regex, err_regex, args.timeout)
111129

112-
sys.exit(exit_code)
130+
return exit_code
113131

114132

115133
if __name__ == '__main__':
116-
asyncio.run(main())
134+
sys.exit(asyncio.run(main()))

0 commit comments

Comments
 (0)