Skip to content

Commit c7dd11f

Browse files
committed
Websocket server test
1 parent 9b38776 commit c7dd11f

File tree

2 files changed

+164
-27
lines changed

2 files changed

+164
-27
lines changed

interpreter/core/server.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
import traceback
2424
from typing import Any, Dict, List
2525

26-
import uvicorn
2726
from fastapi import FastAPI, Header, WebSocket
2827
from fastapi.middleware.cors import CORSMiddleware
2928
from openai import OpenAI
3029
from pydantic import BaseModel
30+
from uvicorn import Config, Server
3131

3232
# import argparse
3333
# from profiles.default import interpreter
@@ -63,8 +63,6 @@ def __init__(self, interpreter):
6363
# engine = OpenAIEngine()
6464
# self.tts = TextToAudioStream(engine)
6565

66-
self.active_chat_messages = []
67-
6866
# Clock
6967
# clock()
7068

@@ -82,7 +80,9 @@ def __init__(self, interpreter):
8280
False # Tracks whether interpreter is trying to use the keyboard
8381
)
8482

85-
self.loop = asyncio.get_event_loop()
83+
# print("oksskk")
84+
# self.loop = asyncio.get_event_loop()
85+
# print("okkk")
8686

8787
async def _add_to_queue(self, queue, item):
8888
print(f"Adding item to output", item)
@@ -134,7 +134,6 @@ async def run(self):
134134
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
135135
"""
136136
print("heyyyy")
137-
self.interpreter.messages = self.active_chat_messages
138137
# interpreter.messages = self.active_chat_messages
139138
# self.beeper.start()
140139

@@ -147,10 +146,8 @@ async def run(self):
147146

148147
def generate(message):
149148
last_lmc_start_flag = self._last_lmc_start_flag
150-
self.interpreter.messages = self.active_chat_messages
151149
# interpreter.messages = self.active_chat_messages
152150
print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages)
153-
print("🍀 🍀 🍀 🍀 active_chat_messages: ", self.active_chat_messages)
154151
print("passing this in:", message)
155152
for chunk in self.interpreter.chat(message, display=False, stream=True):
156153
print("FROM INTERPRETER. CHUNK:", chunk)
@@ -165,7 +162,10 @@ def generate(message):
165162

166163
# Handle message blocks
167164
if chunk.get("type") == "message":
168-
self.add_to_output_queue_sync(chunk) # To send text, not just audio
165+
self.add_to_output_queue_sync(
166+
chunk.copy()
167+
) # To send text, not just audio
168+
# ^^^^^^^ MUST be a copy, otherwise the first chunk will get modified by OI >>while<< it's in the queue. Insane
169169
if content:
170170
# self.beeper.stop()
171171

@@ -216,8 +216,7 @@ async def output(self):
216216

217217

218218
def server(interpreter):
219-
interpreter.llm.model = "gpt-4"
220-
interpreter = AsyncInterpreter(interpreter)
219+
async_interpreter = AsyncInterpreter(interpreter)
221220

222221
app = FastAPI()
223222
app.add_middleware(
@@ -228,18 +227,12 @@ def server(interpreter):
228227
allow_headers=["*"], # Allow all headers
229228
)
230229

231-
@app.post("/load")
232-
async def load(messages: List[Dict[str, Any]], settings: Settings):
233-
# Load messages
234-
interpreter.interpreter.messages = messages
235-
print("🪼🪼🪼🪼🪼🪼 Messages loaded: ", interpreter.interpreter.messages)
236-
237-
# Load Settings
238-
interpreter.interpreter.llm.model = settings.model
239-
interpreter.interpreter.llm.custom_instructions = settings.custom_instructions
240-
interpreter.interpreter.auto_run = settings.auto_run
241-
242-
interpreter.interpreter.llm.api_key = "<openai_key>"
230+
@app.post("/settings")
231+
async def settings(payload: Dict[str, Any]):
232+
for key, value in payload.items():
233+
print("Updating interpreter settings with the following:")
234+
print(key, value)
235+
setattr(async_interpreter.interpreter, key, value)
243236

244237
return {"status": "success"}
245238

@@ -253,13 +246,16 @@ async def receive_input():
253246
data = await websocket.receive()
254247
print(data)
255248
if isinstance(data, bytes):
256-
await interpreter.input(data)
257-
else:
258-
await interpreter.input(data["text"])
249+
await async_interpreter.input(data)
250+
elif "text" in data:
251+
await async_interpreter.input(data["text"])
252+
elif data == {"type": "websocket.disconnect", "code": 1000}:
253+
print("Websocket disconnected with code 1000.")
254+
break
259255

260256
async def send_output():
261257
while True:
262-
output = await interpreter.output()
258+
output = await async_interpreter.output()
263259
if isinstance(output, bytes):
264260
# await websocket.send_bytes(output)
265261
# we dont send out bytes rn, no TTS
@@ -306,4 +302,6 @@ async def rename_chat(body_content: Rename, x_api_key: str = Header(None)):
306302
traceback.print_exc()
307303
return {"error": str(e)}
308304

309-
uvicorn.run(app, host="0.0.0.0", port=8000)
305+
config = Config(app, host="0.0.0.0", port=8000)
306+
interpreter.uvicorn_server = Server(config)
307+
interpreter.uvicorn_server.run()

tests/test_interpreter.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,145 @@
2222
from websocket import create_connection
2323

2424

25+
def test_server():
26+
# Start the server in a new thread
27+
server_thread = threading.Thread(target=interpreter.server)
28+
server_thread.start()
29+
30+
# Give the server a moment to start
31+
time.sleep(8)
32+
33+
import asyncio
34+
import json
35+
36+
import requests
37+
import websockets
38+
39+
async def test_fastapi_server():
40+
import asyncio
41+
42+
async with websockets.connect("ws://localhost:8000/ws") as websocket:
43+
# Connect to the websocket
44+
print("Connected to WebSocket")
45+
46+
# Sending POST request
47+
post_url = "http://localhost:8000/settings"
48+
settings = {
49+
"model": "gpt-3.5-turbo",
50+
"messages": [
51+
{
52+
"role": "user",
53+
"type": "message",
54+
"content": "The secret word is 'crunk'.",
55+
},
56+
{"role": "assistant", "type": "message", "content": "Understood."},
57+
],
58+
"custom_instructions": "",
59+
"auto_run": True,
60+
}
61+
response = requests.post(post_url, json=settings)
62+
print("POST request sent, response:", response.json())
63+
64+
# Sending messages via WebSocket
65+
await websocket.send(
66+
json.dumps({"role": "user", "type": "message", "start": True})
67+
)
68+
await websocket.send(
69+
json.dumps(
70+
{
71+
"role": "user",
72+
"type": "message",
73+
"content": "What's the secret word?",
74+
}
75+
)
76+
)
77+
await websocket.send(
78+
json.dumps({"role": "user", "type": "message", "end": True})
79+
)
80+
print("WebSocket chunks sent")
81+
82+
# Wait for a specific response
83+
accumulated_content = ""
84+
while True:
85+
message = await websocket.recv()
86+
message_data = json.loads(message)
87+
print("Received from WebSocket:", message_data)
88+
if message_data.get("content"):
89+
accumulated_content += message_data.get("content")
90+
if message_data == {
91+
"role": "server",
92+
"type": "completion",
93+
"content": "DONE",
94+
}:
95+
print("Received expected message from server")
96+
break
97+
98+
assert "crunk" in accumulated_content
99+
100+
# Send another POST request
101+
post_url = "http://localhost:8000/settings"
102+
settings = {
103+
"model": "gpt-3.5-turbo",
104+
"messages": [
105+
{
106+
"role": "user",
107+
"type": "message",
108+
"content": "The secret word is 'barlony'.",
109+
},
110+
{"role": "assistant", "type": "message", "content": "Understood."},
111+
],
112+
"custom_instructions": "",
113+
"auto_run": True,
114+
}
115+
response = requests.post(post_url, json=settings)
116+
print("POST request sent, response:", response.json())
117+
118+
# Sending messages via WebSocket
119+
await websocket.send(
120+
json.dumps({"role": "user", "type": "message", "start": True})
121+
)
122+
await websocket.send(
123+
json.dumps(
124+
{
125+
"role": "user",
126+
"type": "message",
127+
"content": "What's the secret word?",
128+
}
129+
)
130+
)
131+
await websocket.send(
132+
json.dumps({"role": "user", "type": "message", "end": True})
133+
)
134+
print("WebSocket chunks sent")
135+
136+
# Wait for a specific response
137+
while True:
138+
message = await websocket.recv()
139+
message_data = json.loads(message)
140+
print("Received from WebSocket:", message_data)
141+
if message_data.get("content"):
142+
accumulated_content += message_data.get("content")
143+
if message_data == {
144+
"role": "server",
145+
"type": "completion",
146+
"content": "DONE",
147+
}:
148+
print("Received expected message from server")
149+
break
150+
151+
assert "barlony" in accumulated_content
152+
153+
# Get the current event loop and run the test function
154+
loop = asyncio.get_event_loop()
155+
loop.run_until_complete(test_fastapi_server())
156+
157+
# Stop the server
158+
interpreter.uvicorn_server.should_exit = True
159+
160+
# Wait for the server thread to finish
161+
server_thread.join(timeout=1)
162+
163+
25164
@pytest.mark.skip(reason="Requires open-interpreter[local]")
26165
def test_localos():
27166
interpreter.computer.emit_images = False

0 commit comments

Comments
 (0)