Skip to content

Commit 28d508d

Browse files
committed
feat(cli): run node inside shell command
1 parent cf5fc97 commit 28d508d

File tree

2 files changed

+307
-8
lines changed

2 files changed

+307
-8
lines changed

hathor_cli/_shell_extension.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
IPython extension that adapts logging handlers to play nicely with the interactive prompt.
3+
4+
When loaded, all stream-based logging handlers are updated so their output is rendered
5+
through prompt_toolkit's ``run_in_terminal`` helper, which ensures log lines appear
6+
above the current input without corrupting the prompt. The original streams are restored
7+
when the extension is unloaded.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import io
13+
import logging
14+
import threading
15+
from contextlib import suppress
16+
from typing import Any, Callable, Iterable
17+
18+
get_app_or_none: Callable[[], Any] | None = None
19+
pt_utils: Any | None = None
20+
21+
try: # pragma: no cover - IPython dependency
22+
from prompt_toolkit.application import get_app_or_none as _get_app_or_none
23+
from prompt_toolkit.shortcuts import utils as _pt_utils
24+
except ImportError: # pragma: no cover - handled at runtime
25+
pass
26+
else:
27+
get_app_or_none = _get_app_or_none
28+
pt_utils = _pt_utils
29+
30+
_original_streams: dict[logging.StreamHandler, Any] = {}
31+
_installed = False
32+
33+
34+
class PromptToolkitLogStream(io.TextIOBase):
35+
"""Proxy stream that forwards writes through prompt_toolkit."""
36+
37+
def __init__(self, inner: Any):
38+
super().__init__()
39+
self._inner = inner
40+
self._encoding_override: str | None = None
41+
self._errors_override: str | None = None
42+
43+
def _run_in_terminal(self, func: Callable[[], None]) -> None:
44+
if pt_utils is None:
45+
func()
46+
return
47+
48+
app = get_app_or_none() if get_app_or_none is not None else None
49+
if app is None:
50+
func()
51+
return
52+
loop = getattr(app, 'loop', None)
53+
if loop is None:
54+
func()
55+
return
56+
57+
event = threading.Event()
58+
handled = False
59+
60+
def run_and_signal() -> None:
61+
nonlocal handled
62+
try:
63+
if not handled:
64+
pt_utils.run_in_terminal(func, in_executor=False)
65+
finally:
66+
handled = True
67+
event.set()
68+
69+
loop.call_soon_threadsafe(run_and_signal)
70+
if not event.wait(timeout=5):
71+
handled = True
72+
func()
73+
74+
def write(self, data: str) -> int:
75+
if not data:
76+
return 0
77+
78+
def _write() -> None:
79+
self._inner.write(data)
80+
81+
self._run_in_terminal(_write)
82+
return len(data)
83+
84+
def flush(self) -> None:
85+
def _flush() -> None:
86+
self._inner.flush()
87+
88+
self._run_in_terminal(_flush)
89+
90+
@property
91+
def encoding(self) -> str:
92+
if self._encoding_override is not None:
93+
return self._encoding_override
94+
return getattr(self._inner, 'encoding', 'utf-8') or 'utf-8'
95+
96+
@encoding.setter
97+
def encoding(self, value: str) -> None:
98+
self._encoding_override = value
99+
100+
@property
101+
def errors(self) -> str:
102+
if self._errors_override is not None:
103+
return self._errors_override
104+
return getattr(self._inner, 'errors', 'strict') or 'strict'
105+
106+
@errors.setter
107+
def errors(self, value: str) -> None:
108+
self._errors_override = value
109+
110+
def fileno(self) -> int:
111+
if hasattr(self._inner, 'fileno') and callable(getattr(self._inner, 'fileno')):
112+
return self._inner.fileno()
113+
raise io.UnsupportedOperation('fileno not available')
114+
115+
def isatty(self) -> bool:
116+
if hasattr(self._inner, 'isatty') and callable(getattr(self._inner, 'isatty')):
117+
return self._inner.isatty()
118+
return False
119+
120+
def close(self) -> None:
121+
# Do not close the underlying stream.
122+
pass
123+
124+
@property
125+
def closed(self) -> bool:
126+
return False
127+
128+
def readable(self) -> bool:
129+
return False
130+
131+
def seekable(self) -> bool:
132+
return False
133+
134+
def writable(self) -> bool:
135+
return True
136+
137+
138+
def _iter_stream_handlers() -> Iterable[logging.StreamHandler]:
139+
"""Yield every stream handler currently registered."""
140+
root = logging.getLogger()
141+
for handler in root.handlers:
142+
if isinstance(handler, logging.StreamHandler):
143+
yield handler
144+
145+
for logger in logging.Logger.manager.loggerDict.values():
146+
if isinstance(logger, logging.PlaceHolder):
147+
continue
148+
if not isinstance(logger, logging.Logger):
149+
continue
150+
for handler in logger.handlers:
151+
if isinstance(handler, logging.StreamHandler):
152+
yield handler
153+
154+
155+
def _install_prompt_toolkit_streams() -> None:
156+
if pt_utils is None:
157+
return
158+
159+
for handler in _iter_stream_handlers():
160+
current_stream = getattr(handler, 'stream', None)
161+
if current_stream is None or isinstance(current_stream, PromptToolkitLogStream):
162+
continue
163+
164+
proxy = PromptToolkitLogStream(current_stream)
165+
_original_streams[handler] = current_stream
166+
handler.stream = proxy
167+
168+
169+
def load_ipython_extension(shell: Any) -> None: # pragma: no cover - requires IPython
170+
"""Called by IPython when the extension is loaded."""
171+
global _installed
172+
173+
if pt_utils is None:
174+
shell.write_err('prompt_toolkit not available; logs will use standard output.\n')
175+
return
176+
177+
if _installed:
178+
return
179+
180+
_install_prompt_toolkit_streams()
181+
_installed = True
182+
183+
184+
def unload_ipython_extension(shell: Any) -> None: # pragma: no cover - requires IPython
185+
"""Called by IPython when the extension is unloaded."""
186+
restore_logging_streams()
187+
188+
189+
def restore_logging_streams() -> None:
190+
"""Restore the original logging streams."""
191+
global _installed
192+
for handler, stream in list(_original_streams.items()):
193+
with suppress(Exception):
194+
handler.stream = stream
195+
_original_streams.clear()
196+
_installed = False

hathor_cli/shell.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,160 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import threading
1516
from argparse import Namespace
16-
from typing import Any, Callable
17+
from contextlib import suppress
18+
from typing import Any, Callable, TypeVar, cast
1719

1820
from hathor_cli.run_node import RunNode
1921

22+
T = TypeVar('T')
2023

21-
def get_ipython(extra_args: list[Any], imported_objects: dict[str, Any]) -> Callable[[], None]:
24+
25+
def get_ipython(
26+
extra_args: list[Any],
27+
imported_objects: dict[str, Any],
28+
*,
29+
config: Any | None = None,
30+
) -> Callable[[], None]:
2231
from IPython import start_ipython
2332

2433
def run_ipython():
25-
start_ipython(argv=extra_args, user_ns=imported_objects)
34+
start_ipython(argv=extra_args, user_ns=imported_objects, config=config)
2635

2736
return run_ipython
2837

2938

3039
class Shell(RunNode):
40+
_reactor_thread: threading.Thread | None = None
41+
_shell_run_node: bool = False
42+
43+
@classmethod
44+
def create_parser(cls):
45+
parser = super().create_parser()
46+
parser.add_argument(
47+
'--x-run-node',
48+
action='store_true',
49+
help='Start the full node in the background while keeping the interactive shell open.'
50+
)
51+
return parser
52+
3153
def start_manager(self) -> None:
32-
pass
54+
if not self._shell_run_node:
55+
return
56+
57+
super().start_manager()
58+
self._start_reactor_thread()
3359

3460
def register_signal_handlers(self) -> None:
3561
pass
3662

3763
def prepare(self, *, register_resources: bool = True) -> None:
38-
super().prepare(register_resources=False)
64+
super().prepare(register_resources=self._shell_run_node)
3965

4066
imported_objects: dict[str, Any] = {}
4167
imported_objects['tx_storage'] = self.tx_storage
4268
if self._args.wallet:
4369
imported_objects['wallet'] = self.wallet
4470
imported_objects['manager'] = self.manager
45-
self.shell = get_ipython(self.extra_args, imported_objects)
71+
imported_objects['reactor'] = self.reactor
72+
ipy_config: Any | None = None
73+
74+
if self._shell_run_node:
75+
import asyncio
76+
from twisted.internet.defer import Deferred
77+
from traitlets.config import Config
78+
79+
async def await_deferred(deferred: Deferred[T]) -> T:
80+
loop = asyncio.get_running_loop()
81+
return await deferred.asFuture(loop)
82+
83+
imported_objects['await_deferred'] = await_deferred
84+
imported_objects['asyncio'] = asyncio
85+
ipy_config = Config()
86+
ipy_config.InteractiveShellApp.extra_extensions = ['hathor_cli._shell_extension']
87+
88+
self.shell = get_ipython(self.extra_args, imported_objects, config=ipy_config)
4689

4790
print()
4891
print('--- Injected globals ---')
4992
for name, obj in imported_objects.items():
5093
print(name, obj)
5194
print('------------------------')
5295
print()
96+
if self._shell_run_node:
97+
print('Node reactor started in background. Use await_deferred() for Deferreds.')
5398

5499
def parse_args(self, argv: list[str]) -> Namespace:
55100
# TODO: add help for the `--` extra argument separator
101+
argv = list(argv)
56102
extra_args: list[str] = []
57103
if '--' in argv:
58104
idx = argv.index('--')
59105
extra_args = argv[idx + 1:]
60106
argv = argv[:idx]
61107
self.extra_args = extra_args
62-
return self.parser.parse_args(argv)
108+
namespace = self.parser.parse_args(argv)
109+
self._shell_run_node = bool(getattr(namespace, 'x_run_node', False))
110+
return namespace
63111

64112
def run(self) -> None:
65-
self.shell()
113+
try:
114+
self.shell()
115+
finally:
116+
if self._shell_run_node:
117+
self._shutdown_background()
118+
119+
def _start_reactor_thread(self) -> None:
120+
if self._reactor_thread and self._reactor_thread.is_alive():
121+
return
122+
123+
def run_reactor() -> None:
124+
self.log.info('reactor thread starting')
125+
try:
126+
run = getattr(self.reactor, 'run')
127+
try:
128+
run(installSignalHandlers=False)
129+
except TypeError:
130+
run()
131+
finally:
132+
self.log.info('reactor thread finished')
133+
134+
self._reactor_thread = threading.Thread(
135+
target=run_reactor,
136+
name='hathor-reactor',
137+
daemon=True,
138+
)
139+
self._reactor_thread.start()
140+
141+
def _shutdown_background(self) -> None:
142+
thread = self._reactor_thread
143+
if thread and thread.is_alive():
144+
try:
145+
from twisted.internet.interfaces import IReactorFromThreads
146+
147+
threaded_reactor = cast(IReactorFromThreads, self.reactor)
148+
threaded_reactor.callFromThread(self.reactor.stop)
149+
except Exception:
150+
self.log.exception('failed to schedule reactor shutdown from shell')
151+
thread.join(timeout=30)
152+
if thread.is_alive():
153+
self.log.warning('reactor thread did not finish cleanly')
154+
self._reactor_thread = None
155+
if self._shell_run_node:
156+
self._restore_logging_streams()
157+
158+
def _restore_logging_streams(self) -> None:
159+
try:
160+
from hathor_cli import _shell_extension
161+
except ImportError:
162+
return
163+
with suppress(Exception):
164+
_shell_extension.restore_logging_streams()
165+
166+
def __del__(self):
167+
with suppress(Exception):
168+
self._restore_logging_streams()
66169

67170

68171
def main():

0 commit comments

Comments
 (0)