|
3 | 3 | import uuid |
4 | 4 | import httpx |
5 | 5 |
|
6 | | -from typing import Dict, Union, Literal, List |
| 6 | +from typing import Dict, Union, Literal, Set |
7 | 7 |
|
8 | 8 | from contextlib import asynccontextmanager |
9 | 9 | from fastapi import FastAPI |
|
16 | 16 | from contexts import create_context, normalize_language |
17 | 17 | from messaging import ContextWebSocket |
18 | 18 | from stream import StreamingListJsonResponse |
19 | | - |
| 19 | +from utils.locks import LockedMap |
20 | 20 |
|
21 | 21 | logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) |
22 | 22 | logger = logging.Logger(__name__) |
|
25 | 25 |
|
26 | 26 |
|
27 | 27 | websockets: Dict[Union[str, Literal["default"]], ContextWebSocket] = {} |
28 | | -default_websockets: Dict[str, str] = {} |
| 28 | +default_websockets = LockedMap() |
29 | 29 | global client |
30 | 30 |
|
31 | 31 |
|
@@ -84,11 +84,16 @@ async def post_execute(request: ExecutionRequest): |
84 | 84 | context_id = None |
85 | 85 | if request.language: |
86 | 86 | language = normalize_language(request.language) |
87 | | - context_id = default_websockets.get(language) |
88 | 87 |
|
89 | | - if not context_id: |
90 | | - context = await create_context(client, websockets, language, "/home/user") |
91 | | - context_id = context.id |
| 88 | + async with await default_websockets.get_lock(language): |
| 89 | + context_id = default_websockets.get(language) |
| 90 | + |
| 91 | + if not context_id: |
| 92 | + context = await create_context( |
| 93 | + client, websockets, language, "/home/user" |
| 94 | + ) |
| 95 | + context_id = context.id |
| 96 | + default_websockets[language] = context_id |
92 | 97 |
|
93 | 98 | elif request.context_id: |
94 | 99 | context_id = request.context_id |
@@ -120,19 +125,19 @@ async def post_contexts(request: CreateContext) -> Context: |
120 | 125 |
|
121 | 126 |
|
122 | 127 | @app.get("/contexts") |
123 | | -async def get_contexts() -> List[Context]: |
| 128 | +async def get_contexts() -> Set[Context]: |
124 | 129 | logger.info(f"Listing contexts") |
125 | 130 |
|
126 | | - context_ids = list(websockets.keys()) |
| 131 | + context_ids = websockets.keys() |
127 | 132 |
|
128 | | - return [ |
| 133 | + return set( |
129 | 134 | Context( |
130 | 135 | id=websockets[context_id].context_id, |
131 | 136 | language=websockets[context_id].language, |
132 | 137 | cwd=websockets[context_id].cwd, |
133 | 138 | ) |
134 | 139 | for context_id in context_ids |
135 | | - ] |
| 140 | + ) |
136 | 141 |
|
137 | 142 |
|
138 | 143 | @app.post("/contexts/{context_id}/restart") |
|
0 commit comments