12
12
13
13
import shortuuid
14
14
from pydantic import BaseModel
15
+ from starlette .websockets import WebSocketState
15
16
16
17
from .core import OpenInterpreter
17
18
@@ -387,12 +388,14 @@ async def home():
387
388
async def websocket_endpoint (websocket : WebSocket ):
388
389
await websocket .accept ()
389
390
390
- try :
391
+ try : # solving it ;)/ # killian super wrote this
391
392
392
393
async def receive_input ():
393
394
authenticated = False
394
395
while True :
395
396
try :
397
+ if websocket .client_state != WebSocketState .CONNECTED :
398
+ return
396
399
data = await websocket .receive ()
397
400
398
401
if not authenticated :
@@ -425,7 +428,7 @@ async def receive_input():
425
428
data = data ["bytes" ]
426
429
await async_interpreter .input (data )
427
430
elif data .get ("type" ) == "websocket.disconnect" :
428
- print ("Disconnecting ." )
431
+ print ("Client wants to disconnect, that's fine. ." )
429
432
return
430
433
else :
431
434
print ("Invalid data:" , data )
@@ -446,6 +449,8 @@ async def receive_input():
446
449
447
450
async def send_output ():
448
451
while True :
452
+ if websocket .client_state != WebSocketState .CONNECTED :
453
+ return
449
454
try :
450
455
# First, try to send any unsent messages
451
456
while async_interpreter .unsent_messages :
@@ -488,9 +493,12 @@ async def send_message(output):
488
493
):
489
494
output ["id" ] = id
490
495
491
- for attempt in range (100 ):
492
- if websocket .client_state == 3 : # 3 represents 'CLOSED' state
496
+ for attempt in range (20 ):
497
+ # time.sleep(0.5)
498
+
499
+ if websocket .client_state != WebSocketState .CONNECTED :
493
500
break
501
+
494
502
try :
495
503
if isinstance (output , bytes ):
496
504
await websocket .send_bytes (output )
@@ -501,7 +509,7 @@ async def send_message(output):
501
509
502
510
if async_interpreter .require_acknowledge :
503
511
acknowledged = False
504
- for _ in range (1000 ):
512
+ for _ in range (100 ):
505
513
if id in async_interpreter .acknowledged_outputs :
506
514
async_interpreter .acknowledged_outputs .remove (id )
507
515
acknowledged = True
@@ -523,10 +531,13 @@ async def send_message(output):
523
531
await asyncio .sleep (0.05 )
524
532
525
533
# If we've reached this point, we've failed to send after 100 attempts
526
- async_interpreter .unsent_messages .append (output )
527
- print (
528
- f"Added message to unsent_messages queue after failed attempts: { output } "
529
- )
534
+ if output not in async_interpreter .unsent_messages :
535
+ async_interpreter .unsent_messages .append (output )
536
+ print (
537
+ f"Added message to unsent_messages queue after failed attempts: { output } "
538
+ )
539
+ else :
540
+ print ("Why was this already in unsent_messages?" , output )
530
541
531
542
await asyncio .gather (receive_input (), send_output ())
532
543
@@ -731,6 +742,10 @@ def __init__(self, async_interpreter, host=None, port=None):
731
742
# Add authentication middleware
732
743
@self .app .middleware ("http" )
733
744
async def validate_api_key (request : Request , call_next ):
745
+ # Ignore authentication for the /heartbeat route
746
+ if request .url .path == "/heartbeat" :
747
+ return await call_next (request )
748
+
734
749
api_key = request .headers .get ("X-API-KEY" )
735
750
if self .authenticate (api_key ):
736
751
response = await call_next (request )
0 commit comments