Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ node_modules
logs
docs/superpowers/

.env
.env
*.db
ddl/
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ admin = [
"alibabacloud_cr20181201==2.0.5",
"sqlmodel",
"aiosqlite",
"asyncpg",
"boto3",
"ray[default]==2.43.0",
"pip",
Expand Down Expand Up @@ -171,5 +172,6 @@ markers = [
"need_ray: need ray start",
"need_docker: need docker daemon running",
"need_admin: need admin start",
"need_admin_and_network: need install from network"
"need_admin_and_network: need install from network",
"need_database: need database Docker containers (PostgreSQL, Redis)"
]
3 changes: 3 additions & 0 deletions requirements_admin.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ arckit==0.1.0
async-timeout==5.0.1 ; python_full_version < '3.11.3'
# via
# aiohttp
# asyncpg
# redis
asyncpg==0.31.0
# via rl-rock
attrs==25.4.0
# via
# aiohttp
Expand Down
30 changes: 30 additions & 0 deletions rock/actions/sandbox/_generated_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Auto-generated literal aliases derived from ``SandboxInfo``.

Do not edit this file manually. Regenerate via:
uv run python tools/generate_sandbox_info_types.py
"""

from typing import Literal

SandboxInfoField = Literal[
"host_ip",
"host_name",
"image",
"user_id",
"experiment_id",
"namespace",
"cluster_name",
"sandbox_id",
"auth_token",
"rock_authorization_encrypted",
"phases",
"state",
"port_mapping",
"create_user_gray_flag",
"cpus",
"memory",
"create_time",
"start_time",
"stop_time",
"update_version",
]
1 change: 1 addition & 0 deletions rock/actions/sandbox/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class SandboxResponse(BaseModel):
class State(str, Enum):
PENDING = "pending"
RUNNING = "running"
STOPPED = "stopped"


class IsAliveResponse(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions rock/actions/sandbox/sandbox_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class SandboxInfo(TypedDict, total=False):
create_time: str
start_time: str
stop_time: str
update_version: int
65 changes: 55 additions & 10 deletions rock/admin/core/db_provider.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,66 @@
"""Generic async SQLAlchemy engine/session provider."""

from __future__ import annotations

from typing import TYPE_CHECKING

from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine

from rock.admin.core.schema import DBModelBase
from rock.admin.core.schema import Base
from rock.logger import init_logger

if TYPE_CHECKING:
from rock.config import DatabaseConfig

logger = init_logger(__name__)


class DatabaseProvider:
def __init__(self, db_config: "DatabaseConfig"):
self.db_config = db_config
self.engine: AsyncEngine
"""Async SQLAlchemy engine and session factory.

Supports SQLite (via ``aiosqlite``) and PostgreSQL (via ``asyncpg``).
"""

def __init__(self, db_config: DatabaseConfig) -> None:
self._url = self._convert_url(db_config.url)
self._engine = None
self._session_factory = None

async def init_pool(self) -> None:
"""Create the async engine, session factory, and ensure tables exist."""
logger.info("Initializing database connection pool ...")
self._engine = create_async_engine(self._url, echo=False)
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database connection pool initialised; tables created.")

async def close_pool(self) -> None:
"""Dispose of the engine and release all connections."""
if self._engine is not None:
logger.info("Closing database connection pool ...")
await self._engine.dispose()
logger.info("Database connection pool closed.")

def session(self) -> AsyncSession:
"""Return a new async session context manager."""
if self._session_factory is None:
raise RuntimeError("DatabaseProvider not initialised. Call init_pool() first.")
return self._session_factory()

async def init(self):
self.engine = create_async_engine(self.db_config.url, echo=True)
@staticmethod
def _convert_url(url: str) -> str:
"""Convert synchronous database URLs to their async equivalents.

async def create_tables(self):
async with self.engine.begin() as conn:
await conn.run_sync(DBModelBase.metadata.create_all)
URLs that already include a driver specifier (e.g. ``sqlite+aiosqlite://``,
``postgresql+asyncpg://``) are returned unchanged.
"""
if url.startswith("sqlite:///"):
return url.replace("sqlite:///", "sqlite+aiosqlite:///", 1)
if url.startswith("postgresql://") or url.startswith("postgres://"):
# "postgres://" is the Heroku-style shorthand; both map to asyncpg.
prefix = "postgresql://" if url.startswith("postgresql://") else "postgres://"
return "postgresql+asyncpg://" + url[len(prefix):]
# Pass through URLs that already include a driver specifier
# (e.g. "postgresql+asyncpg://", "sqlite+aiosqlite://").
return url
5 changes: 5 additions & 0 deletions rock/admin/core/redis_key.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
ALIVE_PREFIX = "alive:"
TIMEOUT_PREFIX = "timeout:"
LOCK_PREFIX = "lock:"


def alive_sandbox_key(sandbox_id: str) -> str:
Expand All @@ -8,3 +9,7 @@ def alive_sandbox_key(sandbox_id: str) -> str:

def timeout_sandbox_key(sandbox_id: str) -> str:
return f"{TIMEOUT_PREFIX}{sandbox_id}"


def lock_sandbox_key(sandbox_id: str) -> str:
return f"{LOCK_PREFIX}{sandbox_id}"
151 changes: 128 additions & 23 deletions rock/admin/core/sandbox_table.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,139 @@
from collections.abc import Sequence
"""SandboxTable: sandbox-specific CRUD and query operations over DatabaseProvider."""

from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast, get_type_hints

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession

from rock.admin.core.schema import SandboxRecord
from rock.actions.sandbox._generated_types import SandboxInfoField
from rock.admin.core.db_provider import DatabaseProvider
from rock.admin.core.schema import SandboxRecord # LIST_BY_ALLOWLIST lives here
from rock.logger import init_logger

if TYPE_CHECKING:
from rock.actions.sandbox.sandbox_info import SandboxInfo

logger = init_logger(__name__)

# ---------------------------------------------------------------------------
# SandboxTable
# ---------------------------------------------------------------------------


class SandboxTable:
def __init__(self, engine: AsyncEngine):
self._engine = engine
"""Sandbox-specific database access layer backed by DatabaseProvider.

All CRUD and query methods operate on the ``sandbox_record`` table.
"""

async def create(self, sandbox_record: SandboxRecord):
async with AsyncSession(self._engine) as session:
session.add(sandbox_record)
def __init__(self, db_provider: DatabaseProvider) -> None:
self._db = db_provider

async def create(self, sandbox_id: str, data: SandboxInfo) -> None:
"""Insert a new sandbox record. Raises ``IntegrityError`` if ``sandbox_id`` already exists."""
filtered = self._filter_data(data)

# Ensure NOT NULL columns have values when creating a new record.
for col, default in SandboxRecord._NOT_NULL_DEFAULTS.items():
if col not in filtered:
filtered[col] = default

record = SandboxRecord(sandbox_id=sandbox_id, **filtered)
async with self._db.session() as session:
session.add(record)
await session.commit()

async def list(
self, namespace: str | None = None, user: str | None = None, experiment_id: str | None = None
) -> Sequence[SandboxRecord]:
async with AsyncSession(self._engine) as session:
stmt = select(SandboxRecord)
if None is not namespace:
stmt = stmt.where(SandboxRecord.namespace == namespace)
if None is not user:
stmt = stmt.where(SandboxRecord.user == user)
if None is not experiment_id:
stmt = stmt.where(SandboxRecord.experiment_id == experiment_id)
async def get(self, sandbox_id: str) -> SandboxInfo | None:
"""Return a sandbox row as SandboxInfo, or ``None`` if not found."""
async with self._db.session() as session:
record = await session.get(SandboxRecord, sandbox_id)
if record is None:
return None
return _record_to_sandbox_info(record)

async def update(self, sandbox_id: str, data: SandboxInfo) -> None:
"""Partial update of an existing sandbox record.

When *data* contains ``update_version``, the write is skipped if the DB
record already carries an equal-or-higher version (prevents stale
fire-and-forget writes from overwriting newer state).
"""
filtered = self._filter_data(data)
if not filtered:
return
new_version: int | None = filtered.get("update_version")
async with self._db.session() as session:
record = await session.get(SandboxRecord, sandbox_id)
if record is None:
logger.warning("update: sandbox_id=%s not found", sandbox_id)
return
if (
new_version is not None
and record.update_version is not None
and new_version <= record.update_version
):
logger.debug(
"update: skip stale write sandbox_id=%s new_version=%s current=%s",
sandbox_id,
new_version,
record.update_version,
)
return
for key, value in filtered.items():
setattr(record, key, value)
await session.commit()

async def delete(self, sandbox_id: str) -> None:
"""Hard-delete a sandbox record."""
async with self._db.session() as session:
record = await session.get(SandboxRecord, sandbox_id)
if record is not None:
await session.delete(record)
await session.commit()

async def list_by(self, column: SandboxInfoField, value: str | int | float | bool) -> list[SandboxInfo]:
"""Equality query on a single column. Only columns in ``SandboxRecord.LIST_BY_ALLOWLIST`` are permitted."""
if column not in SandboxRecord.LIST_BY_ALLOWLIST:
raise ValueError(f"Querying by column '{column}' is not allowed")

col_attr = getattr(SandboxRecord, column)
stmt = select(SandboxRecord).where(col_attr == value)
async with self._db.session() as session:
result = await session.execute(stmt)
return [_record_to_sandbox_info(r) for r in result.scalars().all()]

async def list_by_in(
self, column: SandboxInfoField, values: list[str | int | float | bool]
) -> list[SandboxInfo]:
"""IN query on a single column. Only columns in ``SandboxRecord.LIST_BY_ALLOWLIST`` are permitted."""
if column not in SandboxRecord.LIST_BY_ALLOWLIST:
raise ValueError(f"Querying by column '{column}' is not allowed")
if not values:
return []

col_attr = getattr(SandboxRecord, column)
stmt = select(SandboxRecord).where(col_attr.in_(values))
async with self._db.session() as session:
result = await session.execute(stmt)
return result.scalars().all()
return [_record_to_sandbox_info(r) for r in result.scalars().all()]

def _filter_data(self, data: SandboxInfo) -> dict[str, Any]:
"""Keep only keys that correspond to actual table columns, excluding ``sandbox_id``."""
columns = SandboxRecord.column_names()
return {k: v for k, v in data.items() if k in columns and k != "sandbox_id"}


@lru_cache(maxsize=1)
def _sandbox_info_allowed_keys() -> frozenset[str]:
"""Return the set of valid SandboxInfo field names (cached after first call)."""
from rock.actions.sandbox.sandbox_info import SandboxInfo as _SI # local to avoid cycle

return frozenset(get_type_hints(_SI).keys())


async def get(self, id: str) -> SandboxRecord:
async with AsyncSession(self._engine) as session:
return await session.get(SandboxRecord, id)
def _record_to_sandbox_info(record: SandboxRecord) -> SandboxInfo:
"""Map ORM row to ``SandboxInfo`` (runtime value is a plain ``dict``)."""
data = record.to_dict()
return cast("SandboxInfo", {k: v for k, v in data.items() if k in _sandbox_info_allowed_keys()})
Loading
Loading