Skip to content

Commit c882cb8

Browse files
committed
feat: implement optimistic and pessimistic locking for sandbox metadata
- Add update_version field to SandboxInfo/SandboxRecord and lock_sandbox_key() helper - Add LockResult enum and lock operations to RedisProvider (create_and_acquire, acquire, optimistic_update, release) - Add version-guarded update() to SandboxTable to skip stale writes - Implement SandboxRepository lock context managers (create_and_acquire_lock, acquire_lock) and version-aware CRUD - Wrap SandboxManager.start_async/stop with pessimistic lock context managers; remove _check_sandbox_exists_in_redis - Add tests for optimistic update behaviour in SandboxRepository - Add update_version to SandboxInfoField Literal type
1 parent 32209d6 commit c882cb8

File tree

11 files changed

+684
-62
lines changed

11 files changed

+684
-62
lines changed

rock/actions/sandbox/_generated_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
"create_time",
2727
"start_time",
2828
"stop_time",
29+
"update_version",
2930
]

rock/actions/sandbox/sandbox_info.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ class SandboxInfo(TypedDict, total=False):
2424
create_time: str
2525
start_time: str
2626
stop_time: str
27+
update_version: int

rock/admin/core/redis_key.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
ALIVE_PREFIX = "alive:"
22
TIMEOUT_PREFIX = "timeout:"
3+
LOCK_PREFIX = "lock:"
34

45

56
def alive_sandbox_key(sandbox_id: str) -> str:
@@ -8,3 +9,7 @@ def alive_sandbox_key(sandbox_id: str) -> str:
89

910
def timeout_sandbox_key(sandbox_id: str) -> str:
1011
return f"{TIMEOUT_PREFIX}{sandbox_id}"
12+
13+
14+
def lock_sandbox_key(sandbox_id: str) -> str:
15+
return f"{LOCK_PREFIX}{sandbox_id}"

rock/admin/core/sandbox_table.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,33 @@ async def get(self, sandbox_id: str) -> SandboxInfo | None:
5454
return _record_to_sandbox_info(record)
5555

5656
async def update(self, sandbox_id: str, data: SandboxInfo) -> None:
57-
"""Partial update of an existing sandbox record."""
57+
"""Partial update of an existing sandbox record.
58+
59+
When *data* contains ``update_version``, the write is skipped if the DB
60+
record already carries an equal-or-higher version (prevents stale
61+
fire-and-forget writes from overwriting newer state).
62+
"""
5863
filtered = self._filter_data(data)
5964
if not filtered:
6065
return
61-
66+
new_version: int | None = filtered.get("update_version")
6267
async with self._db.session() as session:
6368
record = await session.get(SandboxRecord, sandbox_id)
6469
if record is None:
6570
logger.warning("update: sandbox_id=%s not found", sandbox_id)
6671
return
72+
if (
73+
new_version is not None
74+
and record.update_version is not None
75+
and new_version <= record.update_version
76+
):
77+
logger.debug(
78+
"update: skip stale write sandbox_id=%s new_version=%s current=%s",
79+
sandbox_id,
80+
new_version,
81+
record.update_version,
82+
)
83+
return
6784
for key, value in filtered.items():
6885
setattr(record, key, value)
6986
await session.commit()

rock/admin/core/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from typing import Any, ClassVar
1010

11-
from sqlalchemy import Boolean, Column, Float, Index, String
11+
from sqlalchemy import Boolean, Column, Float, Index, Integer, String
1212
from sqlalchemy.dialects.postgresql import JSONB
1313
from sqlalchemy.orm import DeclarativeBase
1414
from sqlalchemy.types import JSON
@@ -48,6 +48,7 @@ class SandboxRecord(Base):
4848
create_user_gray_flag = Column(Boolean, nullable=True)
4949
phases = Column(_JSONB_VARIANT, nullable=True)
5050
port_mapping = Column(_JSONB_VARIANT, nullable=True)
51+
update_version = Column(Integer, nullable=True)
5152

5253
__table_args__ = (
5354
Index("ix_sandbox_record_user_id", "user_id"),

rock/sandbox/sandbox_manager.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from rock.sdk.common.exceptions import BadRequestRockError, InternalServerRockError
4040
from rock.utils.crypto_utils import AESEncryption
4141
from rock.utils.format import convert_to_gb, parse_size_to_bytes
42-
from rock.utils.providers.redis_provider import RedisProvider
43-
from rock.utils.service import build_sandbox_from_redis
4442
from rock.utils.system import get_iso8601_timestamp
4543

4644
logger = init_logger(__name__)
@@ -79,12 +77,6 @@ async def refresh_aes_key(self):
7977
logger.error(f"update aes key failed, error: {e}")
8078
raise InternalServerRockError(f"update aes key failed, {str(e)}")
8179

82-
async def _check_sandbox_exists_in_redis(self, config: DeploymentConfig):
83-
if isinstance(config, DockerDeploymentConfig) and config.container_name:
84-
sandbox_id = config.container_name
85-
if self._meta_repo and await self._meta_repo.exists(sandbox_id):
86-
raise BadRequestRockError(f"Sandbox {sandbox_id} already exists")
87-
8880
def _setup_sandbox_actor_metadata(self, sandbox_actor: SandboxActor, user_info: UserInfo) -> None:
8981
user_id = user_info.get("user_id", "default")
9082
experiment_id = user_info.get("experiment_id", "default")
@@ -112,7 +104,6 @@ async def _build_sandbox_info_metadata(
112104
async def start_async(
113105
self, config: DeploymentConfig, user_info: UserInfo = {}, cluster_info: ClusterInfo = {}
114106
) -> SandboxStartResponse:
115-
await self._check_sandbox_exists_in_redis(config)
116107
self.validate_sandbox_spec(self.rock_config.runtime, config)
117108
docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config)
118109

@@ -125,15 +116,24 @@ async def start_async(
125116
)
126117
docker_deployment_config.cpus = self.rock_config.runtime.standard_spec.cpus
127118
docker_deployment_config.memory = self.rock_config.runtime.standard_spec.memory
128-
sandbox_info: SandboxInfo = await self._operator.submit(docker_deployment_config, user_info)
119+
129120
stop_time = str(int(time.time()) + docker_deployment_config.auto_clear_time * 60)
130121
auto_clear_time_dict = {
131122
env_vars.ROCK_SANDBOX_AUTO_CLEAR_TIME_KEY: str(docker_deployment_config.auto_clear_time),
132123
env_vars.ROCK_SANDBOX_EXPIRE_TIME_KEY: stop_time,
133124
}
134-
await self._build_sandbox_info_metadata(sandbox_info, user_info, cluster_info)
125+
135126
if self._meta_repo:
136-
await self._meta_repo.create(sandbox_id, sandbox_info, timeout_info=auto_clear_time_dict)
127+
async with self._meta_repo.create_and_acquire_lock(sandbox_id) as version:
128+
sandbox_info: SandboxInfo = await self._operator.submit(docker_deployment_config, user_info)
129+
await self._build_sandbox_info_metadata(sandbox_info, user_info, cluster_info)
130+
await self._meta_repo.create(
131+
sandbox_id, sandbox_info, timeout_info=auto_clear_time_dict, version=version
132+
)
133+
else:
134+
sandbox_info = await self._operator.submit(docker_deployment_config, user_info)
135+
await self._build_sandbox_info_metadata(sandbox_info, user_info, cluster_info)
136+
137137
return SandboxStartResponse(
138138
sandbox_id=sandbox_id,
139139
host_name=sandbox_info.get("host_name"),
@@ -170,27 +170,37 @@ async def start(self, config: DeploymentConfig) -> SandboxStartResponse:
170170
@monitor_sandbox_operation()
171171
async def stop(self, sandbox_id):
172172
logger.info(f"stop sandbox {sandbox_id}")
173-
sandbox_info: SandboxInfo | None = await self._meta_repo.get(sandbox_id) if self._meta_repo else None
174-
if sandbox_info is None:
175-
sandbox_info = {}
176-
sandbox_info["state"] = State.STOPPED
177-
if sandbox_info.get("start_time"):
178-
sandbox_info["stop_time"] = get_iso8601_timestamp()
179-
log_billing_info(sandbox_info=sandbox_info)
180-
try:
181-
await self._operator.stop(sandbox_id)
182-
except ValueError as e:
183-
logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e)
184-
if self._meta_repo:
185-
await self._meta_repo.archive(sandbox_id, sandbox_info)
173+
174+
if not self._meta_repo:
175+
try:
176+
await self._operator.stop(sandbox_id)
177+
except ValueError as e:
178+
logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e)
179+
try:
180+
self._sandbox_meta.pop(sandbox_id)
181+
except KeyError:
182+
logger.debug(f"{sandbox_id} key not found")
186183
return
187-
try:
188-
self._sandbox_meta.pop(sandbox_id)
189-
except KeyError:
190-
logger.debug(f"{sandbox_id} key not found")
191-
logger.info(f"sandbox {sandbox_id} stopped")
192-
if self._meta_repo:
193-
await self._meta_repo.archive(sandbox_id, sandbox_info)
184+
185+
async with self._meta_repo.acquire_lock(sandbox_id) as version:
186+
sandbox_info: SandboxInfo = await self._meta_repo.get(sandbox_id) or {}
187+
sandbox_info["state"] = State.STOPPED
188+
if sandbox_info.get("start_time"):
189+
sandbox_info["stop_time"] = get_iso8601_timestamp()
190+
log_billing_info(sandbox_info=sandbox_info)
191+
try:
192+
await self._operator.stop(sandbox_id)
193+
except ValueError as e:
194+
logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e)
195+
await self._meta_repo.archive(sandbox_id, sandbox_info, version)
196+
return
197+
try:
198+
self._sandbox_meta.pop(sandbox_id)
199+
except KeyError:
200+
logger.debug(f"{sandbox_id} key not found")
201+
logger.info(f"sandbox {sandbox_id} stopped")
202+
await self._meta_repo.archive(sandbox_id, sandbox_info, version)
203+
# lock_release in context manager silently no-ops: archive already deleted the lock key
194204

195205
async def get_mount(self, sandbox_id):
196206
async with self._ray_service.get_ray_rwlock().read_lock():

rock/sandbox/sandbox_repository.py

Lines changed: 103 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,21 @@
77
from __future__ import annotations
88

99
import asyncio
10+
import socket
1011
import time
1112
from collections.abc import AsyncIterator, Awaitable, Callable
13+
from contextlib import asynccontextmanager
1214
from typing import Any
15+
from uuid import uuid4
1316

1417
from rock.actions.sandbox._generated_types import SandboxInfoField
1518
from rock.actions.sandbox.response import State
1619
from rock.actions.sandbox.sandbox_info import SandboxInfo
17-
from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, timeout_sandbox_key
20+
from rock.admin.core.redis_key import ALIVE_PREFIX, alive_sandbox_key, lock_sandbox_key, timeout_sandbox_key
1821
from rock.admin.core.sandbox_table import SandboxTable
1922
from rock.logger import init_logger
20-
from rock.utils.providers.redis_provider import RedisProvider
23+
from rock.sdk.common.exceptions import BadRequestRockError
24+
from rock.utils.providers.redis_provider import LockResult, RedisProvider
2125

2226
logger = init_logger(__name__)
2327

@@ -53,56 +57,78 @@ async def create(
5357
sandbox_id: str,
5458
sandbox_info: SandboxInfo,
5559
timeout_info: dict[str, str] | None = None,
60+
version: int = 1,
5661
) -> None:
5762
"""Write sandbox info to the Redis alive key and fire-and-forget a DB insert.
5863
5964
Parameters
6065
----------
6166
timeout_info:
6267
If provided, also write the timeout key (``auto_clear_time`` / ``expire_time``).
68+
version:
69+
The version number from ``create_and_acquire_lock``. Stored as
70+
``update_version`` in both the alive key and the DB record.
6371
"""
6472
if self._redis:
65-
await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info)
73+
info_with_version: SandboxInfo = {**sandbox_info, "update_version": version}
74+
await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", info_with_version)
6675
if timeout_info is not None:
6776
await self._redis.json_set(timeout_sandbox_key(sandbox_id), "$", timeout_info)
77+
else:
78+
info_with_version = sandbox_info
6879

69-
self._fire_db_insert(sandbox_id, sandbox_info)
80+
self._fire_db_insert(sandbox_id, {**info_with_version, "update_version": version})
7081

7182
async def update(self, sandbox_id: str, sandbox_info: SandboxInfo) -> None:
72-
"""Merge *sandbox_info* into the existing Redis alive key and fire-and-forget a DB update."""
83+
"""Merge *sandbox_info* into the alive key via optimistic lock; fire DB only on success.
84+
85+
Uses an atomic Lua increment on the lock key:
86+
- ``KEY_ABSENT`` — sandbox was archived; skip.
87+
- ``LOCK_HELD`` — a pessimistic operation is in progress; skip.
88+
- ``OK`` — this writer owns version slot *old_version*; write through.
89+
90+
When no Redis is configured, falls through to DB update unconditionally.
91+
"""
7392
if self._redis:
7493
current = await self._redis.json_get(alive_sandbox_key(sandbox_id), "$")
75-
merged: dict[str, Any] = {**(current[0] if current else {}), **sandbox_info}
76-
await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", merged)
94+
current_doc: dict[str, Any] = current[0] if current else {}
7795

78-
self._fire_db_update(sandbox_id, sandbox_info)
96+
result, old_version = await self._redis.lock_optimistic_update(lock_sandbox_key(sandbox_id))
97+
if result != LockResult.OK:
98+
return # KEY_ABSENT or LOCK_HELD — skip
99+
100+
merged: dict[str, Any] = {**current_doc, **sandbox_info, "update_version": old_version + 1}
101+
await self._redis.json_set(alive_sandbox_key(sandbox_id), "$", merged)
102+
self._fire_db_update(sandbox_id, merged)
103+
else:
104+
self._fire_db_update(sandbox_id, sandbox_info)
79105

80106
async def delete(self, sandbox_id: str) -> None:
81-
"""Delete Redis alive + timeout keys and fire-and-forget a DB delete."""
107+
"""Delete Redis alive + timeout + lock keys and fire-and-forget a DB delete."""
82108
if self._redis:
83109
await self._redis.json_delete(alive_sandbox_key(sandbox_id))
84110
await self._redis.json_delete(timeout_sandbox_key(sandbox_id))
111+
await self._redis.json_delete(lock_sandbox_key(sandbox_id))
85112

86113
self._fire_db_delete(sandbox_id)
87114

88-
async def archive(self, sandbox_id: str, final_info: SandboxInfo) -> None:
89-
"""Persist final state to DB, then remove sandbox from Redis.
90-
91-
Unlike ``delete``, the DB record is preserved and updated with
92-
``final_info`` (e.g. ``stop_time``, ``state``). Use this when a
93-
sandbox has finished its lifecycle and the final state should be
94-
queryable from the DB.
115+
async def archive(
116+
self, sandbox_id: str, final_info: SandboxInfo, version: int | None = None
117+
) -> None:
118+
"""Persist final state to DB, then remove sandbox from Redis including lock key.
95119
96-
The DB write is awaited before the Redis keys are deleted so that
97-
the final state is always durably stored before the alive key
98-
disappears. If the DB write fails the exception is swallowed and
99-
logged, but Redis cleanup still proceeds.
120+
The DB write is awaited before Redis cleanup. *version* (from
121+
``acquire_lock``) is merged into *final_info* so the DB record carries
122+
the correct version.
100123
"""
124+
if version is not None:
125+
final_info = {**final_info, "update_version": version}
101126
await self._await_db_update(sandbox_id, final_info)
102127

103128
if self._redis:
104129
await self._redis.json_delete(alive_sandbox_key(sandbox_id))
105130
await self._redis.json_delete(timeout_sandbox_key(sandbox_id))
131+
await self._redis.json_delete(lock_sandbox_key(sandbox_id))
106132

107133
async def get(self, sandbox_id: str) -> SandboxInfo | None:
108134
"""Read sandbox info from the Redis alive key."""
@@ -241,6 +267,63 @@ async def is_expired(self, sandbox_id: str) -> bool:
241267
expire_time = int(timeout_info.get("expire_time", 0))
242268
return int(time.time()) > expire_time
243269

270+
# ------------------------------------------------------------------
271+
# Lock context managers (pessimistic locking for start / stop)
272+
# ------------------------------------------------------------------
273+
274+
@asynccontextmanager
275+
async def create_and_acquire_lock(self, sandbox_id: str) -> AsyncIterator[int]:
276+
"""NX-create and immediately acquire the sandbox lock (for ``start``).
277+
278+
Yields the initial version number (1). Raises ``BadRequestRockError``
279+
when the sandbox already exists. Always releases the lock on exit even
280+
if the body raises.
281+
282+
When no Redis is configured, yields 1 without any locking.
283+
"""
284+
if not self._redis:
285+
yield 1
286+
return
287+
288+
uuid_str = str(uuid4())
289+
holder = f"{uuid_str}@{socket.gethostname()}"
290+
result, version = await self._redis.lock_create_and_acquire(
291+
lock_sandbox_key(sandbox_id), holder, 1
292+
)
293+
if result == LockResult.ALREADY_EXISTS:
294+
raise BadRequestRockError(f"Sandbox {sandbox_id} already exists")
295+
try:
296+
yield version
297+
finally:
298+
await self._redis.lock_release(lock_sandbox_key(sandbox_id), uuid_str)
299+
300+
@asynccontextmanager
301+
async def acquire_lock(self, sandbox_id: str) -> AsyncIterator[int]:
302+
"""Acquire the sandbox lock on an existing key (for ``stop``).
303+
304+
Yields the new version number (incremented by acquire). Raises
305+
``BadRequestRockError`` when the sandbox does not exist or another
306+
operation holds the lock. Always releases on exit.
307+
308+
When no Redis is configured, yields 1 without any locking.
309+
"""
310+
if not self._redis:
311+
yield 1
312+
return
313+
314+
uuid_str = str(uuid4())
315+
holder = f"{uuid_str}@{socket.gethostname()}"
316+
result, version = await self._redis.lock_acquire(lock_sandbox_key(sandbox_id), holder)
317+
if result == LockResult.KEY_ABSENT:
318+
raise BadRequestRockError(f"Sandbox {sandbox_id} not found")
319+
if result == LockResult.LOCK_HELD:
320+
raise BadRequestRockError(f"Sandbox {sandbox_id} is busy")
321+
try:
322+
yield version
323+
finally:
324+
# lock key is deleted by archive(); lock_release silently no-ops when key is absent
325+
await self._redis.lock_release(lock_sandbox_key(sandbox_id), uuid_str)
326+
244327
# ------------------------------------------------------------------
245328
# Internal helpers
246329
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)