|
9 | 9 |
|
10 | 10 | import argparse
|
11 | 11 | import cmd
|
| 12 | +import concurrent.futures |
12 | 13 | import dataclasses
|
13 | 14 | import json
|
14 | 15 | import os
|
@@ -76,6 +77,9 @@ def __init__(
|
76 | 77 | self._threads = {}
|
77 | 78 | self._session_cost_in_cents = 0
|
78 | 79 | self._session_cost_incomplete = False
|
| 80 | + self._future_executor = concurrent.futures.ThreadPoolExecutor( |
| 81 | + max_workers=1 |
| 82 | + ) |
79 | 83 | super().__init__(*args, **kwargs)
|
80 | 84 |
|
81 | 85 | @property
|
@@ -183,6 +187,22 @@ def _shlex_path(path: str) -> List[str]:
|
183 | 187 | lexer.whitespace_split = True
|
184 | 188 | return list(lexer)
|
185 | 189 |
|
| 190 | + @staticmethod |
| 191 | + def _await_future_interruptible( |
| 192 | + future: concurrent.futures.Future, interval: float = 0.25 |
| 193 | + ): |
| 194 | + """ |
| 195 | + Block until the future finishes, waking up |
| 196 | + at the supplied interval so the main thread can raise |
| 197 | + interrupts immediately. |
| 198 | + Returns future.result(). |
| 199 | + """ |
| 200 | + while True: |
| 201 | + try: |
| 202 | + return future.result(timeout=interval) |
| 203 | + except concurrent.futures.TimeoutError: |
| 204 | + continue |
| 205 | + |
186 | 206 | KNOWN_ROLES = tuple(MessageRole)
|
187 | 207 |
|
188 | 208 | @classmethod
|
@@ -374,13 +394,24 @@ def do_send(self, arg):
|
374 | 394 | This command takes no arguments.
|
375 | 395 | """
|
376 | 396 | print("...")
|
| 397 | + # Run the potentially long-running provider call in a background |
| 398 | + # thread so Ctrl+c can interrupt immediately. |
| 399 | + future = self._future_executor.submit( |
| 400 | + self._account.provider.complete, self._current_thread |
| 401 | + ) |
| 402 | + |
377 | 403 | try:
|
378 |
| - res = self._account.provider.complete(self._current_thread) |
| 404 | + res = self.__class__._await_future_interruptible(future) |
379 | 405 | except KeyboardInterrupt:
|
| 406 | + future.cancel() |
| 407 | + print("\nCancelled") |
| 408 | + # This API request may have incurred cost |
| 409 | + self._session_cost_incomplete = True |
380 | 410 | return
|
381 | 411 | except (CompletionError, NotImplementedError, ValueError) as e:
|
382 | 412 | print(str(e))
|
383 | 413 | return
|
| 414 | + |
384 | 415 | try:
|
385 | 416 | for chunk in res:
|
386 | 417 | print(chunk, end="")
|
@@ -1246,6 +1277,8 @@ def do_quit(self, arg):
|
1246 | 1277 | )
|
1247 | 1278 | else:
|
1248 | 1279 | can_exit = True
|
| 1280 | + if can_exit: |
| 1281 | + self._future_executor.shutdown(wait=False) |
1249 | 1282 | return can_exit # Truthy return values cause the cmdloop to stop
|
1250 | 1283 |
|
1251 | 1284 |
|
|
0 commit comments