Skip to content

Commit b5bc226

Browse files
committed
Robust server
1 parent 9bc8345 commit b5bc226

File tree

4 files changed

+177
-13
lines changed

4 files changed

+177
-13
lines changed

interpreter/core/async_core.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import shortuuid
1414
from pydantic import BaseModel
15+
from starlette.websockets import WebSocketState
1516

1617
from .core import OpenInterpreter
1718

@@ -387,12 +388,14 @@ async def home():
387388
async def websocket_endpoint(websocket: WebSocket):
388389
await websocket.accept()
389390

390-
try:
391+
try: # solving it ;)/ # killian super wrote this
391392

392393
async def receive_input():
393394
authenticated = False
394395
while True:
395396
try:
397+
if websocket.client_state != WebSocketState.CONNECTED:
398+
return
396399
data = await websocket.receive()
397400

398401
if not authenticated:
@@ -425,7 +428,7 @@ async def receive_input():
425428
data = data["bytes"]
426429
await async_interpreter.input(data)
427430
elif data.get("type") == "websocket.disconnect":
428-
print("Disconnecting.")
431+
print("Client wants to disconnect, that's fine..")
429432
return
430433
else:
431434
print("Invalid data:", data)
@@ -446,6 +449,8 @@ async def receive_input():
446449

447450
async def send_output():
448451
while True:
452+
if websocket.client_state != WebSocketState.CONNECTED:
453+
return
449454
try:
450455
# First, try to send any unsent messages
451456
while async_interpreter.unsent_messages:
@@ -488,9 +493,12 @@ async def send_message(output):
488493
):
489494
output["id"] = id
490495

491-
for attempt in range(100):
492-
if websocket.client_state == 3: # 3 represents 'CLOSED' state
496+
for attempt in range(20):
497+
# time.sleep(0.5)
498+
499+
if websocket.client_state != WebSocketState.CONNECTED:
493500
break
501+
494502
try:
495503
if isinstance(output, bytes):
496504
await websocket.send_bytes(output)
@@ -501,7 +509,7 @@ async def send_message(output):
501509

502510
if async_interpreter.require_acknowledge:
503511
acknowledged = False
504-
for _ in range(1000):
512+
for _ in range(100):
505513
if id in async_interpreter.acknowledged_outputs:
506514
async_interpreter.acknowledged_outputs.remove(id)
507515
acknowledged = True
@@ -523,10 +531,13 @@ async def send_message(output):
523531
await asyncio.sleep(0.05)
524532

525533
# If we've reached this point, we've failed to send after 100 attempts
526-
async_interpreter.unsent_messages.append(output)
527-
print(
528-
f"Added message to unsent_messages queue after failed attempts: {output}"
529-
)
534+
if output not in async_interpreter.unsent_messages:
535+
async_interpreter.unsent_messages.append(output)
536+
print(
537+
f"Added message to unsent_messages queue after failed attempts: {output}"
538+
)
539+
else:
540+
print("Why was this already in unsent_messages?", output)
530541

531542
await asyncio.gather(receive_input(), send_output())
532543

interpreter/core/respond.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def respond(interpreter):
181181
except:
182182
pass
183183

184-
if code.endswith("executeexecute"):
184+
if code.strip().endswith("executeexecute"):
185185
edited_code = code.replace("executeexecute", "")
186186
try:
187187
code_dict = json.loads(edited_code)

numbers.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

tests/test_interpreter.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,163 @@
2424
from websocket import create_connection
2525

2626

27+
def run_auth_server():
28+
os.environ["INTERPRETER_REQUIRE_ACKNOWLEDGE"] = "True"
29+
os.environ["INTERPRETER_API_KEY"] = "testing"
30+
async_interpreter = AsyncInterpreter()
31+
async_interpreter.print = False
32+
async_interpreter.server.run()
33+
34+
35+
# @pytest.mark.skip(reason="Requires uvicorn, which we don't require by default")
36+
def test_authenticated_acknowledging_breaking_server():
37+
"""
38+
Test the server when we have authentication and acknowledging one.
39+
40+
I know this is bad, just trying to test quickly!
41+
"""
42+
43+
# Start the server in a new process
44+
45+
process = multiprocessing.Process(target=run_auth_server)
46+
process.start()
47+
48+
# Give the server a moment to start
49+
time.sleep(2)
50+
51+
import asyncio
52+
import json
53+
54+
import requests
55+
import websockets
56+
57+
async def test_fastapi_server():
58+
import asyncio
59+
60+
async with websockets.connect("ws://localhost:8000/") as websocket:
61+
# Connect to the websocket
62+
print("Connected to WebSocket")
63+
64+
# Sending message via WebSocket
65+
await websocket.send(json.dumps({"auth": "testing"}))
66+
67+
# Sending POST request
68+
post_url = "http://localhost:8000/settings"
69+
settings = {
70+
"llm": {
71+
"model": "gpt-4o",
72+
"execution_instructions": "",
73+
"supports_functions": False,
74+
},
75+
"system_message": "You are a poem writing bot. Do not do anything but respond with a poem.",
76+
"auto_run": True,
77+
}
78+
response = requests.post(
79+
post_url, json=settings, headers={"X-API-KEY": "testing"}
80+
)
81+
print("POST request sent, response:", response.json())
82+
83+
# Sending messages via WebSocket
84+
await websocket.send(
85+
json.dumps({"role": "user", "type": "message", "start": True})
86+
)
87+
await websocket.send(
88+
json.dumps(
89+
{
90+
"role": "user",
91+
"type": "message",
92+
"content": "Write a short poem about Seattle.",
93+
}
94+
)
95+
)
96+
await websocket.send(
97+
json.dumps({"role": "user", "type": "message", "end": True})
98+
)
99+
print("WebSocket chunks sent")
100+
101+
max_chunks = 5
102+
103+
poem = ""
104+
while True:
105+
max_chunks -= 1
106+
if max_chunks == 0:
107+
break
108+
message = await websocket.recv()
109+
message_data = json.loads(message)
110+
if "id" in message_data:
111+
await websocket.send(json.dumps({"ack": message_data["id"]}))
112+
if "error" in message_data:
113+
raise Exception(str(message_data))
114+
print("Received from WebSocket:", message_data)
115+
if type(message_data.get("content")) == str:
116+
poem += message_data.get("content")
117+
print(message_data.get("content"), end="", flush=True)
118+
if message_data == {
119+
"role": "server",
120+
"type": "status",
121+
"content": "complete",
122+
}:
123+
raise (
124+
Exception(
125+
"It shouldn't have finished this soon, accumulated_content is: "
126+
+ accumulated_content
127+
)
128+
)
129+
130+
await websocket.close()
131+
print("Disconnected from WebSocket")
132+
133+
time.sleep(3)
134+
135+
# Now let's hilariously keep going
136+
print("RESUMING")
137+
138+
async with websockets.connect("ws://localhost:8000/") as websocket:
139+
# Connect to the websocket
140+
print("Connected to WebSocket")
141+
142+
# Sending message via WebSocket
143+
await websocket.send(json.dumps({"auth": "testing"}))
144+
145+
while True:
146+
message = await websocket.recv()
147+
message_data = json.loads(message)
148+
if "id" in message_data:
149+
await websocket.send(json.dumps({"ack": message_data["id"]}))
150+
if "error" in message_data:
151+
raise Exception(str(message_data))
152+
print("Received from WebSocket:", message_data)
153+
message_data.pop("id", "")
154+
if message_data == {
155+
"role": "server",
156+
"type": "status",
157+
"content": "complete",
158+
}:
159+
break
160+
if type(message_data.get("content")) == str:
161+
poem += message_data.get("content")
162+
print(message_data.get("content"), end="", flush=True)
163+
164+
time.sleep(1)
165+
print("Is this a normal poem?")
166+
print(poem)
167+
time.sleep(1)
168+
169+
# Get the current event loop and run the test function
170+
loop = asyncio.get_event_loop()
171+
try:
172+
loop.run_until_complete(test_fastapi_server())
173+
finally:
174+
# Kill server process
175+
process.terminate()
176+
os.kill(process.pid, signal.SIGKILL) # Send SIGKILL signal
177+
process.join()
178+
179+
27180
def run_server():
181+
os.environ["INTERPRETER_REQUIRE_ACKNOWLEDGE"] = "False"
182+
if "INTERPRETER_API_KEY" in os.environ:
183+
del os.environ["INTERPRETER_API_KEY"]
28184
async_interpreter = AsyncInterpreter()
29185
async_interpreter.print = False
30186
async_interpreter.server.run()

0 commit comments

Comments
 (0)