Skip to content

Commit 9b49818

Browse files
committed
Implement Endpoint manager hot restart (#1867)
This implements a restart-in-place mechanism via `exec()`. The logic serializes and re-establishes state: - The child processes do not lose their parent; the new endpoint instance maintains the exact same PID - In flight audit records are protected, and the new instance setups writing to the MEP file. A hot restart will be visible in the audit log file as the MEP shutting down and restarting. - A log record is emitted to the "normal" logs indicating that hot restart is happening - Until we have a proper MEP pid file, the only access to hot-restart an MEP is to send the process SIGHUP. [sc-40933]
1 parent 8e6d96a commit 9b49818

File tree

5 files changed

+330
-15
lines changed

5 files changed

+330
-15
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
New Functionality
2+
^^^^^^^^^^^^^^^^^
3+
4+
- Implement hot-restart functionality for Multi-user endpoint. See
5+
:ref:`hot-restart` for full documentation, but the synopsis is send the
6+
``SIGHUP`` signal to the MEP (parent) process. Currently, there is no
7+
equivalent built-in sub-command to ``globus-compute-endpoint``.

compute_endpoint/globus_compute_endpoint/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ def _do_start_endpoint(
577577
reg_info = {}
578578
config_str: str | None = None
579579
audit_fd: int | None = None
580+
restart_fd: int | None = None
580581
fn_allow_list: list[str] | None | int = _no_fn_list_canary
581582
if sys.stdin and not (sys.stdin.closed or sys.stdin.isatty()):
582583
try:
@@ -593,6 +594,7 @@ def _do_start_endpoint(
593594
reg_info = stdin_data.get("amqp_creds", {})
594595
config_str = stdin_data.get("config")
595596
audit_fd = stdin_data.get("audit_fd")
597+
restart_fd = stdin_data.get("restart_fd")
596598
fn_allow_list = stdin_data.get("allowed_functions", _no_fn_list_canary)
597599

598600
del stdin_data # clarity for intended scope
@@ -639,7 +641,10 @@ def _do_start_endpoint(
639641
raise ClickException(
640642
"multi-user endpoints are not supported on this system"
641643
)
644+
642645
epm = EndpointManager(ep_dir, endpoint_uuid, ep_config, reg_info)
646+
if restart_fd:
647+
epm._finish_hot_restart(restart_fd)
643648
epm.start()
644649
else:
645650
assert isinstance(ep_config, UserEndpointConfig)

compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import os
88
import pathlib
9+
import pickle
910
import platform
1011
import pwd
1112
import queue
@@ -157,8 +158,8 @@ def __init__(
157158
else:
158159
_import_pyprctl()
159160

160-
self._reload_requested = False
161161
self._time_to_stop = False
162+
self._restart = False
162163

163164
self._heartbeat_period: float = max(MINIMUM_HEARTBEAT, config.heartbeat_period)
164165

@@ -174,7 +175,7 @@ def __init__(
174175
self._cached_cmd_start_args: TTLCache[int, T_CMD_START_ARGS] = TTLCache(
175176
maxsize=32768, ttl=config.mu_child_ep_grace_period_s
176177
)
177-
self._audit_pipes: dict[int, t.Any] = {}
178+
self._audit_pipes: dict[int, dict[str, int | str]] = {}
178179
self._audit_log_handler_stop = not (
179180
self._config.high_assurance and bool(self._config.audit_log_path)
180181
)
@@ -372,6 +373,9 @@ def get_metadata(self, config: ManagerEndpointConfig) -> dict:
372373
"user_config_schema": user_config_schema,
373374
}
374375

376+
def request_restart(self, sig_num, curr_stack_frame):
377+
self._restart = True
378+
375379
def request_shutdown(self, sig_num, curr_stack_frame):
376380
self._time_to_stop = True
377381

@@ -488,12 +492,13 @@ def _audit_log_write(self, fd: int, fpath: io.BytesIO):
488492
uid = uep_audit_info.get("uid")
489493
eid = uep_audit_info.get("endpoint_id")
490494
try:
491-
msg = (
492-
os.read(fd, self._audit_buf_size)
493-
.replace(b"\n", b" ")
494-
.replace(b"\r", b"")
495-
.replace(b"\0", b"")
496-
)
495+
with self._audit_log_lock:
496+
msg = (
497+
os.read(fd, self._audit_buf_size)
498+
.replace(b"\n", b" ")
499+
.replace(b"\r", b"")
500+
.replace(b"\0", b"")
501+
)
497502
if not msg:
498503
self._audit_log_close_reader(fd)
499504
return
@@ -511,6 +516,7 @@ def _audit_log_write(self, fd: int, fpath: io.BytesIO):
511516
log.error(f"Failed to write audit log message: [{uid=}, {eid=}] - {e_str}")
512517

513518
def _install_signal_handlers(self):
519+
signal.signal(signal.SIGHUP, self.request_restart)
514520
signal.signal(signal.SIGTERM, self.request_shutdown)
515521
signal.signal(signal.SIGINT, self.request_shutdown)
516522
signal.signal(signal.SIGQUIT, self.request_shutdown)
@@ -629,6 +635,74 @@ def start(self):
629635
# re-enable cursor visibility
630636
print("\033[?25h", end="", file=msg_out)
631637

638+
def hot_restart(self):
639+
log.info("Manager hot hot_restart requested")
640+
r_fd = os.memfd_create("hot_restart", flags=0) # 0 == *not* CLOEXEC
641+
642+
stdin_data = {
643+
"amqp_creds": {
644+
"endpoint_id": self._endpoint_uuid_str,
645+
"command_queue_info": self._command.queue_info,
646+
"heartbeat_queue_info": self._heartbeat_publisher.queue_info,
647+
},
648+
"restart_fd": r_fd,
649+
}
650+
self._command_stop_event.set()
651+
self._heartbeat_publisher.stop()
652+
self._command.join()
653+
654+
r, w = os.pipe()
655+
os.dup2(r, 0)
656+
os.write(w, json.dumps(stdin_data).encode())
657+
os.close(w)
658+
os.close(r)
659+
660+
with self._audit_log_lock:
661+
if not self._audit_log_handler_stop:
662+
nowtz = datetime.now().astimezone().isoformat()
663+
uid = os.getuid()
664+
pid = os.getpid()
665+
eid = self._endpoint_uuid_str
666+
msg = (
667+
f"{nowtz} uid={uid} pid={pid} eid={eid} End MEP session"
668+
f" [hot restart] .....\n"
669+
)
670+
with open(self._config.audit_log_path, "ab", buffering=0) as audit_f:
671+
audit_f.write(msg.encode())
672+
673+
# only thread of consequence that we block; will be restarted in new exec();
674+
# AMQP will resend any interim received tasks because we won't ACK them.
675+
state = {
676+
"_audit_pipes": self._audit_pipes,
677+
"_children": self._children,
678+
"_cached_cmd_start_args": self._cached_cmd_start_args,
679+
}
680+
os.write(r_fd, pickle.dumps(state))
681+
os.fsync(r_fd)
682+
os.lseek(r_fd, 0, os.SEEK_SET)
683+
args = [sys.executable, *sys.argv]
684+
685+
num_children = len(self._children)
686+
log.info(
687+
f"\n.......... Manager hot restarting {self._endpoint_uuid_str}"
688+
f" (task processors: {num_children})\n"
689+
)
690+
os.execvpe(args[0], args=args, env=os.environ)
691+
692+
def _finish_hot_restart(self, fd: int):
693+
with os.fdopen(fd, "rb") as f:
694+
restart_data: dict = pickle.loads(f.read())
695+
696+
self._audit_pipes.update(restart_data.get("_audit_pipes", {}))
697+
self._children.update(restart_data.get("_children", {}))
698+
self._cached_cmd_start_args.update(
699+
restart_data.get("_cached_cmd_start_args", {})
700+
)
701+
for audit_r in self._audit_pipes:
702+
self._audit_selector.register(
703+
audit_r, selectors.EVENT_READ, self._audit_log_write
704+
)
705+
632706
def _event_loop(self):
633707
parent_identities: set[str] = set()
634708
if not is_privileged():
@@ -668,6 +742,11 @@ def _event_loop(self):
668742
if self._wait_for_child:
669743
self.wait_for_children()
670744

745+
if self._restart:
746+
# not protected; if exec() fails, then this raises and we shutdown
747+
# ... "Failure is not an option!"
748+
self.hot_restart()
749+
671750
if time.monotonic() - last_heartbeat >= self._heartbeat_period:
672751
self.send_heartbeat()
673752
last_heartbeat = time.monotonic()

compute_endpoint/tests/unit/test_endpointmanager_unit.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import logging
55
import os
66
import pathlib
7+
import pickle
78
import pwd
89
import queue
910
import random
1011
import re
1112
import resource
13+
import selectors
1214
import signal
1315
import sys
1416
import time
@@ -49,6 +51,7 @@
4951
EndpointManager,
5052
InvalidUserError,
5153
MappedPosixIdentity,
54+
UserEndpointRecord,
5255
)
5356

5457

@@ -267,6 +270,7 @@ def epmanager_as_root(
267270
mock_os.pipe.return_value = 40, 41
268271
mock_os.dup2.side_effect = (0, 1, 2, AssertionError("dup2: unexpected?"))
269272
mock_os.open.side_effect = (4, 5, AssertionError("open: unexpected?"))
273+
mock_os.memfd_create.return_value = random.randint(50, 10000)
270274

271275
mock_pwd = mocker.patch(f"{_MOCK_BASE}pwd")
272276
mock_pwd.getpwnam.side_effect = (
@@ -295,8 +299,8 @@ def epmanager_as_root(
295299
mock_auth_client.userinfo.return_value = {"identity_set": [{"sub": ident}]}
296300

297301
em = EndpointManager(conf_dir, ep_uuid, mock_conf_root)
298-
em._command = mock.Mock(spec=CommandQueueSubscriber)
299-
em._heartbeat_publisher = mock.Mock(spec=ResultPublisher)
302+
em._command = mock.Mock(spec=CommandQueueSubscriber, queue_info={})
303+
em._heartbeat_publisher = mock.Mock(spec=ResultPublisher, queue_info={})
300304

301305
yield conf_dir, mock_conf_root, mock_client, mock_os, mock_pwd, em
302306
if em.identity_mapper:
@@ -2543,3 +2547,143 @@ def _called(fn_name):
25432547

25442548
assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'"
25452549
assert pamh.pam_close_session.called
2550+
2551+
2552+
def test_restart_signal(successful_exec_from_mocked_root, reset_signals):
2553+
mock_os, *_, em = successful_exec_from_mocked_root
2554+
2555+
em.hot_restart = mock.Mock(side_effect=MemoryError)
2556+
em._install_signal_handlers()
2557+
assert not em._restart, "Verify test setup"
2558+
os.kill(os.getpid(), signal.SIGHUP)
2559+
2560+
with pytest.raises(MemoryError):
2561+
em._event_loop()
2562+
2563+
assert em._restart, "Ensure class state, but main thing is .hot_restart() invoked"
2564+
2565+
2566+
def test_restart_restarts(successful_exec_from_mocked_root, randomstring):
2567+
mock_os, *_, em = successful_exec_from_mocked_root
2568+
2569+
canary = randomstring()
2570+
mock_os.environ = {"canary": canary}
2571+
2572+
em.hot_restart()
2573+
2574+
assert mock_os.execvpe.called, "Basic correctness"
2575+
a, k = mock_os.execvpe.call_args
2576+
exp_args = [sys.executable, *sys.argv]
2577+
assert (exp_args[0],) == a, "Expect repeat of initial args"
2578+
assert k["args"] == exp_args, "Expect repeat of initial args"
2579+
assert k["env"]["canary"] == canary, "Expect to relay environment variables"
2580+
2581+
2582+
def test_restart_conveys_state(successful_exec_from_mocked_root, randomstring):
2583+
mock_os, *_, em = successful_exec_from_mocked_root
2584+
2585+
em._audit_pipes[123] = {"pid": random.randint(1, 1000000)}
2586+
em._children[123] = UserEndpointRecord(ep_name="abc", arguments="some_args")
2587+
em._cached_cmd_start_args[123] = randomstring()
2588+
em._command.queue_info = {"canary": randomstring()}
2589+
em._heartbeat_publisher.queue_info = {"canary": randomstring()}
2590+
em.hot_restart()
2591+
2592+
assert mock_os.execvpe.called, "Basic correctness"
2593+
assert mock_os.write.call_count == 2, "Verify test setup, expected writes"
2594+
2595+
pipe_r, pipe_w = mock_os.pipe.return_value
2596+
(stdin_fd, stdin_bytes), _ = mock_os.write.call_args_list[0]
2597+
(mem_fd, conveyed), _ = mock_os.write.call_args_list[1]
2598+
2599+
assert stdin_fd == pipe_w, "Expect write to new proc stdin"
2600+
mock_os.dup2.assert_called_with(pipe_r, 0), "Expect write to new proc stdin"
2601+
stdin = json.loads(stdin_bytes)
2602+
creds = stdin.get("amqp_creds")
2603+
assert creds, "Expect reconnection credentials; no need to relogin"
2604+
assert creds["endpoint_id"] == em._endpoint_uuid_str
2605+
assert creds["command_queue_info"] == em._command.queue_info
2606+
assert creds["heartbeat_queue_info"] == em._heartbeat_publisher.queue_info
2607+
assert stdin.get("restart_fd") == mem_fd, "Hot restarted requires a state file"
2608+
2609+
assert mem_fd == mock_os.memfd_create.return_value, "Should write *anonymous* file"
2610+
2611+
state = pickle.loads(conveyed)
2612+
assert state["_audit_pipes"] == em._audit_pipes
2613+
assert state["_children"] == em._children
2614+
assert state["_cached_cmd_start_args"] == em._cached_cmd_start_args
2615+
2616+
2617+
def test_restart_repopulates_state(successful_exec_from_mocked_root, randomstring):
2618+
mock_os, *_, em = successful_exec_from_mocked_root
2619+
2620+
canary = randomstring()
2621+
audit_pipes = {123: {"pid": random.randint(1, 1000000)}}
2622+
children = {123: UserEndpointRecord(ep_name="abc", arguments="some_args")}
2623+
cached_args = {123: randomstring()}
2624+
em._audit_selector = mock.Mock(spec=selectors.DefaultSelector)
2625+
em._audit_pipes = audit_pipes
2626+
em._children = children
2627+
em._cached_cmd_start_args = cached_args
2628+
2629+
em.hot_restart()
2630+
em._audit_pipes = {10000: canary}
2631+
em._children = {10000: canary}
2632+
em._cached_cmd_start_args = {10000: canary}
2633+
2634+
(mem_fd, conveyed), _ = mock_os.write.call_args_list[1]
2635+
mem_f = io.BytesIO(conveyed)
2636+
mem_f.seek(0)
2637+
mock_os.fdopen.return_value = mem_f
2638+
2639+
em._finish_hot_restart(mem_fd)
2640+
mock_os.fdopen.assert_called_with(mem_fd, "rb"), "Expect passed fd opened"
2641+
assert em._audit_pipes[10000] == canary, "Expect updated, not overwritten"
2642+
assert em._children[10000] == canary, "Expect updated, not overwritten"
2643+
assert em._cached_cmd_start_args[10000] == canary, "Expect updated, not overwritten"
2644+
del em._audit_pipes[10000], em._children[10000], em._cached_cmd_start_args[10000]
2645+
2646+
assert em._audit_pipes == audit_pipes
2647+
assert em._children == children
2648+
assert em._cached_cmd_start_args == cached_args
2649+
2650+
all_args = {
2651+
fd: (evt, cb) for (fd, evt, cb), _ in em._audit_selector.register.call_args_list
2652+
}
2653+
2654+
exp_args = (selectors.EVENT_READ, em._audit_log_write)
2655+
for audit_fd in em._audit_pipes:
2656+
assert all_args[audit_fd] == exp_args, "Expect reregistration of audit pipes"
2657+
2658+
2659+
def test_restart_audit_pipes_protected(successful_exec_from_mocked_root):
2660+
mock_os, *_, em = successful_exec_from_mocked_root
2661+
2662+
em._audit_pipes[123] = {"pid": 1235}
2663+
em._audit_log_lock = mock.MagicMock()
2664+
2665+
def lock_test(*a, **k):
2666+
assert em._audit_log_lock.__enter__.called
2667+
assert not em._audit_log_lock.__exit__.called, "Expect locked at during call"
2668+
return b"some audit bytes"
2669+
2670+
mock_os.execvpe.side_effect = lock_test
2671+
em.hot_restart()
2672+
assert em._audit_log_lock.__enter__.called, "Verify test setup"
2673+
2674+
mock_os.read.side_effect = lock_test
2675+
em._audit_log_lock.reset_mock()
2676+
em._audit_log_write(123, mock.Mock())
2677+
assert em._audit_log_lock.__enter__.called, "Verify test setup"
2678+
2679+
2680+
def test_restart_logs(successful_exec_from_mocked_root, mock_log):
2681+
mock_os, *_, em = successful_exec_from_mocked_root
2682+
2683+
em.hot_restart()
2684+
2685+
i_logs = "\n".join(f"{a}" for (a,), k in mock_log.info.call_args_list)
2686+
2687+
assert "hot hot_restart requested" in i_logs, "Expect initial signal acknowledged"
2688+
assert ".......... Manager hot restarting" in i_logs, "Expect last message"
2689+
assert " (task processors: 0)" in i_logs, "Expect friendly count for admin"

0 commit comments

Comments
 (0)