Skip to content

Commit 398ae53

Browse files
committed
Support file uploading
1 parent c3c61b0 commit 398ae53

File tree

2 files changed

+161
-11
lines changed

2 files changed

+161
-11
lines changed

interpreter/core/async_core.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import os
4+
import shutil
45
import socket
56
import threading
67
import time
@@ -16,7 +17,7 @@
1617
try:
1718
import janus
1819
import uvicorn
19-
from fastapi import APIRouter, FastAPI, WebSocket
20+
from fastapi import APIRouter, FastAPI, File, Form, UploadFile, WebSocket
2021
from fastapi.responses import PlainTextResponse, StreamingResponse
2122
except:
2223
# Server dependencies are not required by the main package.
@@ -60,7 +61,7 @@ async def input(self, chunk):
6061
run_code = None # Will later default to auto_run unless the user makes a command here
6162

6263
# But first, process any commands.
63-
if self.messages[-1]["type"] == "command":
64+
if self.messages[-1].get("type") == "command":
6465
command = self.messages[-1]["content"]
6566
self.messages = self.messages[:-1]
6667

@@ -145,20 +146,61 @@ def accumulate(self, chunk):
145146
# We don't do anything with these.
146147
pass
147148

148-
elif (
149-
"start" in chunk
150-
or chunk["type"] != self.messages[-1]["type"]
151-
or chunk.get("format") != self.messages[-1].get("format")
149+
elif "content" in chunk and not (
150+
len(self.messages) > 0
151+
and (
152+
(
153+
"type" in self.messages[-1]
154+
and chunk.get("type") != self.messages[-1].get("type")
155+
)
156+
or (
157+
"format" in self.messages[-1]
158+
and chunk.get("format") != self.messages[-1].get("format")
159+
)
160+
)
161+
):
162+
if len(self.messages) == 0:
163+
raise Exception(
164+
"You must send a 'start: True' chunk first to create this message."
165+
)
166+
# Append to an existing message
167+
if (
168+
"type" not in self.messages[-1]
169+
): # It was created with a type-less start message
170+
self.messages[-1]["type"] = chunk["type"]
171+
if (
172+
chunk.get("format") and "format" not in self.messages[-1]
173+
): # It was created with a type-less start message
174+
self.messages[-1]["format"] = chunk["format"]
175+
if "content" not in self.messages[-1]:
176+
self.messages[-1]["content"] = chunk["content"]
177+
else:
178+
self.messages[-1]["content"] += chunk["content"]
179+
180+
# elif "content" in chunk and (len(self.messages) > 0 and self.messages[-1] == {'role': 'user', 'start': True}):
181+
# # Last message was {'role': 'user', 'start': True}. Just populate that with this chunk
182+
# self.messages[-1] = chunk.copy()
183+
184+
elif "start" in chunk or (
185+
len(self.messages) > 0
186+
and (
187+
chunk.get("type") != self.messages[-1].get("type")
188+
or chunk.get("format") != self.messages[-1].get("format")
189+
)
152190
):
191+
# Create a new message
153192
chunk_copy = (
154193
chunk.copy()
155194
) # So we don't modify the original chunk, which feels wrong.
156-
chunk_copy.pop("start")
157-
chunk_copy["content"] = ""
195+
if "start" in chunk_copy:
196+
chunk_copy.pop("start")
197+
if "content" not in chunk_copy:
198+
chunk_copy["content"] = ""
158199
self.messages.append(chunk_copy)
159200

160-
elif "content" in chunk:
161-
self.messages[-1]["content"] += chunk["content"]
201+
print("ADDED CHUNK:", chunk)
202+
print("MESSAGES IS NOW:", self.messages)
203+
# time.sleep(5)
162204

163205
elif type(chunk) == bytes:
164206
if self.messages[-1]["content"] == "": # We initialize as an empty string ^
@@ -482,6 +524,24 @@ async def get_setting(setting: str):
482524
else:
483525
return json.dumps({"error": "Setting not found"}), 404
484526

527+
@router.post("/upload")
528+
async def upload_file(file: UploadFile = File(...), path: str = Form(...)):
529+
try:
530+
with open(path, "wb") as output_file:
531+
shutil.copyfileobj(file.file, output_file)
532+
return {"status": "success"}
533+
except Exception as e:
534+
return {"error": str(e)}, 500
535+
536+
@router.get("/download/{filename}")
537+
async def download_file(filename: str):
538+
try:
539+
return StreamingResponse(
540+
open(filename, "rb"), media_type="application/octet-stream"
541+
)
542+
except Exception as e:
543+
return {"error": str(e)}, 500
544+
485545
### OPENAI COMPATIBLE ENDPOINT
486546

487547
class ChatMessage(BaseModel):

tests/test_interpreter.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ async def test_fastapi_server():
8888
while True:
8989
message = await websocket.recv()
9090
message_data = json.loads(message)
91+
if "error" in message_data:
92+
raise Exception(message_data["content"])
9193
print("Received from WebSocket:", message_data)
92-
if message_data.get("content"):
94+
if type(message_data.get("content")) == str:
9395
accumulated_content += message_data.get("content")
9496
if message_data == {
9597
"role": "server",
@@ -142,6 +144,8 @@ async def test_fastapi_server():
142144
while True:
143145
message = await websocket.recv()
144146
message_data = json.loads(message)
147+
if "error" in message_data:
148+
raise Exception(message_data["content"])
145149
print("Received from WebSocket:", message_data)
146150
if message_data.get("content"):
147151
accumulated_content += message_data.get("content")
@@ -189,6 +193,8 @@ async def test_fastapi_server():
189193
while True:
190194
message = await websocket.recv()
191195
message_data = json.loads(message)
196+
if "error" in message_data:
197+
raise Exception(message_data["content"])
192198
print("Received from WebSocket:", message_data)
193199
if message_data.get("content"):
194200
accumulated_content += message_data.get("content")
@@ -237,6 +243,8 @@ async def test_fastapi_server():
237243
while True:
238244
message = await websocket.recv()
239245
message_data = json.loads(message)
246+
if "error" in message_data:
247+
raise Exception(message_data["content"])
240248
print("Received from WebSocket:", message_data)
241249
if message_data.get("content"):
242250
if type(message_data.get("content")) == str:
@@ -251,6 +259,88 @@ async def test_fastapi_server():
251259

252260
assert "18893094989" in accumulated_content.replace(",", "")
253261

262+
#### TEST FILE ####
263+
264+
# Send another POST request
265+
post_url = "http://localhost:8000/settings"
266+
settings = {"messages": [], "auto_run": True}
267+
response = requests.post(post_url, json=settings)
268+
print("POST request sent, response:", response.json())
269+
270+
# Sending messages via WebSocket
271+
await websocket.send(json.dumps({"role": "user", "start": True}))
272+
print("sent", json.dumps({"role": "user", "start": True}))
273+
await websocket.send(
274+
json.dumps(
275+
{
276+
"role": "user",
277+
"type": "message",
278+
"content": "Does this file exist?",
279+
}
280+
)
281+
)
282+
print(
283+
"sent",
284+
{
285+
"role": "user",
286+
"type": "message",
287+
"content": "Does this file exist?",
288+
},
289+
)
290+
await websocket.send(
291+
json.dumps(
292+
{
293+
"role": "user",
294+
"type": "file",
295+
"format": "path",
296+
"content": "/something.txt",
297+
}
298+
)
299+
)
300+
print(
301+
"sent",
302+
{
303+
"role": "user",
304+
"type": "file",
305+
"format": "path",
306+
"content": "/something.txt",
307+
},
308+
)
309+
await websocket.send(json.dumps({"role": "user", "end": True}))
310+
print("WebSocket chunks sent")
311+
312+
# Wait for response
313+
accumulated_content = ""
314+
while True:
315+
message = await websocket.recv()
316+
message_data = json.loads(message)
317+
if "error" in message_data:
318+
raise Exception(message_data["content"])
319+
print("Received from WebSocket:", message_data)
320+
if type(message_data.get("content")) == str:
321+
accumulated_content += message_data.get("content")
322+
if message_data == {
323+
"role": "server",
324+
"type": "status",
325+
"content": "complete",
326+
}:
327+
print("Received expected message from server")
328+
break
329+
330+
# Get messages
331+
get_url = "http://localhost:8000/settings/messages"
332+
response_json = requests.get(get_url).json()
333+
print("GET request sent, response:", response_json)
334+
if isinstance(response_json, str):
335+
response_json = json.loads(response_json)
336+
messages = response_json["messages"]
337+
338+
response = async_interpreter.computer.ai.chat(
339+
str(messages)
340+
+ "\n\nIn the conversation above, does the assistant think the file exists? Yes or no? Only reply with one word— 'yes' or 'no'."
341+
)
342+
assert response.strip(" \n.").lower() == "no"
343+
254344
# Sending POST request to /run endpoint with code to kill a thread in Python
255345
# actually wait i dont think this will work..? will just kill the python interpreter
256346
post_url = "http://localhost:8000/run"

0 commit comments

Comments
 (0)