Skip to content

Commit 7384ac5

Browse files
committed
send: Allow cancellation with Ctrl+c
1 parent 4eb05e7 commit 7384ac5

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

src/gptcmd/cli.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import argparse
1111
import cmd
12+
import concurrent.futures
1213
import dataclasses
1314
import json
1415
import os
@@ -76,6 +77,9 @@ def __init__(
7677
self._threads = {}
7778
self._session_cost_in_cents = 0
7879
self._session_cost_incomplete = False
80+
self._future_executor = concurrent.futures.ThreadPoolExecutor(
81+
max_workers=1
82+
)
7983
super().__init__(*args, **kwargs)
8084

8185
@property
@@ -183,6 +187,22 @@ def _shlex_path(path: str) -> List[str]:
183187
lexer.whitespace_split = True
184188
return list(lexer)
185189

190+
@staticmethod
191+
def _await_future_interruptible(
192+
future: concurrent.futures.Future, interval: float = 0.25
193+
):
194+
"""
195+
Block until the future finishes, waking up
196+
at the supplied interval so the main thread can raise
197+
interrupts immediately.
198+
Returns future.result().
199+
"""
200+
while True:
201+
try:
202+
return future.result(timeout=interval)
203+
except concurrent.futures.TimeoutError:
204+
continue
205+
186206
KNOWN_ROLES = tuple(MessageRole)
187207

188208
@classmethod
@@ -374,13 +394,24 @@ def do_send(self, arg):
374394
This command takes no arguments.
375395
"""
376396
print("...")
397+
# Run the potentially long-running provider call in a background
398+
# thread so Ctrl+c can interrupt immediately.
399+
future = self._future_executor.submit(
400+
self._account.provider.complete, self._current_thread
401+
)
402+
377403
try:
378-
res = self._account.provider.complete(self._current_thread)
404+
res = self.__class__._await_future_interruptible(future)
379405
except KeyboardInterrupt:
406+
future.cancel()
407+
print("\nCancelled")
408+
# This API request may have incurred cost
409+
self._session_cost_incomplete = True
380410
return
381411
except (CompletionError, NotImplementedError, ValueError) as e:
382412
print(str(e))
383413
return
414+
384415
try:
385416
for chunk in res:
386417
print(chunk, end="")
@@ -1246,6 +1277,8 @@ def do_quit(self, arg):
12461277
)
12471278
else:
12481279
can_exit = True
1280+
if can_exit:
1281+
self._future_executor.shutdown(wait=False)
12491282
return can_exit # Truthy return values cause the cmdloop to stop
12501283

12511284

0 commit comments

Comments
 (0)