Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
119bc92
Refactor TinkerEngine to use backend architecture
OhadRubin Dec 17, 2025
77ea8b2
Enhance TinkerEngine and backend integration
OhadRubin Dec 17, 2025
2bab2cf
Add MaxTextBackend support to TinkerEngine
OhadRubin Dec 17, 2025
f25653a
Add maxtext dependency and improve engine configuration
OhadRubin Dec 17, 2025
228c5fb
Add convert_maxtext_lora_to_hf function
OhadRubin Dec 17, 2025
9177aed
Update skyrl-tx/tx/tinker/backends/maxtext.py
OhadRubin Dec 17, 2025
6842b1f
Update skyrl-tx/tx/tinker/api.py
OhadRubin Dec 17, 2025
7d75ce6
Update skyrl-tx/tx/tinker/backends/native.py
OhadRubin Dec 17, 2025
bdd8039
Update skyrl-tx/tx/tinker/backends/native.py
OhadRubin Dec 17, 2025
282d79a
Pin maxtext dependency to specific commit
OhadRubin Dec 17, 2025
e6a7013
Merge branch 'maxtext_backend' of https://github.com/OhadRubin/SkyRL …
OhadRubin Dec 17, 2025
b9510f4
Remove duplicate pad_batch, import from utils
OhadRubin Dec 17, 2025
b1e1508
Remove hardcoded path fallback in _get_maxtext_base_config_path
OhadRubin Dec 17, 2025
3bde089
Fix MaxText backend method signatures to match AbstractBackend
OhadRubin Dec 17, 2025
56fdd40
Enhance TinkerEngine and backend functionality
OhadRubin Dec 18, 2025
725f3cb
Implement LoRA weight reset functionality in MaxText backend
OhadRubin Dec 20, 2025
9984637
Add TTL storage structure and update NativeBackend for eager sharding
OhadRubin Dec 20, 2025
8cc0546
Add sampler checkpoint eviction logic and enhance MaxText backend
OhadRubin Dec 20, 2025
58d52d5
Refactor retry logic in ExternalInferenceClient for improved error ha…
OhadRubin Dec 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion skyrl-tx/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "skyrl-tx"
dynamic = ["version"]
description = "Unified API for training and inference"
readme = "README.md"
requires-python = ">=3.11"
requires-python = "==3.12.*"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The requires-python has been changed to ==3.12.*, which is very restrictive. This will prevent users on other Python 3.12 patch versions or future minor versions from using the library. Unless there's a strong reason for this exact version, consider using a more flexible specifier like >=3.12.

Suggested change
requires-python = "==3.12.*"
requires-python = ">=3.12"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxtext is annoying and requires python 3.12, actually, i haven't tested it with python 3.13

dependencies = [
"datasets>=4.0.0",
"flax>=0.12.0",
Expand All @@ -30,6 +30,10 @@ gpu = [
"jax[cuda12]>=0.7.2",
]

maxtext = [
"maxtext @ git+https://github.com/OhadRubin/maxtext@1edde4d1b1d562173d1753650b0234aa5c6a2fea",
]

tpu = [
"jax[tpu]>=0.7.2",
]
Expand All @@ -42,6 +46,7 @@ tinker = [
"aiosqlite",
"asyncpg",
"psycopg2-binary",
"tenacity",
]

aws = [
Expand Down
107 changes: 101 additions & 6 deletions skyrl-tx/tx/tinker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
ID_PATTERN = r"^[a-zA-Z0-9_-]+$"
ID_MAX_LENGTH = 255

# Maximum number of sampler checkpoints to keep per model (oldest are evicted)
MAX_SAMPLER_CHECKPOINTS_PER_MODEL = 3


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -55,7 +58,10 @@ async def lifespan(app: FastAPI):
logger.info("Using internal engine for inference")

# Build subprocess command with engine config parameters
cmd = ["uv", "run", "--extra", "tinker", "-m", "tx.tinker.engine"]
cmd = ["uv", "run", "--no-sync", "--extra", "tinker"]
if app.state.engine_config.maxtext_config_str:
cmd += ["--extra", "maxtext"]
cmd += ["-m", "tx.tinker.engine"]
cmd.extend(config_to_argv(app.state.engine_config))

background_engine = subprocess.Popen(cmd)
Expand Down Expand Up @@ -130,16 +136,100 @@ async def create_checkpoint(
try:
await session.flush()
except IntegrityError:
# Determine which constraint failed by checking if the model exists
await session.rollback()
# Check if the model exists
statement = select(ModelDB).where(ModelDB.model_id == model_id)
result = await session.exec(statement)

if not result.first():
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
else:
raise HTTPException(
status_code=409, detail=f"Checkpoint '{checkpoint_id}' already exists for model '{model_id}'"

# Delete existing checkpoint and create new one
delete_stmt = select(CheckpointDB).where(
CheckpointDB.model_id == model_id,
CheckpointDB.checkpoint_id == checkpoint_id,
CheckpointDB.checkpoint_type == checkpoint_type,
)
existing = (await session.exec(delete_stmt)).first()
if existing:
await session.delete(existing)
await session.flush()

# Re-add the new checkpoint
checkpoint_db = CheckpointDB(
model_id=model_id,
checkpoint_id=checkpoint_id,
checkpoint_type=checkpoint_type,
status=CheckpointStatus.PENDING,
)
session.add(checkpoint_db)
await session.flush()


async def evict_old_sampler_checkpoints(
request: Request,
session: AsyncSession,
model_id: str,
max_count: int = MAX_SAMPLER_CHECKPOINTS_PER_MODEL,
):
"""Delete oldest sampler checkpoints if count exceeds max_count.

Called before creating a new sampler checkpoint to make room.
Deletes the database entry, the checkpoint archive, and the extracted lora directory (if exists).

Args:
request: FastAPI request (for accessing engine config)
session: Database session
model_id: The model whose checkpoints to manage
max_count: Maximum number of sampler checkpoints to keep (default: 3)
"""
import shutil

engine_config = request.app.state.engine_config

# Get all sampler checkpoints for this model, ordered by creation time (oldest first)
statement = (
select(CheckpointDB)
.where(CheckpointDB.model_id == model_id)
.where(CheckpointDB.checkpoint_type == types.CheckpointType.SAMPLER)
.order_by(CheckpointDB.created_at.asc())
)
result = await session.exec(statement)
checkpoints = result.all()

# If we have max_count or more, delete the oldest ones to make room for the new one
if len(checkpoints) >= max_count:
# Delete oldest checkpoints, keeping only (max_count - 1) to make room for new one
to_delete = checkpoints[: len(checkpoints) - max_count + 1]
for checkpoint in to_delete:
checkpoint_id = checkpoint.checkpoint_id

# Delete checkpoint archive from disk
checkpoint_path = (
engine_config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz"
)
try:
if checkpoint_path.exists():
checkpoint_path.unlink()
logger.info(f"Deleted sampler checkpoint file: {checkpoint_path}")
except Exception as e:
logger.warning(f"Failed to delete checkpoint file {checkpoint_path}: {e}")

# Delete extracted lora directory (used by external inference / vLLM)
if engine_config.external_inference_lora_base:
lora_dir = engine_config.external_inference_lora_base / f"{model_id}_{checkpoint_id}"
try:
if lora_dir.exists():
shutil.rmtree(lora_dir)
logger.info(f"Deleted extracted lora directory: {lora_dir}")
except Exception as e:
logger.warning(f"Failed to delete lora directory {lora_dir}: {e}")

# Delete from database
await session.delete(checkpoint)
logger.info(f"Evicted sampler checkpoint: {model_id}/{checkpoint_id}")

await session.flush()


class LoRAConfig(BaseModel):
Expand Down Expand Up @@ -683,8 +773,13 @@ async def save_weights(request: SaveWeightsRequest, session: AsyncSession = Depe


@app.post("/api/v1/save_weights_for_sampler", response_model=FutureResponse)
async def save_weights_for_sampler(request: SaveWeightsForSamplerRequest, session: AsyncSession = Depends(get_session)):
async def save_weights_for_sampler(
request: SaveWeightsForSamplerRequest, req: Request, session: AsyncSession = Depends(get_session)
):
"""Saves weights in a format compatible with sampling/inference servers."""
# Evict old sampler checkpoints to keep only the last K
await evict_old_sampler_checkpoints(req, session, request.model_id)

# Create pending checkpoint entry (validates model exists)
await create_checkpoint(
session=session,
Expand Down
7 changes: 7 additions & 0 deletions skyrl-tx/tx/tinker/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Tinker engine backends."""

from tx.tinker.backends.backend import AbstractBackend
from tx.tinker.backends.native import NativeBackend
from tx.tinker.backends.maxtext import MaxTextBackend, parse_maxtext_config

__all__ = ["AbstractBackend", "NativeBackend", "MaxTextBackend", "parse_maxtext_config"]
Loading