Skip to content

Commit 14ef23a

Browse files
authored
Server authentication
Server authentication
2 parents 99557bb + 4f66916 commit 14ef23a

File tree

2 files changed

+105
-31
lines changed

2 files changed

+105
-31
lines changed

interpreter/core/async_core.py

Lines changed: 102 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,18 @@
1717
try:
1818
import janus
1919
import uvicorn
20-
from fastapi import APIRouter, FastAPI, File, Form, UploadFile, WebSocket
21-
from fastapi.responses import PlainTextResponse, StreamingResponse
20+
from fastapi import (
21+
APIRouter,
22+
FastAPI,
23+
File,
24+
Form,
25+
HTTPException,
26+
Request,
27+
UploadFile,
28+
WebSocket,
29+
)
30+
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
31+
from starlette.status import HTTP_403_FORBIDDEN
2232
except:
2333
# Server dependencies are not required by the main package.
2434
pass
@@ -204,6 +214,24 @@ def accumulate(self, chunk):
204214
self.messages[-1]["content"] += chunk
205215

206216

217+
def authenticate_function(key):
218+
"""
219+
This function checks if the provided key is valid for authentication.
220+
221+
Returns True if the key is valid, False otherwise.
222+
"""
223+
# Fetch the API key from the environment variables. If it's not set, return True.
224+
api_key = os.getenv("INTERPRETER_API_KEY", None)
225+
226+
# If the API key is not set in the environment variables, return True.
227+
# Otherwise, check if the provided key matches the fetched API key.
228+
# Return True if they match, False otherwise.
229+
if api_key is None:
230+
return True
231+
else:
232+
return key == api_key
233+
234+
207235
def create_router(async_interpreter):
208236
router = APIRouter()
209237

@@ -226,6 +254,7 @@ async def home():
226254
<button>Send</button>
227255
</form>
228256
<button id="approveCodeButton">Approve Code</button>
257+
<button id="authButton">Send Auth</button>
229258
<div id="messages"></div>
230259
<script>
231260
var ws = new WebSocket("ws://"""
@@ -234,6 +263,7 @@ async def home():
234263
+ str(async_interpreter.server.port)
235264
+ """/");
236265
var lastMessageElement = null;
266+
237267
ws.onmessage = function(event) {
238268
239269
var eventData = JSON.parse(event.data);
@@ -326,8 +356,15 @@ async def home():
326356
};
327357
ws.send(JSON.stringify(endCommandBlock));
328358
}
359+
function authenticate() {
360+
var authBlock = {
361+
"auth": "dummy-api-key"
362+
};
363+
ws.send(JSON.stringify(authBlock));
364+
}
329365
330366
document.getElementById("approveCodeButton").addEventListener("click", approveCode);
367+
document.getElementById("authButton").addEventListener("click", authenticate);
331368
</script>
332369
</body>
333370
</html>
@@ -338,13 +375,30 @@ async def home():
338375
@router.websocket("/")
339376
async def websocket_endpoint(websocket: WebSocket):
340377
await websocket.accept()
378+
341379
try:
342380

343381
async def receive_input():
382+
authenticated = False
344383
while True:
345384
try:
346385
data = await websocket.receive()
347386

387+
if not authenticated:
388+
if "text" in data:
389+
data = json.loads(data["text"])
390+
if "auth" in data:
391+
if async_interpreter.server.authenticate(
392+
data["auth"]
393+
):
394+
authenticated = True
395+
await websocket.send_text(
396+
json.dumps({"auth": True})
397+
)
398+
if not authenticated:
399+
await websocket.send_text(json.dumps({"auth": False}))
400+
continue
401+
348402
if data.get("type") == "websocket.receive":
349403
if "text" in data:
350404
data = json.loads(data["text"])
@@ -474,19 +528,6 @@ async def post_input(payload: Dict[str, Any]):
474528
except Exception as e:
475529
return {"error": str(e)}, 500
476530

477-
@router.post("/run")
478-
async def run_code(payload: Dict[str, Any]):
479-
language, code = payload.get("language"), payload.get("code")
480-
if not (language and code):
481-
return {"error": "Both 'language' and 'code' are required."}, 400
482-
try:
483-
print(f"Running {language}:", code)
484-
output = async_interpreter.computer.run(language, code)
485-
print("Output:", output)
486-
return {"output": output}
487-
except Exception as e:
488-
return {"error": str(e)}, 500
489-
490531
@router.post("/settings")
491532
async def set_settings(payload: Dict[str, Any]):
492533
for key, value in payload.items():
@@ -520,23 +561,38 @@ async def get_setting(setting: str):
520561
else:
521562
return json.dumps({"error": "Setting not found"}), 404
522563

523-
@router.post("/upload")
524-
async def upload_file(file: UploadFile = File(...), path: str = Form(...)):
525-
try:
526-
with open(path, "wb") as output_file:
527-
shutil.copyfileobj(file.file, output_file)
528-
return {"status": "success"}
529-
except Exception as e:
530-
return {"error": str(e)}, 500
564+
if os.getenv("INTERPRETER_INSECURE_ROUTES", "").lower() == "true":
531565

532-
@router.get("/download/{filename}")
533-
async def download_file(filename: str):
534-
try:
535-
return StreamingResponse(
536-
open(filename, "rb"), media_type="application/octet-stream"
537-
)
538-
except Exception as e:
539-
return {"error": str(e)}, 500
566+
@router.post("/run")
567+
async def run_code(payload: Dict[str, Any]):
568+
language, code = payload.get("language"), payload.get("code")
569+
if not (language and code):
570+
return {"error": "Both 'language' and 'code' are required."}, 400
571+
try:
572+
print(f"Running {language}:", code)
573+
output = async_interpreter.computer.run(language, code)
574+
print("Output:", output)
575+
return {"output": output}
576+
except Exception as e:
577+
return {"error": str(e)}, 500
578+
579+
@router.post("/upload")
580+
async def upload_file(file: UploadFile = File(...), path: str = Form(...)):
581+
try:
582+
with open(path, "wb") as output_file:
583+
shutil.copyfileobj(file.file, output_file)
584+
return {"status": "success"}
585+
except Exception as e:
586+
return {"error": str(e)}, 500
587+
588+
@router.get("/download/{filename}")
589+
async def download_file(filename: str):
590+
try:
591+
return StreamingResponse(
592+
open(filename, "rb"), media_type="application/octet-stream"
593+
)
594+
except Exception as e:
595+
return {"error": str(e)}, 500
540596

541597
### OPENAI COMPATIBLE ENDPOINT
542598

@@ -648,6 +704,21 @@ class Server:
648704
def __init__(self, async_interpreter, host="127.0.0.1", port=8000):
649705
self.app = FastAPI()
650706
router = create_router(async_interpreter)
707+
self.authenticate = authenticate_function
708+
709+
# Add authentication middleware
710+
@self.app.middleware("http")
711+
async def validate_api_key(request: Request, call_next):
712+
api_key = request.headers.get("X-API-KEY")
713+
if self.authenticate(api_key):
714+
response = await call_next(request)
715+
return response
716+
else:
717+
return JSONResponse(
718+
status_code=HTTP_403_FORBIDDEN,
719+
content={"detail": "Authentication failed"},
720+
)
721+
651722
self.app.include_router(router)
652723
self.config = uvicorn.Config(app=self.app, host=host, port=port)
653724
self.uvicorn_server = uvicorn.Server(self.config)

tests/test_interpreter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ async def test_fastapi_server():
5353
# Connect to the websocket
5454
print("Connected to WebSocket")
5555

56+
# Sending message via WebSocket
57+
await websocket.send(json.dumps({"auth": "dummy-api-key"}))
58+
5659
# Sending POST request
5760
post_url = "http://localhost:8000/settings"
5861
settings = {

0 commit comments

Comments
 (0)