|
1 | | -#!/usr/bin/env python3 |
| 1 | +#!/usr/local/bin/python3-login |
2 | 2 | # note: must run on Python >= 3.5, which mainly means no f-strings |
3 | 3 |
|
4 | 4 | # goals: |
5 | 5 | # - load environment variables from a login shell (bash -l) |
6 | 6 | # - preserve signal handling of subprocess (kill -TERM and friends) |
7 | 7 | # - tee output to a log file |
8 | 8 |
|
9 | | -import json |
| 9 | +import fcntl |
10 | 10 | import os |
| 11 | +import select |
11 | 12 | import signal |
12 | 13 | import subprocess |
13 | 14 | import sys |
14 | | -import time |
15 | | - |
16 | | - |
17 | | -def get_login_env(): |
18 | | - """Instantiate a login shell to retrieve environment variables |
19 | | -
|
20 | | - Serialize with Python to ensure proper escapes |
21 | | - """ |
22 | | - p = subprocess.run( |
23 | | - [ |
24 | | - "bash", |
25 | | - "-l", |
26 | | - "-c", |
27 | | - "python3 -c 'import os, json; print(json.dumps(dict(os.environ)))'", |
28 | | - ], |
29 | | - stdout=subprocess.PIPE, |
30 | | - ) |
31 | | - if p.returncode: |
32 | | - print("Error getting login env") |
33 | | - return {} |
34 | | - |
35 | | - last_line = p.stdout.splitlines()[-1] |
36 | | - try: |
37 | | - return json.loads(last_line) |
38 | | - except Exception as e: |
39 | | - print("Error getting login env: {e}".format(e=e), file=sys.stderr) |
40 | | - return {} |
41 | | - |
42 | | - |
43 | | -def monitor_parent(parent_pid, child_pgid): |
44 | | - """Monitor parent_pid and shutdown child_pgid if parent goes away first""" |
45 | | - while True: |
46 | | - try: |
47 | | - os.kill(parent_pid, 0) |
48 | | - except ProcessLookupError: |
49 | | - # parent is gone, likely by SIGKILL |
50 | | - # send SIGKILL to child process group |
51 | | - try: |
52 | | - os.killpg(child_pgid, signal.SIGKILL) |
53 | | - except (ProcessLookupError, PermissionError): |
54 | | - # ignore if the child is already gone |
55 | | - pass |
56 | | - return |
57 | | - else: |
58 | | - time.sleep(1) |
59 | 15 |
|
| 16 | +# output chunk size to read |
| 17 | +CHUNK_SIZE = 1024 |
60 | 18 |
|
61 | 19 | # signals to be forwarded to the child |
62 | | -SIGNALS = [ |
63 | | - signal.SIGHUP, |
64 | | - signal.SIGINT, |
65 | | - # signal.SIGKILL, |
66 | | - signal.SIGQUIT, |
67 | | - signal.SIGTERM, |
68 | | - signal.SIGUSR1, |
69 | | - signal.SIGUSR2, |
70 | | - signal.SIGWINCH, |
71 | | -] |
| 20 | +# everything catchable, excluding SIGCHLD |
| 21 | +SIGNALS = set(signal.Signals) - {signal.SIGKILL, signal.SIGSTOP, signal.SIGCHLD} |
72 | 22 |
|
73 | 23 |
|
74 | 24 | def main(): |
75 | 25 |
|
76 | | - # load login shell environment |
77 | | - login_env = get_login_env() |
78 | | - env = os.environ.copy() |
79 | | - env.update(login_env) |
80 | | - |
81 | 26 | # open log file to send output |
82 | 27 | log_file = open( |
83 | 28 | os.path.join(os.environ.get("REPO_DIR", "."), ".jupyter-server-log.txt"), |
84 | | - "a", |
| 29 | + "ab", |
85 | 30 | ) |
86 | 31 |
|
| 32 | + # build the command |
| 33 | + # like `exec "$@"` |
87 | 34 | command = sys.argv[1:] |
| 35 | + # load entrypoint override from env |
88 | 36 | r2d_entrypoint = os.environ.get("R2D_ENTRYPOINT") |
89 | 37 | if r2d_entrypoint: |
90 | 38 | command.insert(0, r2d_entrypoint) |
91 | 39 |
|
| 40 | + # launch the subprocess |
92 | 41 | child = subprocess.Popen( |
93 | 42 | command, |
94 | 43 | bufsize=1, |
95 | | - env=env, |
96 | | - start_new_session=True, |
97 | 44 | stdout=subprocess.PIPE, |
98 | 45 | stderr=subprocess.STDOUT, |
99 | | - universal_newlines=True, |
100 | 46 | ) |
101 | | - child_pgid = os.getpgid(child.pid) |
102 | | - |
103 | | - # if parent is forcefully shutdown, |
104 | | - # make sure child shuts down immediately as well |
105 | | - parent_pid = os.getpid() |
106 | | - |
107 | | - monitor_pid = os.fork() |
108 | | - if monitor_pid == 0: |
109 | | - # child process, sibling of 'real' command |
110 | | - # avoid receiving signals sent to parent |
111 | | - os.setpgrp() |
112 | | - # terminate child if parent goes away, |
113 | | - # e.g. in ungraceful KILL not relayed to children |
114 | | - monitor_parent(parent_pid, child_pgid) |
115 | | - return |
116 | 47 |
|
117 | 48 | # hook up ~all signals so that every signal the parent gets, |
118 | 49 | # the children also get |
119 | 50 |
|
120 | 51 | def relay_signal(sig, frame): |
121 | 52 | """Relay a signal to children""" |
122 | | - print( |
123 | | - "Forwarding signal {sig} to {child_pgid}".format( |
124 | | - sig=sig, child_pgid=child_pgid |
125 | | - ) |
126 | | - ) |
127 | | - os.killpg(child_pgid, sig) |
128 | | - |
129 | | - # question: maybe use all valid_signals() except a few, e.g. SIGCHLD? |
130 | | - # rather than opt-in list |
| 53 | + # DEBUG: show signal |
| 54 | + child.send_signal(sig) |
| 55 | + |
131 | 56 | for signum in SIGNALS: |
132 | 57 | signal.signal(signum, relay_signal) |
133 | 58 |
|
134 | 59 | # tee output from child to both our stdout and the log file |
135 | 60 | def tee(chunk): |
136 | | - for f in [sys.stdout, log_file]: |
| 61 | + """Tee output from child to both our stdout and the log file""" |
| 62 | + for f in [sys.stdout.buffer, log_file]: |
137 | 63 | f.write(chunk) |
138 | 64 | f.flush() |
139 | 65 |
|
| 66 | + # make stdout pipe non-blocking |
| 67 | + # this means child.stdout.read(nbytes) |
| 68 | + # will always return immediately, even if there's nothing to read |
| 69 | + flags = fcntl.fcntl(child.stdout, fcntl.F_GETFL) |
| 70 | + fcntl.fcntl(child.stdout, fcntl.F_SETFL, flags | os.O_NONBLOCK) |
| 71 | + poller = select.poll() |
| 72 | + poller.register(child.stdout) |
| 73 | + |
| 74 | + # while child is running, constantly relay output |
140 | 75 | while child.poll() is None: |
141 | | - tee(child.stdout.readline()) |
| 76 | + chunk = child.stdout.read(CHUNK_SIZE) |
| 77 | + if chunk: |
| 78 | + tee(chunk) |
| 79 | + else: |
| 80 | + # empty chunk means nothing to read |
| 81 | + # wait for output on the pipe |
| 82 | + # timeout is in milliseconds |
| 83 | + poller.poll(1000) |
142 | 84 |
|
143 | | - # flush the rest |
| 85 | + # child has exited, continue relaying any remaining output |
| 86 | + # At this point, read() will return an empty string when it's done |
144 | 87 | chunk = child.stdout.read() |
145 | 88 | while chunk: |
146 | 89 | tee(chunk) |
147 | 90 | chunk = child.stdout.read() |
148 | 91 |
|
149 | | - # child exited, cleanup monitor |
150 | | - try: |
151 | | - os.kill(monitor_pid, signal.SIGKILL) |
152 | | - except ProcessLookupError: |
153 | | - pass |
154 | | - |
155 | | - # preserve returncode |
| 92 | + # make our returncode match the child's returncode |
156 | 93 | sys.exit(child.returncode) |
157 | 94 |
|
158 | 95 |
|
|
0 commit comments