17
17
try :
18
18
import janus
19
19
import uvicorn
20
- from fastapi import APIRouter , FastAPI , File , Form , UploadFile , WebSocket
21
- from fastapi .responses import PlainTextResponse , StreamingResponse
20
+ from fastapi import (
21
+ APIRouter ,
22
+ FastAPI ,
23
+ File ,
24
+ Form ,
25
+ HTTPException ,
26
+ Request ,
27
+ UploadFile ,
28
+ WebSocket ,
29
+ )
30
+ from fastapi .responses import JSONResponse , PlainTextResponse , StreamingResponse
31
+ from starlette .status import HTTP_403_FORBIDDEN
22
32
except :
23
33
# Server dependencies are not required by the main package.
24
34
pass
@@ -204,6 +214,24 @@ def accumulate(self, chunk):
204
214
self .messages [- 1 ]["content" ] += chunk
205
215
206
216
217
+ def authenticate_function (key ):
218
+ """
219
+ This function checks if the provided key is valid for authentication.
220
+
221
+ Returns True if the key is valid, False otherwise.
222
+ """
223
+ # Fetch the API key from the environment variables. If it's not set, return True.
224
+ api_key = os .getenv ("INTERPRETER_API_KEY" , None )
225
+
226
+ # If the API key is not set in the environment variables, return True.
227
+ # Otherwise, check if the provided key matches the fetched API key.
228
+ # Return True if they match, False otherwise.
229
+ if api_key is None :
230
+ return True
231
+ else :
232
+ return key == api_key
233
+
234
+
207
235
def create_router (async_interpreter ):
208
236
router = APIRouter ()
209
237
@@ -226,6 +254,7 @@ async def home():
226
254
<button>Send</button>
227
255
</form>
228
256
<button id="approveCodeButton">Approve Code</button>
257
+ <button id="authButton">Send Auth</button>
229
258
<div id="messages"></div>
230
259
<script>
231
260
var ws = new WebSocket("ws://"""
@@ -234,6 +263,7 @@ async def home():
234
263
+ str (async_interpreter .server .port )
235
264
+ """/");
236
265
var lastMessageElement = null;
266
+
237
267
ws.onmessage = function(event) {
238
268
239
269
var eventData = JSON.parse(event.data);
@@ -326,8 +356,15 @@ async def home():
326
356
};
327
357
ws.send(JSON.stringify(endCommandBlock));
328
358
}
359
+ function authenticate() {
360
+ var authBlock = {
361
+ "auth": "dummy-api-key"
362
+ };
363
+ ws.send(JSON.stringify(authBlock));
364
+ }
329
365
330
366
document.getElementById("approveCodeButton").addEventListener("click", approveCode);
367
+ document.getElementById("authButton").addEventListener("click", authenticate);
331
368
</script>
332
369
</body>
333
370
</html>
@@ -338,13 +375,30 @@ async def home():
338
375
@router .websocket ("/" )
339
376
async def websocket_endpoint (websocket : WebSocket ):
340
377
await websocket .accept ()
378
+
341
379
try :
342
380
343
381
async def receive_input ():
382
+ authenticated = False
344
383
while True :
345
384
try :
346
385
data = await websocket .receive ()
347
386
387
+ if not authenticated :
388
+ if "text" in data :
389
+ data = json .loads (data ["text" ])
390
+ if "auth" in data :
391
+ if async_interpreter .server .authenticate (
392
+ data ["auth" ]
393
+ ):
394
+ authenticated = True
395
+ await websocket .send_text (
396
+ json .dumps ({"auth" : True })
397
+ )
398
+ if not authenticated :
399
+ await websocket .send_text (json .dumps ({"auth" : False }))
400
+ continue
401
+
348
402
if data .get ("type" ) == "websocket.receive" :
349
403
if "text" in data :
350
404
data = json .loads (data ["text" ])
@@ -474,19 +528,6 @@ async def post_input(payload: Dict[str, Any]):
474
528
except Exception as e :
475
529
return {"error" : str (e )}, 500
476
530
477
- @router .post ("/run" )
478
- async def run_code (payload : Dict [str , Any ]):
479
- language , code = payload .get ("language" ), payload .get ("code" )
480
- if not (language and code ):
481
- return {"error" : "Both 'language' and 'code' are required." }, 400
482
- try :
483
- print (f"Running { language } :" , code )
484
- output = async_interpreter .computer .run (language , code )
485
- print ("Output:" , output )
486
- return {"output" : output }
487
- except Exception as e :
488
- return {"error" : str (e )}, 500
489
-
490
531
@router .post ("/settings" )
491
532
async def set_settings (payload : Dict [str , Any ]):
492
533
for key , value in payload .items ():
@@ -520,23 +561,38 @@ async def get_setting(setting: str):
520
561
else :
521
562
return json .dumps ({"error" : "Setting not found" }), 404
522
563
523
- @router .post ("/upload" )
524
- async def upload_file (file : UploadFile = File (...), path : str = Form (...)):
525
- try :
526
- with open (path , "wb" ) as output_file :
527
- shutil .copyfileobj (file .file , output_file )
528
- return {"status" : "success" }
529
- except Exception as e :
530
- return {"error" : str (e )}, 500
564
+ if os .getenv ("INTERPRETER_INSECURE_ROUTES" , "" ).lower () == "true" :
531
565
532
- @router .get ("/download/{filename}" )
533
- async def download_file (filename : str ):
534
- try :
535
- return StreamingResponse (
536
- open (filename , "rb" ), media_type = "application/octet-stream"
537
- )
538
- except Exception as e :
539
- return {"error" : str (e )}, 500
566
+ @router .post ("/run" )
567
+ async def run_code (payload : Dict [str , Any ]):
568
+ language , code = payload .get ("language" ), payload .get ("code" )
569
+ if not (language and code ):
570
+ return {"error" : "Both 'language' and 'code' are required." }, 400
571
+ try :
572
+ print (f"Running { language } :" , code )
573
+ output = async_interpreter .computer .run (language , code )
574
+ print ("Output:" , output )
575
+ return {"output" : output }
576
+ except Exception as e :
577
+ return {"error" : str (e )}, 500
578
+
579
+ @router .post ("/upload" )
580
+ async def upload_file (file : UploadFile = File (...), path : str = Form (...)):
581
+ try :
582
+ with open (path , "wb" ) as output_file :
583
+ shutil .copyfileobj (file .file , output_file )
584
+ return {"status" : "success" }
585
+ except Exception as e :
586
+ return {"error" : str (e )}, 500
587
+
588
+ @router .get ("/download/{filename}" )
589
+ async def download_file (filename : str ):
590
+ try :
591
+ return StreamingResponse (
592
+ open (filename , "rb" ), media_type = "application/octet-stream"
593
+ )
594
+ except Exception as e :
595
+ return {"error" : str (e )}, 500
540
596
541
597
### OPENAI COMPATIBLE ENDPOINT
542
598
@@ -648,6 +704,21 @@ class Server:
648
704
def __init__ (self , async_interpreter , host = "127.0.0.1" , port = 8000 ):
649
705
self .app = FastAPI ()
650
706
router = create_router (async_interpreter )
707
+ self .authenticate = authenticate_function
708
+
709
+ # Add authentication middleware
710
+ @self .app .middleware ("http" )
711
+ async def validate_api_key (request : Request , call_next ):
712
+ api_key = request .headers .get ("X-API-KEY" )
713
+ if self .authenticate (api_key ):
714
+ response = await call_next (request )
715
+ return response
716
+ else :
717
+ return JSONResponse (
718
+ status_code = HTTP_403_FORBIDDEN ,
719
+ content = {"detail" : "Authentication failed" },
720
+ )
721
+
651
722
self .app .include_router (router )
652
723
self .config = uvicorn .Config (app = self .app , host = host , port = port )
653
724
self .uvicorn_server = uvicorn .Server (self .config )
0 commit comments