|
7 | 7 | from __future__ import annotations |
8 | 8 |
|
9 | 9 | import asyncio |
| 10 | +import socket |
10 | 11 | import time |
11 | 12 | from collections.abc import AsyncIterator, Awaitable, Callable |
| 13 | +from contextlib import asynccontextmanager |
12 | 14 | from typing import Any |
| 15 | +from uuid import uuid4 |
13 | 16 |
|
14 | 17 | from rock.actions.sandbox._generated_types import SandboxInfoField |
15 | 18 | from rock.actions.sandbox.response import State |
16 | 19 | from rock.actions.sandbox.sandbox_info import SandboxInfo |
17 | | -from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, timeout_sandbox_key |
| 20 | +from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, lock_sandbox_key, timeout_sandbox_key |
18 | 21 | from rock.admin.core.sandbox_table import SandboxTable |
19 | 22 | from rock.logger import init_logger |
20 | | -from rock.utils.providers.redis_provider import RedisProvider |
| 23 | +from rock.sdk.common.exceptions import BadRequestRockError |
| 24 | +from rock.utils.providers.redis_provider import LockResult, RedisProvider |
21 | 25 |
|
22 | 26 | logger = init_logger(__name__) |
23 | 27 |
|
@@ -53,56 +57,78 @@ async def create( |
53 | 57 | sandbox_id: str, |
54 | 58 | sandbox_info: SandboxInfo, |
55 | 59 | timeout_info: dict[str, str] | None = None, |
| 60 | + version: int = 1, |
56 | 61 | ) -> None: |
57 | 62 | """Write sandbox info to the Redis alive key and fire-and-forget a DB insert. |
58 | 63 |
|
59 | 64 | Parameters |
60 | 65 | ---------- |
61 | 66 | timeout_info: |
62 | 67 | If provided, also write the timeout key (``auto_clear_time`` / ``expire_time``). |
| 68 | + version: |
| 69 | + The version number from ``create_and_acquire_lock``. Stored as |
| 70 | + ``update_version`` in both the alive key and the DB record. |
63 | 71 | """ |
64 | 72 | if self._redis: |
65 | | - await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) |
| 73 | + info_with_version: SandboxInfo = {**sandbox_info, "update_version": version} |
| 74 | + await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", info_with_version) |
66 | 75 | if timeout_info is not None: |
67 | 76 | await self._redis.json_set(timeout_sandbox_key(sandbox_id), "$", timeout_info) |
| 77 | + else: |
| 78 | + info_with_version = sandbox_info |
68 | 79 |
|
69 | | - self._fire_db_insert(sandbox_id, sandbox_info) |
| 80 | + self._fire_db_insert(sandbox_id, {**info_with_version, "update_version": version}) |
70 | 81 |
|
71 | 82 | async def update(self, sandbox_id: str, sandbox_info: SandboxInfo) -> None: |
72 | | - """Merge *sandbox_info* into the existing Redis alive key and fire-and-forget a DB update.""" |
| 83 | + """Merge *sandbox_info* into the alive key via optimistic lock; fire DB only on success. |
| 84 | +
|
| 85 | + Uses an atomic Lua increment on the lock key: |
| 86 | + - ``KEY_ABSENT`` — sandbox was archived; skip. |
| 87 | + - ``LOCK_HELD`` — a pessimistic operation is in progress; skip. |
| 88 | + - ``OK`` — this writer owns version slot *old_version*; write through. |
| 89 | +
|
| 90 | + When no Redis is configured, falls through to DB update unconditionally. |
| 91 | + """ |
73 | 92 | if self._redis: |
74 | 93 | current = await self._redis.json_get(alive_sandbox_key(sandbox_id), "$") |
75 | | - merged: dict[str, Any] = {**(current[0] if current else {}), **sandbox_info} |
76 | | - await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", merged) |
| 94 | + current_doc: dict[str, Any] = current[0] if current else {} |
77 | 95 |
|
78 | | - self._fire_db_update(sandbox_id, sandbox_info) |
| 96 | + result, old_version = await self._redis.lock_optimistic_update(lock_sandbox_key(sandbox_id)) |
| 97 | + if result != LockResult.OK: |
| 98 | + return # KEY_ABSENT or LOCK_HELD — skip |
| 99 | + |
| 100 | + merged: dict[str, Any] = {**current_doc, **sandbox_info, "update_version": old_version + 1} |
| 101 | + await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", merged) |
| 102 | + self._fire_db_update(sandbox_id, merged) |
| 103 | + else: |
| 104 | + self._fire_db_update(sandbox_id, sandbox_info) |
79 | 105 |
|
80 | 106 | async def delete(self, sandbox_id: str) -> None: |
81 | | - """Delete Redis alive + timeout keys and fire-and-forget a DB delete.""" |
| 107 | + """Delete Redis alive + timeout + lock keys and fire-and-forget a DB delete.""" |
82 | 108 | if self._redis: |
83 | 109 | await self._redis.json_delete(alive_sandbox_key(sandbox_id)) |
84 | 110 | await self._redis.json_delete(timeout_sandbox_key(sandbox_id)) |
| 111 | + await self._redis.json_delete(lock_sandbox_key(sandbox_id)) |
85 | 112 |
|
86 | 113 | self._fire_db_delete(sandbox_id) |
87 | 114 |
|
88 | | - async def archive(self, sandbox_id: str, final_info: SandboxInfo) -> None: |
89 | | - """Persist final state to DB, then remove sandbox from Redis. |
90 | | -
|
91 | | - Unlike ``delete``, the DB record is preserved and updated with |
92 | | - ``final_info`` (e.g. ``stop_time``, ``state``). Use this when a |
93 | | - sandbox has finished its lifecycle and the final state should be |
94 | | - queryable from the DB. |
| 115 | + async def archive( |
| 116 | + self, sandbox_id: str, final_info: SandboxInfo, version: int | None = None |
| 117 | + ) -> None: |
| 118 | + """Persist final state to DB, then remove sandbox from Redis including lock key. |
95 | 119 |
|
96 | | - The DB write is awaited before the Redis keys are deleted so that |
97 | | - the final state is always durably stored before the alive key |
98 | | - disappears. If the DB write fails the exception is swallowed and |
99 | | - logged, but Redis cleanup still proceeds. |
| 120 | + The DB write is awaited before Redis cleanup. *version* (from |
| 121 | + ``acquire_lock``) is merged into *final_info* so the DB record carries |
| 122 | + the correct version. |
100 | 123 | """ |
| 124 | + if version is not None: |
| 125 | + final_info = {**final_info, "update_version": version} |
101 | 126 | await self._await_db_update(sandbox_id, final_info) |
102 | 127 |
|
103 | 128 | if self._redis: |
104 | 129 | await self._redis.json_delete(alive_sandbox_key(sandbox_id)) |
105 | 130 | await self._redis.json_delete(timeout_sandbox_key(sandbox_id)) |
| 131 | + await self._redis.json_delete(lock_sandbox_key(sandbox_id)) |
106 | 132 |
|
107 | 133 | async def get(self, sandbox_id: str) -> SandboxInfo | None: |
108 | 134 | """Read sandbox info from the Redis alive key.""" |
@@ -241,6 +267,63 @@ async def is_expired(self, sandbox_id: str) -> bool: |
241 | 267 | expire_time = int(timeout_info.get("expire_time", 0)) |
242 | 268 | return int(time.time()) > expire_time |
243 | 269 |
|
| 270 | + # ------------------------------------------------------------------ |
| 271 | + # Lock context managers (pessimistic locking for start / stop) |
| 272 | + # ------------------------------------------------------------------ |
| 273 | + |
| 274 | + @asynccontextmanager |
| 275 | + async def create_and_acquire_lock(self, sandbox_id: str) -> AsyncIterator[int]: |
| 276 | + """NX-create and immediately acquire the sandbox lock (for ``start``). |
| 277 | +
|
| 278 | + Yields the initial version number (1). Raises ``BadRequestRockError`` |
| 279 | + when the sandbox already exists. Always releases the lock on exit even |
| 280 | + if the body raises. |
| 281 | +
|
| 282 | + When no Redis is configured, yields 1 without any locking. |
| 283 | + """ |
| 284 | + if not self._redis: |
| 285 | + yield 1 |
| 286 | + return |
| 287 | + |
| 288 | + uuid_str = str(uuid4()) |
| 289 | + holder = f"{uuid_str}@{socket.gethostname()}" |
| 290 | + result, version = await self._redis.lock_create_and_acquire( |
| 291 | + lock_sandbox_key(sandbox_id), holder, 1 |
| 292 | + ) |
| 293 | + if result == LockResult.ALREADY_EXISTS: |
| 294 | + raise BadRequestRockError(f"Sandbox {sandbox_id} already exists") |
| 295 | + try: |
| 296 | + yield version |
| 297 | + finally: |
| 298 | + await self._redis.lock_release(lock_sandbox_key(sandbox_id), uuid_str) |
| 299 | + |
| 300 | + @asynccontextmanager |
| 301 | + async def acquire_lock(self, sandbox_id: str) -> AsyncIterator[int]: |
| 302 | + """Acquire the sandbox lock on an existing key (for ``stop``). |
| 303 | +
|
| 304 | + Yields the new version number (incremented by acquire). Raises |
| 305 | + ``BadRequestRockError`` when the sandbox does not exist or another |
| 306 | + operation holds the lock. Always releases on exit. |
| 307 | +
|
| 308 | + When no Redis is configured, yields 1 without any locking. |
| 309 | + """ |
| 310 | + if not self._redis: |
| 311 | + yield 1 |
| 312 | + return |
| 313 | + |
| 314 | + uuid_str = str(uuid4()) |
| 315 | + holder = f"{uuid_str}@{socket.gethostname()}" |
| 316 | + result, version = await self._redis.lock_acquire(lock_sandbox_key(sandbox_id), holder) |
| 317 | + if result == LockResult.KEY_ABSENT: |
| 318 | + raise BadRequestRockError(f"Sandbox {sandbox_id} not found") |
| 319 | + if result == LockResult.LOCK_HELD: |
| 320 | + raise BadRequestRockError(f"Sandbox {sandbox_id} is busy") |
| 321 | + try: |
| 322 | + yield version |
| 323 | + finally: |
| 324 | + # lock key is deleted by archive(); lock_release silently no-ops when key is absent |
| 325 | + await self._redis.lock_release(lock_sandbox_key(sandbox_id), uuid_str) |
| 326 | + |
244 | 327 | # ------------------------------------------------------------------ |
245 | 328 | # Internal helpers |
246 | 329 | # ------------------------------------------------------------------ |
|
0 commit comments