Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 23 additions & 32 deletions src/cockpit/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import struct
import subprocess
import termios
from threading import Thread
from typing import Any, ClassVar, Sequence

from .jsonutil import JsonObject, get_int
Expand All @@ -33,6 +34,8 @@ def prctl(*args: int) -> None:


logger = logging.getLogger(__name__)


IOV_MAX = 1024 # man 2 writev


Expand Down Expand Up @@ -305,50 +308,38 @@ def get_stderr(self, *, reset: bool = False) -> str:
return ''

def watch_exit(self, process: 'subprocess.Popen[bytes]') -> None:
def flag_exit() -> None:
def child_exited(pid: int, status: int) -> None:
assert pid == process.pid
# os.waitstatus_to_exitcode() is only available since Python 3.9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could become a polyfill, for easier future clean-up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. This would fit nicely in polyfills.py with the rest of them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if os.WIFEXITED(status):
self._returncode = os.WEXITSTATUS(status)
elif os.WIFSIGNALED(status):
self._returncode = -os.WTERMSIG(status)
else:
self._returncode = status

assert isinstance(self._protocol, SubprocessProtocol)
logger.debug('Process exited with status %d', self._returncode)
if not self._closing:
self._protocol.process_exited()

def pidfd_ready() -> None:
pid, status = os.waitpid(process.pid, 0)
assert pid == process.pid
try:
self._returncode = os.waitstatus_to_exitcode(status)
except ValueError:
self._returncode = status
pid, status = os.waitpid(process.pid, 0) # should never block
self._loop.remove_reader(pidfd)
os.close(pidfd)
flag_exit()

def child_watch_fired(pid: int, code: int) -> None:
assert process.pid == pid
self._returncode = code
flag_exit()

# We first try to create a pidfd to track the process manually. If
# that does work, we need to create a SafeChildWatcher, which has been
# deprecated and removed in Python 3.14. This effectively means that
# using Python 3.14 requires that we're running on a kernel with pidfd
# support, which is fine: the only place we still care about such old
# kernels is on RHEL8 and we have Python 3.6 there.
child_exited(pid, status)

def waitpid_thread() -> None:
pid, status = os.waitpid(process.pid, 0) # will block
self._loop.call_soon_threadsafe(child_exited, pid, status)

# We first try to create a pidfd to track the process. If that doesn't
# work, we spawn a thread to do a blocking waitpid().
try:
pidfd = os.pidfd_open(process.pid)
self._loop.add_reader(pidfd, pidfd_ready)
except (AttributeError, OSError):
quark = '_cockpit_transports_child_watcher'
watcher = getattr(self._loop, quark, None)

if watcher is None:
try:
watcher = asyncio.SafeChildWatcher() # type: ignore[attr-defined]
except AttributeError as e:
raise RuntimeError('pidfd support required on Python 3.14+') from e
watcher.attach_loop(self._loop)
setattr(self._loop, quark, watcher)

watcher.add_child_handler(process.pid, child_watch_fired)
Thread(name=f'cockpit-waitpid-{process.pid}', target=waitpid_thread, daemon=True).start()

def __init__(self,
loop: asyncio.AbstractEventLoop,
Expand Down
87 changes: 80 additions & 7 deletions test/pytest/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import os
import signal
import subprocess
import sys
import unittest.mock
from typing import Any, List, Optional, Tuple

Expand Down Expand Up @@ -289,17 +288,91 @@ async def test_stderr(self) -> None:
assert transport.get_stderr(reset=True) == ''

@pytest.mark.asyncio
async def test_safe_watcher_ENOSYS(self, monkeypatch: pytest.MonkeyPatch) -> None:
async def test_pidfd_ENOSYS(self, monkeypatch: pytest.MonkeyPatch) -> None:
# this test disables pidfd support in order to force the fallback path
# which creates a SafeChildWatcher. That's deprecated since 3.12 and
# removed in 3.14, so skip this test on those versions to avoid issues.
if sys.version_info >= (3, 12, 0):
pytest.skip()

monkeypatch.setattr(os, 'pidfd_open', unittest.mock.Mock(side_effect=OSError), raising=False)
protocol, _transport = self.subprocess(['true'])
await protocol.eof_and_exited_with_code(0)

@pytest.mark.asyncio
async def test_pidfd_ENOSYS_nonzero_exit(self, monkeypatch: pytest.MonkeyPatch) -> None:
# test that non-zero exit codes are correctly reported via the threaded fallback path
monkeypatch.setattr(os, 'pidfd_open', unittest.mock.Mock(side_effect=OSError), raising=False)
protocol, _transport = self.subprocess(['false'])
await protocol.eof_and_exited_with_code(1)

@pytest.mark.asyncio
async def test_pidfd_ENOSYS_exit_code(self, monkeypatch: pytest.MonkeyPatch) -> None:
# test that specific exit codes are correctly reported via the threaded fallback path
monkeypatch.setattr(os, 'pidfd_open', unittest.mock.Mock(side_effect=OSError), raising=False)
protocol, _transport = self.subprocess(['sh', '-c', 'exit 42'])
await protocol.eof_and_exited_with_code(42)

@pytest.mark.asyncio
async def test_pidfd_ENOSYS_signal(self, monkeypatch: pytest.MonkeyPatch) -> None:
# test that signal termination is correctly reported via the threaded fallback path
monkeypatch.setattr(os, 'pidfd_open', unittest.mock.Mock(side_effect=OSError), raising=False)
protocol, transport = self.subprocess(['cat'])
transport.send_signal(signal.SIGTERM)
await protocol.eof_and_exited_with_code(-signal.SIGTERM)

@pytest.mark.asyncio
async def test_pidfd_ENOSYS_kill(self, monkeypatch: pytest.MonkeyPatch) -> None:
# test that SIGKILL is correctly reported via the threaded fallback path
monkeypatch.setattr(os, 'pidfd_open', unittest.mock.Mock(side_effect=OSError), raising=False)
protocol, transport = self.subprocess(['cat'])
transport.kill()
await protocol.eof_and_exited_with_code(-signal.SIGKILL)

@pytest.mark.asyncio
async def test_pidfd_ENOSYS_concurrent(self, monkeypatch: pytest.MonkeyPatch) -> None:
# test multiple concurrent subprocesses with different exit scenarios
# using the threaded fallback path, to ensure exit statuses don't get mixed up
monkeypatch.setattr(os, 'pidfd_open', unittest.mock.Mock(side_effect=OSError), raising=False)

# start processes that block on stdin - we control when they exit
proto_0, transport_0 = self.subprocess(['sh', '-c', 'read a; exit 0'])
proto_1, transport_1 = self.subprocess(['sh', '-c', 'read a; exit 1'])
proto_42, transport_42 = self.subprocess(['sh', '-c', 'read a; exit 42'])
proto_term, transport_term = self.subprocess(['sh', '-c', 'read a; exit 99'])
proto_kill, transport_kill = self.subprocess(['sh', '-c', 'read a; exit 99'])

# create tasks for each process
task_0 = asyncio.create_task(proto_0.eof_and_exited_with_code(0))
task_1 = asyncio.create_task(proto_1.eof_and_exited_with_code(1))
task_42 = asyncio.create_task(proto_42.eof_and_exited_with_code(42))
task_term = asyncio.create_task(proto_term.eof_and_exited_with_code(-signal.SIGTERM))
task_kill = asyncio.create_task(proto_kill.eof_and_exited_with_code(-signal.SIGKILL))
pending = {task_0, task_1, task_42, task_term, task_kill}

# exit them one by one in a specific order and verify each time
transport_kill.kill()
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
assert done == {task_kill}
task_kill.result()

transport_term.send_signal(signal.SIGTERM)
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
assert done == {task_term}
task_term.result()

transport_42.write_eof()
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
assert done == {task_42}
task_42.result()

transport_1.write_eof()
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
assert done == {task_1}
task_1.result()

transport_0.write_eof()
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
assert done == {task_0}
task_0.result()

assert not pending

@pytest.mark.asyncio
async def test_true_pty(self) -> None:
loop = asyncio.get_running_loop()
Expand Down