Skip to content

Commit d3cab53

Browse files
committed
Save default websocket for the language
1 parent 4e9de80 commit d3cab53

File tree

3 files changed

+49
-11
lines changed

3 files changed

+49
-11
lines changed

template/server/api/models/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ class Context(BaseModel):
66
id: StrictStr = Field(description="Context ID")
77
language: StrictStr = Field(description="Language of the context")
88
cwd: StrictStr = Field(description="Current working directory of the context")
9+
10+
def __hash__(self):
11+
return hash(self.id)

template/server/main.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import uuid
44
import httpx
55

6-
from typing import Dict, Union, Literal, List
6+
from typing import Dict, Union, Literal, Set
77

88
from contextlib import asynccontextmanager
99
from fastapi import FastAPI
@@ -16,7 +16,7 @@
1616
from contexts import create_context, normalize_language
1717
from messaging import ContextWebSocket
1818
from stream import StreamingListJsonResponse
19-
19+
from utils.locks import LockedMap
2020

2121
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
2222
logger = logging.Logger(__name__)
@@ -25,7 +25,7 @@
2525

2626

2727
websockets: Dict[Union[str, Literal["default"]], ContextWebSocket] = {}
28-
default_websockets: Dict[str, str] = {}
28+
default_websockets = LockedMap()
2929
global client
3030

3131

@@ -84,11 +84,16 @@ async def post_execute(request: ExecutionRequest):
8484
context_id = None
8585
if request.language:
8686
language = normalize_language(request.language)
87-
context_id = default_websockets.get(language)
8887

89-
if not context_id:
90-
context = await create_context(client, websockets, language, "/home/user")
91-
context_id = context.id
88+
async with await default_websockets.get_lock(language):
89+
context_id = default_websockets.get(language)
90+
91+
if not context_id:
92+
context = await create_context(
93+
client, websockets, language, "/home/user"
94+
)
95+
context_id = context.id
96+
default_websockets[language] = context_id
9297

9398
elif request.context_id:
9499
context_id = request.context_id
@@ -120,19 +125,19 @@ async def post_contexts(request: CreateContext) -> Context:
120125

121126

122127
@app.get("/contexts")
123-
async def get_contexts() -> List[Context]:
128+
async def get_contexts() -> Set[Context]:
124129
logger.info(f"Listing contexts")
125130

126-
context_ids = list(websockets.keys())
131+
context_ids = websockets.keys()
127132

128-
return [
133+
return set(
129134
Context(
130135
id=websockets[context_id].context_id,
131136
language=websockets[context_id].language,
132137
cwd=websockets[context_id].cwd,
133138
)
134139
for context_id in context_ids
135-
]
140+
)
136141

137142

138143
@app.post("/contexts/{context_id}/restart")

template/server/utils/locks.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import asyncio
2+
3+
4+
class LockedMap:
5+
def __init__(self):
6+
self.map_lock = asyncio.Lock()
7+
self.map = {}
8+
self.locks = {}
9+
10+
def get(self, key):
11+
return self.map.get(key)
12+
13+
def set(self, key, value):
14+
self.map[key] = value
15+
16+
async def get_lock(self, key):
17+
await self.map_lock.acquire()
18+
if key not in self.locks:
19+
self.locks[key] = asyncio.Lock()
20+
21+
lock = self.locks[key]
22+
print(f"Lock acquired for {key}")
23+
self.map_lock.release()
24+
return lock
25+
26+
def __getitem__(self, key):
27+
return self.get(key)
28+
29+
def __setitem__(self, key, value):
30+
self.set(key, value)

0 commit comments

Comments
 (0)