Skip to content

Commit bd6441a

Browse files
committed
Better Python API, better KeyboardInturrupt
1 parent bdfe9aa commit bd6441a

File tree

5 files changed

+79
-77
lines changed

5 files changed

+79
-77
lines changed

interpreter_1/cli.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,17 @@ def load_interpreter_thread(args):
226226
spinner.stop()
227227
print()
228228
global_interpreter.messages = [{"role": "user", "content": message}]
229-
async for _ in global_interpreter.async_respond():
230-
pass
229+
try:
230+
async for _ in global_interpreter.async_respond():
231+
pass
232+
except KeyboardInterrupt:
233+
global_interpreter._spinner.stop()
234+
except asyncio.CancelledError:
235+
global_interpreter._spinner.stop()
231236
print()
232237

233-
return global_interpreter
238+
if global_interpreter.interactive:
239+
await global_interpreter.async_chat()
234240

235241

236242
def parse_args():
@@ -297,17 +303,10 @@ def main():
297303
global_interpreter.server()
298304
return
299305

300-
# Run async portion
301-
interpreter = asyncio.run(async_main(args))
302-
# If we got an interpreter back and it's interactive, start chat in sync context
303-
if interpreter and interpreter.interactive:
304-
interpreter.chat()
305-
306+
asyncio.run(async_main(args))
306307
except KeyboardInterrupt:
307-
print("KeyboardInterrupt")
308308
sys.exit(0)
309309
except asyncio.CancelledError:
310-
print("CancelledError")
311310
sys.exit(0)
312311

313312

interpreter_1/interpreter.py

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from readchar import readchar
1414

15-
from .misc.get_input import get_input
15+
from .misc.get_input import async_get_input
1616

1717
# Third-party imports
1818
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
@@ -227,11 +227,14 @@ def default_system_message(self):
227227

228228
return system_message
229229

230-
async def async_respond(self):
230+
async def async_respond(self, user_input=None):
231231
"""
232232
Agentic sampling loop for the assistant/tool interaction.
233233
Yields chunks and maintains message history on the interpreter instance.
234234
"""
235+
if user_input:
236+
self.messages.append({"role": "user", "content": user_input})
237+
235238
tools = []
236239
if "interpreter" in self.tools:
237240
tools.append(BashTool())
@@ -925,14 +928,26 @@ def _handle_command(self, cmd: str, parts: list[str]) -> bool:
925928
return self._command_handler.handle_command(cmd, parts)
926929

927930
def chat(self):
928-
"""
929-
Interactive mode
930-
"""
931+
"""Chat with the interpreter. Handles both sync and async contexts."""
932+
try:
933+
loop = asyncio.get_running_loop()
934+
# If we get here, there is a running event loop
935+
loop.create_task(self.async_chat())
936+
except RuntimeError:
937+
# No running event loop, create one
938+
asyncio.run(self.async_chat())
939+
940+
async def async_chat(self):
941+
original_message_length = len(self.messages)
942+
931943
try:
932944
message_count = 0
933945
while True:
934-
user_input = get_input()
935-
print("")
946+
try:
947+
user_input = await async_get_input()
948+
except KeyboardInterrupt:
949+
print()
950+
return self.messages[original_message_length:]
936951

937952
message_count += 1 # Increment counter after each message
938953

@@ -943,23 +958,23 @@ def chat(self):
943958
continue
944959

945960
if user_input == "":
946-
if message_count in range(4, 7):
961+
if message_count in range(8, 11):
947962
print("Error: Cat is asleep on Enter key\n")
948963
else:
949964
print("Error: No input provided\n")
950965
continue
951966

952-
self.messages.append({"role": "user", "content": user_input})
953-
954-
for _ in self.respond():
955-
pass
967+
try:
968+
print()
969+
async for _ in self.async_respond(user_input):
970+
pass
971+
except KeyboardInterrupt:
972+
self._spinner.stop()
973+
except asyncio.CancelledError:
974+
self._spinner.stop()
956975

957976
print()
958-
except KeyboardInterrupt:
959-
self._spinner.stop()
960-
print()
961-
pass
962-
except Exception as e:
977+
except:
963978
self._spinner.stop()
964979
print(traceback.format_exc())
965980
print("\n\n\033[91mAn error has occurred.\033[0m")
@@ -976,35 +991,34 @@ def chat(self):
976991
self._report_error("".join(traceback.format_exc()))
977992
exit(1)
978993

979-
async def _consume_generator(self, generator):
980-
"""Consume the async generator from async_respond"""
981-
async for chunk in generator:
982-
yield chunk
994+
def respond(self, user_input=None, stream=False):
995+
"""Sync method to respond to user input if provided, or to the messages in self.messages."""
996+
if user_input:
997+
self.messages.append({"role": "user", "content": user_input})
983998

984-
def respond(self):
985-
"""
986-
Synchronous wrapper around async_respond.
987-
Yields chunks from the async generator.
988-
"""
999+
if stream:
1000+
return self._sync_respond_stream()
1001+
else:
1002+
original_message_length = len(self.messages)
1003+
for _ in self._sync_respond_stream():
1004+
pass
1005+
return self.messages[original_message_length:]
1006+
1007+
def _sync_respond_stream(self):
1008+
"""Synchronous generator that yields responses. Only use in synchronous contexts."""
1009+
loop = asyncio.new_event_loop()
1010+
asyncio.set_event_loop(loop)
9891011
try:
990-
loop = asyncio.get_event_loop()
991-
except RuntimeError:
992-
loop = asyncio.new_event_loop()
993-
asyncio.set_event_loop(loop)
994-
995-
async def run():
996-
async for chunk in self.async_respond():
997-
yield chunk
998-
999-
agen = run()
1000-
while True:
1001-
try:
1002-
chunk = loop.run_until_complete(anext(agen))
1003-
yield chunk
1004-
except StopAsyncIteration:
1005-
break
1006-
1007-
return self.messages
1012+
# Convert async generator to sync generator
1013+
async_gen = self.async_respond()
1014+
while True:
1015+
try:
1016+
chunk = loop.run_until_complete(async_gen.__anext__())
1017+
yield chunk
1018+
except StopAsyncIteration:
1019+
break
1020+
finally:
1021+
loop.close()
10081022

10091023
def server(self):
10101024
"""

interpreter_1/misc/get_input.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def _(event):
5151
include_default_pygments_style=False,
5252
input_processors=[],
5353
enable_system_prompt=False,
54-
wrap_lines=False,
5554
)
5655
return result
5756

@@ -101,6 +100,5 @@ def _(event):
101100
include_default_pygments_style=False,
102101
input_processors=[],
103102
enable_system_prompt=False,
104-
wrap_lines=False,
105103
)
106104
return result

interpreter_1/server.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,25 +56,7 @@ async def chat_completion(self, request: Request):
5656
self._stream_response(), media_type="text/event-stream"
5757
)
5858

59-
# For non-streaming, collect all chunks
60-
response_text = ""
61-
for chunk in self.interpreter.respond():
62-
if chunk.get("type") == "chunk":
63-
response_text += chunk["chunk"]
64-
65-
return {
66-
"id": "chatcmpl-" + str(time.time()),
67-
"object": "chat.completion",
68-
"created": int(time.time()),
69-
"model": req.model or self.interpreter.model,
70-
"choices": [
71-
{
72-
"index": 0,
73-
"message": {"role": "assistant", "content": response_text},
74-
"finish_reason": "stop",
75-
}
76-
],
77-
}
59+
raise NotImplementedError("Non-streaming is not supported yet")
7860

7961
async def _stream_response(self):
8062
"""Stream the response in OpenAI-compatible format"""

interpreter_1/tools/bash.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import os
33
import shutil
4+
import traceback
45
from typing import ClassVar, Literal
56

67
import pyte
@@ -91,9 +92,17 @@ async def run(self, command: str):
9192
final_output = self._get_screen_text()
9293
return CLIResult(output=final_output if final_output else "<No output>")
9394

94-
except (KeyboardInterrupt, asyncio.CancelledError):
95+
except KeyboardInterrupt as e:
9596
self.stop()
9697
return CLIResult(output="Command cancelled by user.")
98+
except asyncio.CancelledError as e:
99+
self.stop()
100+
return CLIResult(output="Command cancelled by user.")
101+
except Exception as e:
102+
print("Unexpected error")
103+
traceback.print_exc()
104+
self.stop()
105+
return CLIResult(output=f"Command failed with error: {e}")
97106

98107

99108
class BashTool(BaseAnthropicTool):

0 commit comments

Comments
 (0)