Skip to content

Commit 45f5280

Browse files
committed
feat(cli): run node inside shell command
1 parent a98679e commit 45f5280

File tree

2 files changed

+321
-8
lines changed

2 files changed

+321
-8
lines changed

hathor_cli/_shell_extension.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2025 Hathor Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
IPython extension that adapts logging handlers to play nicely with the interactive prompt.
17+
18+
When loaded, all stream-based logging handlers are updated so their output is rendered
19+
through prompt_toolkit's ``run_in_terminal`` helper, which ensures log lines appear
20+
above the current input without corrupting the prompt. The original streams are restored
21+
when the extension is unloaded.
22+
"""
23+
24+
from __future__ import annotations
25+
26+
import io
27+
import logging
28+
import threading
29+
from contextlib import suppress
30+
from typing import Any, Callable, Iterable
31+
32+
get_app_or_none: Callable[[], Any] | None = None
33+
pt_utils: Any | None = None
34+
35+
try:
36+
from prompt_toolkit.application import get_app_or_none as _get_app_or_none
37+
from prompt_toolkit.shortcuts import utils as _pt_utils
38+
except ImportError:
39+
pass
40+
else:
41+
get_app_or_none = _get_app_or_none
42+
pt_utils = _pt_utils
43+
44+
_original_streams: dict[logging.StreamHandler, Any] = {}
45+
_installed = False
46+
47+
48+
class PromptToolkitLogStream(io.TextIOBase):
49+
"""Proxy stream that forwards writes through prompt_toolkit."""
50+
51+
def __init__(self, inner: Any):
52+
super().__init__()
53+
self._inner = inner
54+
self._encoding_override: str | None = None
55+
self._errors_override: str | None = None
56+
57+
def _run_in_terminal(self, func: Callable[[], None]) -> None:
58+
if pt_utils is None:
59+
func()
60+
return
61+
62+
app = get_app_or_none() if get_app_or_none is not None else None
63+
if app is None:
64+
func()
65+
return
66+
loop = getattr(app, 'loop', None)
67+
if loop is None:
68+
func()
69+
return
70+
71+
event = threading.Event()
72+
handled = False
73+
74+
def run_and_signal() -> None:
75+
nonlocal handled
76+
try:
77+
if not handled:
78+
pt_utils.run_in_terminal(func, in_executor=False)
79+
finally:
80+
handled = True
81+
event.set()
82+
83+
loop.call_soon_threadsafe(run_and_signal)
84+
if not event.wait(timeout=5):
85+
handled = True
86+
func()
87+
88+
def write(self, data: str) -> int:
89+
if not data:
90+
return 0
91+
92+
def _write() -> None:
93+
self._inner.write(data)
94+
95+
self._run_in_terminal(_write)
96+
return len(data)
97+
98+
def flush(self) -> None:
99+
def _flush() -> None:
100+
self._inner.flush()
101+
102+
self._run_in_terminal(_flush)
103+
104+
@property
105+
def encoding(self) -> str:
106+
if self._encoding_override is not None:
107+
return self._encoding_override
108+
return getattr(self._inner, 'encoding', 'utf-8') or 'utf-8'
109+
110+
@encoding.setter
111+
def encoding(self, value: str) -> None:
112+
self._encoding_override = value
113+
114+
@property
115+
def errors(self) -> str:
116+
if self._errors_override is not None:
117+
return self._errors_override
118+
return getattr(self._inner, 'errors', 'strict') or 'strict'
119+
120+
@errors.setter
121+
def errors(self, value: str) -> None:
122+
self._errors_override = value
123+
124+
def fileno(self) -> int:
125+
if hasattr(self._inner, 'fileno') and callable(getattr(self._inner, 'fileno')):
126+
return self._inner.fileno()
127+
raise io.UnsupportedOperation('fileno not available')
128+
129+
def isatty(self) -> bool:
130+
if hasattr(self._inner, 'isatty') and callable(getattr(self._inner, 'isatty')):
131+
return self._inner.isatty()
132+
return False
133+
134+
def close(self) -> None:
135+
# Do not close the underlying stream.
136+
pass
137+
138+
@property
139+
def closed(self) -> bool:
140+
return False
141+
142+
def readable(self) -> bool:
143+
return False
144+
145+
def seekable(self) -> bool:
146+
return False
147+
148+
def writable(self) -> bool:
149+
return True
150+
151+
152+
def _iter_stream_handlers() -> Iterable[logging.StreamHandler]:
153+
"""Yield every stream handler currently registered."""
154+
root = logging.getLogger()
155+
for handler in root.handlers:
156+
if isinstance(handler, logging.StreamHandler):
157+
yield handler
158+
159+
for logger in logging.Logger.manager.loggerDict.values():
160+
if isinstance(logger, logging.PlaceHolder):
161+
continue
162+
if not isinstance(logger, logging.Logger):
163+
continue
164+
for handler in logger.handlers:
165+
if isinstance(handler, logging.StreamHandler):
166+
yield handler
167+
168+
169+
def _install_prompt_toolkit_streams() -> None:
170+
if pt_utils is None:
171+
return
172+
173+
for handler in _iter_stream_handlers():
174+
current_stream = getattr(handler, 'stream', None)
175+
if current_stream is None or isinstance(current_stream, PromptToolkitLogStream):
176+
continue
177+
178+
proxy = PromptToolkitLogStream(current_stream)
179+
_original_streams[handler] = current_stream
180+
handler.stream = proxy
181+
182+
183+
def load_ipython_extension(shell: Any) -> None:
184+
"""Called by IPython when the extension is loaded."""
185+
global _installed
186+
187+
if pt_utils is None:
188+
shell.write_err('prompt_toolkit not available; logs will use standard output.\n')
189+
return
190+
191+
if _installed:
192+
return
193+
194+
_install_prompt_toolkit_streams()
195+
_installed = True
196+
197+
198+
def unload_ipython_extension(shell: Any) -> None:
199+
"""Called by IPython when the extension is unloaded."""
200+
restore_logging_streams()
201+
202+
203+
def restore_logging_streams() -> None:
204+
"""Restore the original logging streams."""
205+
global _installed
206+
for handler, stream in list(_original_streams.items()):
207+
with suppress(Exception):
208+
handler.stream = stream
209+
_original_streams.clear()
210+
_installed = False

hathor_cli/shell.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,160 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import threading
1516
from argparse import Namespace
16-
from typing import Any, Callable
17+
from contextlib import suppress
18+
from typing import Any, Callable, TypeVar, cast
1719

1820
from hathor_cli.run_node import RunNode
1921

22+
T = TypeVar('T')
2023

21-
def get_ipython(extra_args: list[Any], imported_objects: dict[str, Any]) -> Callable[[], None]:
24+
25+
def get_ipython(
26+
extra_args: list[Any],
27+
imported_objects: dict[str, Any],
28+
*,
29+
config: Any | None = None,
30+
) -> Callable[[], None]:
2231
from IPython import start_ipython
2332

2433
def run_ipython():
25-
start_ipython(argv=extra_args, user_ns=imported_objects)
34+
start_ipython(argv=extra_args, user_ns=imported_objects, config=config)
2635

2736
return run_ipython
2837

2938

3039
class Shell(RunNode):
40+
_reactor_thread: threading.Thread | None = None
41+
_shell_run_node: bool = False
42+
43+
@classmethod
44+
def create_parser(cls):
45+
parser = super().create_parser()
46+
parser.add_argument(
47+
'--x-run-node',
48+
action='store_true',
49+
help='Start the full node in the background while keeping the interactive shell open.'
50+
)
51+
return parser
52+
3153
def start_manager(self) -> None:
32-
pass
54+
if not self._shell_run_node:
55+
return
56+
57+
super().start_manager()
58+
self._start_reactor_thread()
3359

3460
def register_signal_handlers(self) -> None:
3561
pass
3662

3763
def prepare(self, *, register_resources: bool = True) -> None:
38-
super().prepare(register_resources=False)
64+
super().prepare(register_resources=self._shell_run_node)
3965

4066
imported_objects: dict[str, Any] = {}
4167
imported_objects['tx_storage'] = self.tx_storage
4268
if self._args.wallet:
4369
imported_objects['wallet'] = self.wallet
4470
imported_objects['manager'] = self.manager
45-
self.shell = get_ipython(self.extra_args, imported_objects)
71+
imported_objects['reactor'] = self.reactor
72+
ipy_config: Any | None = None
73+
74+
if self._shell_run_node:
75+
import asyncio
76+
from twisted.internet.defer import Deferred
77+
from traitlets.config import Config
78+
79+
async def await_deferred(deferred: Deferred[T]) -> T:
80+
loop = asyncio.get_running_loop()
81+
return await deferred.asFuture(loop)
82+
83+
imported_objects['await_deferred'] = await_deferred
84+
imported_objects['asyncio'] = asyncio
85+
ipy_config = Config()
86+
ipy_config.InteractiveShellApp.extra_extensions = ['hathor_cli._shell_extension']
87+
88+
self.shell = get_ipython(self.extra_args, imported_objects, config=ipy_config)
4689

4790
print()
4891
print('--- Injected globals ---')
4992
for name, obj in imported_objects.items():
5093
print(name, obj)
5194
print('------------------------')
5295
print()
96+
if self._shell_run_node:
97+
print('Node reactor started in background. Use await_deferred() for Deferreds.')
5398

5499
def parse_args(self, argv: list[str]) -> Namespace:
55100
# TODO: add help for the `--` extra argument separator
101+
argv = list(argv)
56102
extra_args: list[str] = []
57103
if '--' in argv:
58104
idx = argv.index('--')
59105
extra_args = argv[idx + 1:]
60106
argv = argv[:idx]
61107
self.extra_args = extra_args
62-
return self.parser.parse_args(argv)
108+
namespace = self.parser.parse_args(argv)
109+
self._shell_run_node = bool(getattr(namespace, 'x_run_node', False))
110+
return namespace
63111

64112
def run(self) -> None:
65-
self.shell()
113+
try:
114+
self.shell()
115+
finally:
116+
if self._shell_run_node:
117+
self._shutdown_background()
118+
119+
def _start_reactor_thread(self) -> None:
120+
if self._reactor_thread and self._reactor_thread.is_alive():
121+
return
122+
123+
def run_reactor() -> None:
124+
self.log.info('reactor thread starting')
125+
try:
126+
run = getattr(self.reactor, 'run')
127+
try:
128+
run(installSignalHandlers=False)
129+
except TypeError:
130+
run()
131+
finally:
132+
self.log.info('reactor thread finished')
133+
134+
self._reactor_thread = threading.Thread(
135+
target=run_reactor,
136+
name='hathor-reactor',
137+
daemon=True,
138+
)
139+
self._reactor_thread.start()
140+
141+
def _shutdown_background(self) -> None:
142+
thread = self._reactor_thread
143+
if thread and thread.is_alive():
144+
try:
145+
from twisted.internet.interfaces import IReactorFromThreads
146+
147+
threaded_reactor = cast(IReactorFromThreads, self.reactor)
148+
threaded_reactor.callFromThread(self.reactor.stop)
149+
except Exception:
150+
self.log.exception('failed to schedule reactor shutdown from shell')
151+
thread.join(timeout=30)
152+
if thread.is_alive():
153+
self.log.warning('reactor thread did not finish cleanly')
154+
self._reactor_thread = None
155+
if self._shell_run_node:
156+
self._restore_logging_streams()
157+
158+
def _restore_logging_streams(self) -> None:
159+
try:
160+
from hathor_cli import _shell_extension
161+
except ImportError:
162+
return
163+
with suppress(Exception):
164+
_shell_extension.restore_logging_streams()
165+
166+
def __del__(self):
167+
with suppress(Exception):
168+
self._restore_logging_streams()
66169

67170

68171
def main():

0 commit comments

Comments
 (0)