Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions dk-installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,14 +723,14 @@ def start_cmd(self, *cmd, raise_on_non_zero=True, env=None, **popen_args):
]

try:
with (
stream_iterator(proc, "stdout", stdout_path) as stdout_iter,
stream_iterator(proc, "stderr", stderr_path) as stderr_iter,
):
try:
try:
with (
stream_iterator(proc, "stdout", stdout_path) as stdout_iter,
stream_iterator(proc, "stderr", stderr_path) as stderr_iter,
):
yield proc, stdout_iter, stderr_iter
finally:
proc.wait()
finally:
proc.wait()
if raise_on_non_zero and proc.returncode != 0:
raise CommandFailed
# We capture and raise CommandFailed to allow the client code to raise an empty CommandFailed exception
Expand Down
19 changes: 10 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
from tests.installer import CONSOLE, Action, TESTGEN_DEFAULT_IMAGE


@pytest.fixture
def proc_mock():
proc = Mock()
proc.returncode = 0
proc.wait.return_value = None
proc.poll.return_value = 0
return proc


@pytest.fixture
def stdout_mock():
return Mock(return_value=[])
Expand All @@ -29,6 +20,16 @@ def stderr_mock():
return Mock(return_value=[])


@pytest.fixture
def proc_mock(stdout_mock, stderr_mock):
proc = Mock()
proc.returncode = 0
proc.wait.return_value = None
proc.poll.return_value = 0
proc.communicate.return_value = ["\n".join(stdout_mock()).encode(), "\n".join(stderr_mock()).encode()]
return proc


@pytest.fixture
def console_msg_mock():
with patch.object(CONSOLE, "msg") as mock:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,26 @@ def test_start_cmd_wait_on_exception(action, popen_mock, stream_iter_mock):

assert proc_mock.wait.call_count == 1
assert exc_info.value.args == ("something went wrong",)


@pytest.mark.unit
@pytest.mark.parametrize("raise_", (True, False))
def test_start_cmd_calls_wait_after_streams_consumed(raise_, action, popen_mock):
try:
with action.start_cmd("cmd", "arg") as (proc_mock, _, _):
proc_mock.returncode = 0
if raise_:
raise RuntimeError()
except RuntimeError:
did_raise = True
else:
did_raise = False

assert did_raise == raise_
proc_mock.assert_has_calls(
[
call.communicate(timeout=1.0),
call.communicate(timeout=1.0),
call.wait(),
]
)
15 changes: 11 additions & 4 deletions tests/test_stream_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,28 @@
def popen_stdout_buffer(popen_mock):
buffer = "\n".join(["🔷🔶🔺🔻"[i % 4] + " xxxx" * 20 for i in range(100)]).encode()
popen_mock.communicate.side_effect = [
*[subprocess.TimeoutExpired("cmd", 1, output=buffer[:idx]) for idx in range(0, len(buffer), 38)],
*[subprocess.TimeoutExpired("cmd", 1, output=buffer[:idx], stderr=b"") for idx in range(0, len(buffer), 38)],
(buffer, b""),
(buffer, b""),
]
return buffer


@pytest.mark.unit
def test_stream_iterator(popen_mock, popen_stdout_buffer, tmp_logs_folder):
cmd_log_path = pathlib.Path(tmp_logs_folder) / "cmd-log.txt"
cmd_log_stdout_path = pathlib.Path(tmp_logs_folder) / "cmd-log-out.txt"
cmd_log_stderr_path = pathlib.Path(tmp_logs_folder) / "cmd-log-err.txt"

with stream_iterator(popen_mock, "stdout", cmd_log_path) as stdout_iter:
with (
stream_iterator(popen_mock, "stdout", cmd_log_stdout_path) as stdout_iter,
stream_iterator(popen_mock, "stderr", cmd_log_stderr_path) as stderr_iter,
):
for stdout_line, buffer_line in itertools.zip_longest(stdout_iter, popen_stdout_buffer.splitlines()):
assert stdout_line == buffer_line.decode()
assert list(stderr_iter) == []

assert cmd_log_path.read_bytes() == popen_stdout_buffer
assert cmd_log_stdout_path.read_bytes() == popen_stdout_buffer
assert not cmd_log_stderr_path.exists()


@pytest.mark.unit
Expand Down