diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index f6f136a46..23e9209e8 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -7,7 +7,7 @@ name = "skyrl-tx" dynamic = ["version"] description = "Unified API for training and inference" readme = "README.md" -requires-python = ">=3.11" +requires-python = "==3.12.*" dependencies = [ "datasets>=4.0.0", "flax>=0.12.0", @@ -30,6 +30,10 @@ gpu = [ "jax[cuda12]>=0.7.2", ] +maxtext = [ + "maxtext @ git+https://github.com/OhadRubin/maxtext@1edde4d1b1d562173d1753650b0234aa5c6a2fea", +] + tpu = [ "jax[tpu]>=0.7.2", ] @@ -42,6 +46,7 @@ tinker = [ "aiosqlite", "asyncpg", "psycopg2-binary", + "tenacity", ] aws = [ diff --git a/skyrl-tx/tx/tinker/api.py b/skyrl-tx/tx/tinker/api.py index 3f4c28f24..a8de2baed 100644 --- a/skyrl-tx/tx/tinker/api.py +++ b/skyrl-tx/tx/tinker/api.py @@ -35,6 +35,9 @@ ID_PATTERN = r"^[a-zA-Z0-9_-]+$" ID_MAX_LENGTH = 255 +# Maximum number of sampler checkpoints to keep per model (oldest are evicted) +MAX_SAMPLER_CHECKPOINTS_PER_MODEL = 3 + @asynccontextmanager async def lifespan(app: FastAPI): @@ -55,7 +58,10 @@ async def lifespan(app: FastAPI): logger.info("Using internal engine for inference") # Build subprocess command with engine config parameters - cmd = ["uv", "run", "--extra", "tinker", "-m", "tx.tinker.engine"] + cmd = ["uv", "run", "--no-sync", "--extra", "tinker"] + if app.state.engine_config.maxtext_config_str: + cmd += ["--extra", "maxtext"] + cmd += ["-m", "tx.tinker.engine"] cmd.extend(config_to_argv(app.state.engine_config)) background_engine = subprocess.Popen(cmd) @@ -130,16 +136,100 @@ async def create_checkpoint( try: await session.flush() except IntegrityError: - # Determine which constraint failed by checking if the model exists + await session.rollback() + # Check if the model exists statement = select(ModelDB).where(ModelDB.model_id == model_id) result = await session.exec(statement) if not result.first(): raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") - else: - raise HTTPException( - status_code=409, detail=f"Checkpoint '{checkpoint_id}' already exists for model '{model_id}'" + + # Delete existing checkpoint and create new one + delete_stmt = select(CheckpointDB).where( + CheckpointDB.model_id == model_id, + CheckpointDB.checkpoint_id == checkpoint_id, + CheckpointDB.checkpoint_type == checkpoint_type, + ) + existing = (await session.exec(delete_stmt)).first() + if existing: + await session.delete(existing) + await session.flush() + + # Re-add the new checkpoint + checkpoint_db = CheckpointDB( + model_id=model_id, + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + status=CheckpointStatus.PENDING, + ) + session.add(checkpoint_db) + await session.flush() + + +async def evict_old_sampler_checkpoints( + request: Request, + session: AsyncSession, + model_id: str, + max_count: int = MAX_SAMPLER_CHECKPOINTS_PER_MODEL, +): + """Delete oldest sampler checkpoints if count exceeds max_count. + + Called before creating a new sampler checkpoint to make room. + Deletes the database entry, the checkpoint archive, and the extracted lora directory (if exists). + + Args: + request: FastAPI request (for accessing engine config) + session: Database session + model_id: The model whose checkpoints to manage + max_count: Maximum number of sampler checkpoints to keep (default: 3) + """ + import shutil + + engine_config = request.app.state.engine_config + + # Get all sampler checkpoints for this model, ordered by creation time (oldest first) + statement = ( + select(CheckpointDB) + .where(CheckpointDB.model_id == model_id) + .where(CheckpointDB.checkpoint_type == types.CheckpointType.SAMPLER) + .order_by(CheckpointDB.created_at.asc()) + ) + result = await session.exec(statement) + checkpoints = result.all() + + # If we have max_count or more, delete the oldest ones to make room for the new one + if len(checkpoints) >= max_count: + # Delete oldest checkpoints, keeping only (max_count - 1) to make room for new one + to_delete = checkpoints[: len(checkpoints) - max_count + 1] + for checkpoint in to_delete: + checkpoint_id = checkpoint.checkpoint_id + + # Delete checkpoint archive from disk + checkpoint_path = ( + engine_config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz" ) + try: + if checkpoint_path.exists(): + checkpoint_path.unlink() + logger.info(f"Deleted sampler checkpoint file: {checkpoint_path}") + except Exception as e: + logger.warning(f"Failed to delete checkpoint file {checkpoint_path}: {e}") + + # Delete extracted lora directory (used by external inference / vLLM) + if engine_config.external_inference_lora_base: + lora_dir = engine_config.external_inference_lora_base / f"{model_id}_{checkpoint_id}" + try: + if lora_dir.exists(): + shutil.rmtree(lora_dir) + logger.info(f"Deleted extracted lora directory: {lora_dir}") + except Exception as e: + logger.warning(f"Failed to delete lora directory {lora_dir}: {e}") + + # Delete from database + await session.delete(checkpoint) + logger.info(f"Evicted sampler checkpoint: {model_id}/{checkpoint_id}") + + await session.flush() class LoRAConfig(BaseModel): @@ -683,8 +773,13 @@ async def save_weights(request: SaveWeightsRequest, session: AsyncSession = Depe @app.post("/api/v1/save_weights_for_sampler", response_model=FutureResponse) -async def save_weights_for_sampler(request: SaveWeightsForSamplerRequest, session: AsyncSession = Depends(get_session)): +async def save_weights_for_sampler( + request: SaveWeightsForSamplerRequest, req: Request, session: AsyncSession = Depends(get_session) +): """Saves weights in a format compatible with sampling/inference servers.""" + # Evict old sampler checkpoints to keep only the last K + await evict_old_sampler_checkpoints(req, session, request.model_id) + # Create pending checkpoint entry (validates model exists) await create_checkpoint( session=session, diff --git a/skyrl-tx/tx/tinker/backends/__init__.py b/skyrl-tx/tx/tinker/backends/__init__.py new file mode 100644 index 000000000..c1dc5a676 --- /dev/null +++ b/skyrl-tx/tx/tinker/backends/__init__.py @@ -0,0 +1,7 @@ +"""Tinker engine backends.""" + +from tx.tinker.backends.backend import AbstractBackend +from tx.tinker.backends.native import NativeBackend +from tx.tinker.backends.maxtext import MaxTextBackend, parse_maxtext_config + +__all__ = ["AbstractBackend", "NativeBackend", "MaxTextBackend", "parse_maxtext_config"] diff --git a/skyrl-tx/tx/tinker/backends/backend.py b/skyrl-tx/tx/tinker/backends/backend.py new file mode 100644 index 000000000..a2f0cf32e --- /dev/null +++ b/skyrl-tx/tx/tinker/backends/backend.py @@ -0,0 +1,247 @@ +"""Abstract backend interface for TinkerEngine. + +Backends handle all model state and computation. The engine handles file I/O and database operations. + +Design: + 1. AbstractBackend (backend.py) + Clean interface defining what backends must implement: + - register_model, unregister_model (optimizer lifecycle managed internally) + - process_forward_backward_batch, process_forward_batch, process_optim_step, process_sample_batch + - extract_checkpoint_data, insert_checkpoint_data (pure state manipulation) + - extract_sampler_weights, insert_sampler_weights + - save_checkpoint, save_sampler_checkpoint (file I/O in backend) + + 2. NativeBackend (native.py) + - Implements all abstract methods fully for Qwen3 + LoRA + - Uses jax.value_and_grad for gradient computation + - Uses 2D mesh (dp, tp) + - Multi-adapter AccumulatedGradients with counts array + + 3. MaxTextBackend (maxtext.py) + - Single-adapter backend (asserts max_lora_adapters=1) + - Uses MaxText model with LoRA + + 4. TinkerEngine (engine.py) + - Instantiates backend based on config + - Delegates computation to self.backend + - Handles all database operations + - Manages model metadata (backends manage optimizers internally) +""" + +from abc import ABC, abstractmethod + +import jax +from flax import nnx + +from tx.tinker import types +from tx.tinker.config import EngineConfig + + +class AbstractBackend(ABC): + """Abstract base class for TinkerEngine backends. + + Backends handle computation and model state manipulation. + Database operations are handled by TinkerEngine. + """ + + config: EngineConfig + mesh: jax.sharding.Mesh + model: nnx.Module + metrics: types.EngineMetrics + graphdef: nnx.GraphDef + lora_params: nnx.State + non_lora_params: nnx.State + + @abstractmethod + def __init__(self, config: EngineConfig, **kwargs): + """Initialize the backend.""" + pass + + @abstractmethod + def register_model(self, model_id: str, adapter_index: int, lora_config: types.LoraConfig) -> None: + """Register a new model with the backend. + + Creates optimizer and configures LoRA adapter internally. + + Args: + model_id: The model identifier + adapter_index: The adapter slot index to use + lora_config: LoRA configuration with rank and alpha + """ + pass + + @abstractmethod + def unregister_model(self, model_id: str, adapter_index: int) -> None: + """Unregister a model from the backend. + + Removes optimizer and resets adapter weights. + + Args: + model_id: The model identifier + adapter_index: The adapter slot index to reset + """ + pass + + @abstractmethod + def process_forward_backward_batch( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Process forward_backward requests in a batch. + + Args: + prepared_batch: PreparedModelPassBatch with all data extracted from requests + + Returns: + Dict mapping request_id to result or error + """ + pass + + @abstractmethod + def process_forward_batch( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Process forward-only requests in a batch (no gradient computation). + + Args: + prepared_batch: PreparedModelPassBatch with all data extracted from requests + + Returns: + Dict mapping request_id to result or error + """ + pass + + @abstractmethod + def process_optim_step( + self, + model_id: str, + adapter_index: int, + request_data: types.OptimStepInput, + ) -> types.OptimStepOutput: + """Process an optimizer step request. + + Args: + model_id: The model identifier + adapter_index: The adapter index for this model + request_data: The optimizer step input parameters + + Returns: + OptimStepOutput result + """ + pass + + @abstractmethod + def process_sample_batch( + self, + prepared_batch: types.PreparedSampleBatch, + ) -> dict[str, types.SampleOutput | types.ErrorResponse]: + """Process multiple sample requests in a single batch. + + Args: + prepared_batch: PreparedSampleBatch with all data extracted from requests + + Returns: + Dict mapping request_id to result or error + """ + pass + + @abstractmethod + def save_checkpoint( + self, + output_path, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> None: + """Save training checkpoint to disk. + + Args: + output_path: Path to save the checkpoint + model_id: The model identifier + models: Dict mapping model_id to ModelMetadata + """ + pass + + @abstractmethod + def extract_checkpoint_data( + self, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> dict: + """Extract model state for checkpointing. + + Args: + model_id: The model identifier + models: Dict mapping model_id to ModelMetadata + + Returns: + Dictionary containing checkpoint data (weights, optimizer state, config). + """ + pass + + @abstractmethod + def insert_checkpoint_data( + self, + model_id: str, + checkpoint_data: dict, + models: dict[str, types.ModelMetadata], + ) -> None: + """Insert checkpoint data into model state. + + Args: + model_id: The model identifier + checkpoint_data: Dictionary from extract_checkpoint_data or loaded from disk + models: Dict mapping model_id to ModelMetadata + """ + pass + + @abstractmethod + def save_sampler_checkpoint( + self, + output_path, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> None: + """Save sampler checkpoint to disk as tar.gz. + + Args: + output_path: Path to save the checkpoint tar.gz file + model_id: The model identifier + models: Dict mapping model_id to ModelMetadata + """ + pass + + @abstractmethod + def extract_sampler_weights( + self, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> dict: + """Extract weights for sampler checkpoint. + + Args: + model_id: The model identifier + models: Dict mapping model_id to ModelMetadata + + Returns: + Dictionary containing sampler weights data. + """ + pass + + @abstractmethod + def insert_sampler_weights( + self, + model_id: str, + checkpoint_id: str, + checkpoint_path, + models: dict[str, types.ModelMetadata], + ) -> None: + """Insert sampler weights into model state from checkpoint file. + + Args: + model_id: The model identifier + checkpoint_id: The checkpoint identifier + checkpoint_path: Path to the checkpoint file + models: Dict mapping model_id to ModelMetadata + """ + pass diff --git a/skyrl-tx/tx/tinker/backends/maxtext.py b/skyrl-tx/tx/tinker/backends/maxtext.py new file mode 100644 index 000000000..ca531fa54 --- /dev/null +++ b/skyrl-tx/tx/tinker/backends/maxtext.py @@ -0,0 +1,600 @@ +"""MaxText backend for TinkerEngine.""" + +import time +from contextlib import contextmanager +from dataclasses import dataclass + +import numpy as np +import jax +import jax.numpy as jnp +import optax +from flax import nnx +from flax.linen import partitioning as nn_partitioning + +from tx.tinker import types +from tx.tinker.config import EngineConfig +from tx.tinker.backends.backend import AbstractBackend +from tx.tinker.backends.utils import pad_batch +from tx.tinker.loss_fns import LOSS_FUNCTIONS +from tx.utils.models import round_up_seq_len, convert_maxtext_lora_to_hf +from tx.utils.storage import pack_and_upload +from tx.utils.log import logger + +# MaxText imports +import MaxText +from MaxText import maxtext_utils +from MaxText import model_creation_utils as maxtext_model_creation +from MaxText import sharding as maxtext_sharding +from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter + + + +def reset_adapter_weights(model): + + state = nnx.state(model) + + def update_lora_config(path, value): + normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) + j_path = "/".join(normalized_path) + if "lora_b" in j_path: + return value.at[...].set(0.0) + return value + + updated_state = jax.tree.map_with_path(update_lora_config, state) + nnx.update(model, updated_state) + +def _get_maxtext_base_config_path() -> str: + """Get the absolute path to MaxText's base.yml config file.""" + import os + from importlib.resources import files + try: + # Try importlib.resources first (works if MaxText packages configs properly) + config_path = str(files("MaxText").joinpath("configs", "base.yml")) + if os.path.exists(config_path): + return config_path + except (TypeError, FileNotFoundError): + pass + # Fallback: derive from package location + maxtext_pkg_dir = os.path.dirname(MaxText.__file__) + maxtext_root = os.path.dirname(os.path.dirname(maxtext_pkg_dir)) + config_path = os.path.join(maxtext_root, "src", "MaxText", "configs", "base.yml") + if not os.path.exists(config_path): + raise FileNotFoundError( + f"Could not find MaxText base.yml config. Tried: {config_path}. " + "Ensure MaxText is installed correctly or set MAXTEXT_CONFIG_PATH environment variable." + ) + return config_path + + +def parse_maxtext_config(config_str: str): + """Parse MaxText config from space-separated key=value string.""" + if not config_str: + return None + config_path = _get_maxtext_base_config_path() + logger.info(f"Using MaxText config: {config_path}") + argv = ["", config_path] + config_str.split() + from MaxText import pyconfig as maxtext_pyconfig + return maxtext_pyconfig.initialize(argv) + + + +def _count_params(pytree) -> int: + """Count total number of parameters in a pytree.""" + def get_numel(x): + if hasattr(x, 'shape'): + return int(np.prod(x.shape)) + return 0 + counts = jax.tree.leaves(jax.tree.map(get_numel, pytree)) + return sum(counts) + + +@jax.tree_util.register_dataclass +@dataclass +class AccumulatedGradients: + """Stores accumulated gradients.""" + + grad_sum: nnx.State + count: jax.Array + + @classmethod + def create(cls, lora_params: nnx.State) -> "AccumulatedGradients": + """Initialize with zeros.""" + return cls( + grad_sum=jax.tree.map(jnp.zeros_like, lora_params), + count=jnp.zeros((1,), dtype=jnp.int32), + ) + + def add(self, lora_grads: nnx.State, batch_size: int) -> "AccumulatedGradients": + """Accumulate gradients and increment count.""" + return AccumulatedGradients( + grad_sum=jax.tree.map(lambda a, b: a + b, self.grad_sum, lora_grads), + count=self.count + batch_size, + ) + + def get_mean(self) -> nnx.State: + """Compute mean gradients.""" + return jax.tree.map( + lambda g: g / self.count.astype(g.dtype), + self.grad_sum, + ) + + def reset(self) -> "AccumulatedGradients": + """Reset gradients and count.""" + return AccumulatedGradients( + grad_sum=jax.tree.map(jnp.zeros_like, self.grad_sum), + count=jnp.zeros((1,), dtype=jnp.int32), + ) + + +class MaxTextBackend(AbstractBackend): + """Backend for MaxText models with context parallelism. + + This is a single-adapter backend (max_lora_adapters must be 1). + """ + + def __init__(self, config: EngineConfig, maxtext_config): + """Initialize MaxText backend.""" + if config.max_lora_adapters != 1: + raise ValueError( + f"MaxTextBackend only supports single adapter (max_lora_adapters=1), " + f"got max_lora_adapters={config.max_lora_adapters}" + ) + + self.config = config + self.maxtext_config = maxtext_config + self.metrics = types.EngineMetrics() + + # Create mesh using MaxText's device mesh creation + devices_array = maxtext_utils.create_device_mesh(maxtext_config) + self.mesh = jax.sharding.Mesh(devices_array, maxtext_config.mesh_axes) + logger.info(f"Created MaxText mesh with shape {self.mesh.shape}, axes {self.mesh.axis_names}") + + # Create model using MaxText's model creation + with jax.set_mesh(self.mesh): + base_model, _ = maxtext_model_creation.create_nnx_model(maxtext_config, mesh=self.mesh) + self.model = TunixMaxTextAdapter(base_model=base_model) + self.model.config = None + + # Extract LoRA params for gradient accumulation + lora_filter = nnx.All(nnx.Param, nnx.Any(nnx.PathContains("lora_a"), nnx.PathContains("lora_b"))) + self.lora_filter = lora_filter + self.graphdef, self.lora_params, self.non_lora_params = nnx.split(self.model, lora_filter, ...) + + # Initialize accumulated gradients + self.accumulated_grads = AccumulatedGradients.create(self.lora_params) + self._log_accumulated_grads() + + # Per-model optimizer storage (managed internally) + self.optimizers: dict[str, nnx.Optimizer] = {} + + logger.info(f"Initialized MaxText model with context_parallel_size={maxtext_config.context_parallel_size}") + + self._create_loss_and_grad_fn() + + def _log_accumulated_grads(self): + """Log accumulated gradient structure.""" + accum_params = _count_params(self.accumulated_grads.grad_sum) + logger.info(f"[MaxText] Accumulated grads total params: {accum_params / 1e6:.2f}M") + for path, val in jax.tree_util.tree_leaves_with_path(self.accumulated_grads.grad_sum): + path_str = "/".join(str(k.key) if hasattr(k, 'key') else str(k) for k in path) + logger.info(f" {path_str}: {val.shape}") + + def _create_loss_and_grad_fn(self): + """Create loss and gradient functions for MaxText model.""" + + def loss_for_maxtext_model( + model, + input_ids: jax.Array, + positions: jax.Array, + target_ids: jax.Array, + loss_mask: jax.Array, + loss_fn_types: jax.Array, + sampling_logprobs: jax.Array, + advantages: jax.Array, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Loss for MaxText model with dynamic loss function selection.""" + logits, _ = model(input_ids, positions, None, None, False) + logprobs = jax.nn.log_softmax(logits, axis=-1) + target_logprobs = jnp.take_along_axis(logprobs, target_ids[..., None], axis=-1).squeeze(-1) + + def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): + return jax.lax.switch( + loss_fn_type, + LOSS_FUNCTIONS, + target_logprobs, + loss_mask, + sampling_logprobs, + advantages, + ) + + per_token_losses = jax.vmap(compute_loss_per_example)( + loss_fn_types, + target_logprobs, + loss_mask, + sampling_logprobs, + advantages, + ) + + per_seq_loss = per_token_losses.sum(axis=-1) / jnp.maximum(loss_mask.sum(axis=-1), 1e-9) + total_loss = per_seq_loss.sum() + return total_loss, (target_logprobs, per_token_losses) + + loss_and_grad_fn = nnx.value_and_grad( + loss_for_maxtext_model, + argnums=nnx.DiffState(0, self.lora_filter), + has_aux=True + ) + + def forward_backward_maxtext( + model, input_ids, positions, target_ids, loss_mask, loss_fn_types, sampling_logprobs, advantages, + ) -> tuple[jax.Array, jax.Array, jax.Array, nnx.State]: + """Forward-backward for MaxText model.""" + (loss, (target_logprobs, per_token_losses)), grads = loss_and_grad_fn( + model, input_ids, positions, target_ids, loss_mask, loss_fn_types, sampling_logprobs, advantages, + ) + return loss, target_logprobs, per_token_losses, grads + + data_sharding = maxtext_sharding.get_input_data_sharding(self.maxtext_config, self.mesh) + + if self.config.enforce_eager: + self._forward_backward = forward_backward_maxtext + else: + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(self.maxtext_config.logical_axis_rules): + self._forward_backward = jax.jit( + forward_backward_maxtext, + # model, input_ids, positions, target_ids, loss_mask, loss_fn_types (1D), sampling_logprobs, advantages + in_shardings=(None, data_sharding, data_sharding, data_sharding, data_sharding, None, data_sharding, data_sharding), + ) + + def optim_step(model, optimizer, grads): + """Apply gradients to optimizer.""" + optimizer.update(model, grads) + + if self.config.enforce_eager: + self._optim_step = optim_step + else: + self._optim_step = nnx.jit(optim_step) + + logger.info("Created MaxText loss and gradient functions") + + def _micro_batch_size(self, total: int) -> int: + """Return effective micro-batch size.""" + mb = self.config.train_micro_batch_size + return total if mb <= 0 else max(1, min(mb, total)) + + @contextmanager + def _jit_timing_context(self, seq_len: int, mode: str): + """Context manager to track JIT compilation times.""" + jit_times = self.metrics.train_seq_len_jit_times if mode == "train" else self.metrics.sample_seq_len_jit_times + if not self.config.enforce_eager and seq_len not in jit_times: + logger.info(f"JIT compiling for {mode} seq_len={seq_len} in progress...") + start_time = time.time() + yield + elapsed = time.time() - start_time + jit_times[seq_len] = elapsed + logger.info(f"JIT compilation for {mode} seq_len={seq_len} took {elapsed:.2f}s") + else: + yield + + def register_model(self, model_id: str, adapter_index: int, lora_config: types.LoraConfig) -> None: + """Register a new model with the backend. + + Creates optimizer for the model. MaxText is single-adapter so adapter_index is ignored. + """ + tx = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0) + self.optimizers[model_id] = nnx.Optimizer(self.model, tx, wrt=self.lora_filter) + logger.info(f"Registered model {model_id} with MaxText backend") + + def unregister_model(self, model_id: str, adapter_index: int) -> None: + """Unregister a model from the backend. + + Removes optimizer and resets LoRA weights (only zeros lora_b, preserves lora_a). + """ + self.optimizers.pop(model_id, None) + + # Reset LoRA weights (only zero lora_b, preserve lora_a random init for gradients to flow) + reset_adapter_weights(self.model) + # Re-split to update lora_params reference + self.graphdef, self.lora_params, self.non_lora_params = nnx.split(self.model, self.lora_filter, ...) + logger.info(f"Unregistered model {model_id} from MaxText backend") + + def precompile_kernels(self, seq_lens: list[int]): + """Precompile JIT kernels for specified sequence lengths.""" + if not seq_lens or self.config.enforce_eager: + return + + logger.info(f"Precompiling JIT kernels for sequence lengths: {seq_lens}") + micro_bs = max(1, self.config.train_micro_batch_size) if self.config.train_micro_batch_size > 0 else 1 + + with jax.set_mesh(self.mesh): + for seq_len in seq_lens: + dummy_input_ids = jnp.zeros((micro_bs, seq_len), dtype=jnp.int32) + dummy_target_ids = jnp.zeros((micro_bs, seq_len), dtype=jnp.int32) + dummy_loss_mask = jnp.ones((micro_bs, seq_len), dtype=jnp.float32) + dummy_positions = jnp.broadcast_to(jnp.arange(seq_len), (micro_bs, seq_len)) + dummy_loss_fn_types = jnp.zeros((micro_bs,), dtype=jnp.int32) + dummy_sampling_logprobs = jnp.zeros((micro_bs, seq_len), dtype=jnp.float32) + dummy_advantages = jnp.zeros((micro_bs, seq_len), dtype=jnp.float32) + + data_sharding = maxtext_sharding.get_input_data_sharding(self.maxtext_config, self.mesh) + dummy_input_ids = jax.device_put(dummy_input_ids, data_sharding) + dummy_positions = jax.device_put(dummy_positions, data_sharding) + dummy_target_ids = jax.device_put(dummy_target_ids, data_sharding) + dummy_loss_mask = jax.device_put(dummy_loss_mask, data_sharding) + # dummy_loss_fn_types is 1D, no sharding needed + dummy_sampling_logprobs = jax.device_put(dummy_sampling_logprobs, data_sharding) + dummy_advantages = jax.device_put(dummy_advantages, data_sharding) + + with nn_partitioning.axis_rules(self.maxtext_config.logical_axis_rules): + with self._jit_timing_context(seq_len, mode="train"): + _, _, _, grads = self._forward_backward( + self.model, dummy_input_ids, dummy_positions, dummy_target_ids, dummy_loss_mask, + dummy_loss_fn_types, dummy_sampling_logprobs, dummy_advantages, + ) + self.accumulated_grads = self.accumulated_grads.add(grads, micro_bs) + + self.accumulated_grads = AccumulatedGradients.create(self.lora_params) + + logger.info(f"Precompilation complete for {len(seq_lens)} sequence lengths") + + def process_forward_backward_batch( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Process forward_backward requests using MaxText model.""" + all_input_ids = prepared_batch.all_input_ids + all_targets = prepared_batch.all_targets + all_token_weights = prepared_batch.all_token_weights + all_sampling_logprobs = prepared_batch.all_sampling_logprobs + all_advantages = prepared_batch.all_advantages + all_loss_fn_types = prepared_batch.all_loss_fn_types + request_batch_slices = prepared_batch.request_batch_slices + + if not all_input_ids: + return {} + + results = {} + max_len = round_up_seq_len(max(len(seq) for seq in all_input_ids), self.config.min_seq_len) + input_ids = pad_batch(all_input_ids, max_len, np.int32) + target_ids = pad_batch(all_targets, max_len, np.int32) + loss_mask = pad_batch(all_token_weights, max_len, np.float32) + sampling_logprobs = pad_batch(all_sampling_logprobs, max_len, np.float32) + advantages = pad_batch(all_advantages, max_len, np.float32) + loss_fn_types = jnp.array(all_loss_fn_types, dtype=jnp.int32) + + batch_size = input_ids.shape[0] + seq_len = input_ids.shape[1] + positions = jnp.broadcast_to(jnp.arange(seq_len), (batch_size, seq_len)) + seq_lens = [len(seq) for seq in all_input_ids] + + data_sharding = maxtext_sharding.get_input_data_sharding(self.maxtext_config, self.mesh) + + token_losses_device = [] + logprobs_device = [] + total_bs = batch_size + micro_bs = self._micro_batch_size(total_bs) + + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(self.maxtext_config.logical_axis_rules): + with self._jit_timing_context(seq_len, mode="train"): + for mb_start in range(0, total_bs, micro_bs): + mb_end = min(mb_start + micro_bs, total_bs) + print(f"MaxText forward-backward: batch [{mb_start}:{mb_end}], seq_len={seq_len}", flush=True) + tic = time.time() + + mb_input_ids = jax.device_put(input_ids[mb_start:mb_end], data_sharding) + mb_positions = jax.device_put(positions[mb_start:mb_end], data_sharding) + mb_target_ids = jax.device_put(target_ids[mb_start:mb_end], data_sharding) + mb_loss_mask = jax.device_put(loss_mask[mb_start:mb_end], data_sharding) + mb_loss_fn_types = loss_fn_types[mb_start:mb_end] # 1D, no sharding needed + mb_sampling_logprobs = jax.device_put(sampling_logprobs[mb_start:mb_end], data_sharding) + mb_advantages = jax.device_put(advantages[mb_start:mb_end], data_sharding) + + _, target_logprobs, per_token_losses, grads = self._forward_backward( + self.model, + mb_input_ids, + mb_positions, + mb_target_ids, + mb_loss_mask, + mb_loss_fn_types, + mb_sampling_logprobs, + mb_advantages, + ) + + _ = jax.device_get(target_logprobs) + + took = time.time() - tic + tokens_processed = (mb_end - mb_start) * seq_len + tokens_per_sec = tokens_processed / took if took > 0 else float('nan') + print(f"Batch [{mb_start}:{mb_end}] forward-backward time: {took:.3f} sec, tokens/sec: {tokens_per_sec:,.1f}", flush=True) + + micro_batch_size = mb_end - mb_start + self.accumulated_grads = self.accumulated_grads.add(grads, micro_batch_size) + token_losses_device.append(per_token_losses) + logprobs_device.append(target_logprobs) + + token_losses_host, logprobs_host = jax.device_get((token_losses_device, logprobs_device)) + + token_losses_out = [] + logprobs_out = [] + idx = 0 + for mb_losses, mb_logprobs in zip(token_losses_host, logprobs_host): + for i in range(mb_losses.shape[0]): + token_losses_out.append(mb_losses[i, :seq_lens[idx]].astype(jnp.float32)) + logprobs_out.append(mb_logprobs[i, :seq_lens[idx]].astype(jnp.float32)) + idx += 1 + + for request_id, _, start_idx, end_idx in request_batch_slices: + loss_fn_outputs = [] + for i in range(start_idx, end_idx): + token_losses = token_losses_out[i] + token_logprobs = logprobs_out[i] + loss_fn_outputs.append({ + "elementwise_loss": { + "data": token_losses.tolist(), + "dtype": "float32", + "shape": [token_losses.shape[0]], + }, + "logprobs": { + "data": token_logprobs.tolist(), + "dtype": "float32", + "shape": [token_logprobs.shape[0]], + }, + }) + + results[request_id] = types.ForwardBackwardOutput( + loss_fn_output_type="scalar", + loss_fn_outputs=loss_fn_outputs, + metrics={}, + ) + + return results + + def process_optim_step( + self, + model_id: str, + adapter_index: int, + request_data: types.OptimStepInput, + ) -> types.OptimStepOutput: + """Process an optim_step request.""" + if self.accumulated_grads.count[0] == 0: + logger.warning(f"No accumulated gradients for model {model_id}, skipping optimizer step") + return types.OptimStepOutput() + + optimizer = self.optimizers[model_id] + hp = optimizer.opt_state.hyperparams + hp["learning_rate"][...] = request_data.adam_params.learning_rate + hp["b1"][...] = request_data.adam_params.beta1 + hp["b2"][...] = request_data.adam_params.beta2 + hp["eps"][...] = request_data.adam_params.eps + + mean_grads = self.accumulated_grads.get_mean() + + with jax.set_mesh(self.mesh), nn_partitioning.axis_rules(self.maxtext_config.logical_axis_rules): + self._optim_step(self.model, optimizer, mean_grads) + + self.accumulated_grads = self.accumulated_grads.reset() + logger.info(f"Applied MaxText optimizer step for model {model_id}") + + return types.OptimStepOutput() + + def process_forward_batch( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Process forward-only requests - not implemented for MaxText.""" + raise NotImplementedError("Forward-only pass not yet implemented for MaxText backend") + + def process_sample_batch( + self, + prepared_batch: types.PreparedSampleBatch, + ) -> dict[str, types.SampleOutput | types.ErrorResponse]: + """Process sample requests - not implemented for MaxText.""" + raise NotImplementedError("Sampling not yet implemented for MaxText backend") + + def save_checkpoint( + self, + output_path, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> None: + """Save training checkpoint in HuggingFace PEFT format as tar.gz.""" + with pack_and_upload(output_path) as temp_dir: + convert_maxtext_lora_to_hf( + lora_state=self.lora_params, + output_path=temp_dir, + base_model_name=self.config.base_model, + lora_rank=self.maxtext_config.lora_rank, + lora_alpha=self.maxtext_config.lora_alpha, + ) + logger.info(f"Saved MaxText training checkpoint to {output_path}") + + def extract_checkpoint_data( + self, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> dict: + """Extract LoRA state and optimizer state for checkpointing. + + Creates copies of the arrays to ensure the cached state is independent + from the live model state (which may be zeroed on eviction). + """ + # Copy arrays to avoid caching references that get zeroed + lora_weights_copy = jax.tree.map(jnp.copy, self.lora_params) + optimizer_state_copy = jax.tree.map(jnp.copy, nnx.state(self.optimizers[model_id])) + return { + "lora_weights": lora_weights_copy, + "optimizer_state": optimizer_state_copy, + "lora_config": models[model_id].lora_config.model_dump(), + } + + def insert_checkpoint_data( + self, + model_id: str, + checkpoint_data: dict, + models: dict[str, types.ModelMetadata], + ) -> None: + """Insert checkpoint data into model state. + + Reshards the cached arrays to match the current model's sharding. + """ + optimizer = self.optimizers[model_id] + + # Reshard cached weights to match current model sharding + def reshard_to_match(cached, current): + """Reshard cached array to match current array's sharding.""" + sharding = current.sharding + return jax.device_put(cached, sharding) + + resharded_lora = jax.tree.map( + reshard_to_match, checkpoint_data["lora_weights"], self.lora_params + ) + resharded_optim = jax.tree.map( + reshard_to_match, checkpoint_data["optimizer_state"], nnx.state(optimizer) + ) + + # Update model state + nnx.update(self.lora_params, resharded_lora) + nnx.update(nnx.state(optimizer), resharded_optim) + # Sync model with updated lora_params + self.model = nnx.merge(self.graphdef, self.lora_params, self.non_lora_params) + logger.info(f"Restored checkpoint data for model {model_id}") + + def save_sampler_checkpoint( + self, + output_path, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> None: + """Save sampler checkpoint in HuggingFace PEFT format as tar.gz.""" + with pack_and_upload(output_path) as temp_dir: + convert_maxtext_lora_to_hf( + lora_state=self.lora_params, + output_path=temp_dir, + base_model_name=self.config.base_model, + lora_rank=self.maxtext_config.lora_rank, + lora_alpha=self.maxtext_config.lora_alpha, + ) + logger.info(f"Saved MaxText LoRA sampler checkpoint to {output_path}") + + def extract_sampler_weights( + self, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> dict: + """Extract sampler weights.""" + return { + "lora_params": self.lora_params, + "lora_rank": self.maxtext_config.lora_rank, + "lora_alpha": self.maxtext_config.lora_alpha, + } + + def insert_sampler_weights( + self, + model_id: str, + checkpoint_id: str, + weights_data: dict, + models: dict[str, types.ModelMetadata], + ) -> None: + """Insert sampler weights - not implemented for MaxText.""" + raise NotImplementedError("Loading sampler weights not yet implemented for MaxText backend") + diff --git a/skyrl-tx/tx/tinker/backends/native.py b/skyrl-tx/tx/tinker/backends/native.py new file mode 100644 index 000000000..0c8f1c8d1 --- /dev/null +++ b/skyrl-tx/tx/tinker/backends/native.py @@ -0,0 +1,718 @@ +"""Native LoRA backend for TinkerEngine (Qwen3 + LoRA). + +This backend implements the full training and inference pipeline for Qwen3 models +with LoRA adapters. It uses jax.value_and_grad for gradient computation and supports +multiple LoRA adapters via the AccumulatedGradients dataclass. +""" + +import time +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Callable + +import numpy as np +import jax +import jax.numpy as jnp +import optax +from flax import nnx +from flax.training import checkpoints +from transformers import AutoTokenizer, PretrainedConfig + +from tx.models.configs import Qwen3Config +from tx.layers.lora import update_adapter_config +from tx.tinker import types +from tx.tinker.config import EngineConfig +from tx.tinker.backends.backend import AbstractBackend +from tx.tinker.backends.utils import pad, pad_batch +from tx.tinker.loss_fns import LOSS_FUNCTIONS +from tx.utils.models import ( + get_dtype, + get_model_class, + load_safetensors, + load_lora_checkpoint, + save_lora_checkpoint, + extract_adapter_state, + insert_adapter_state, + round_up_seq_len, + resolve_model_path, +) +from tx.utils.storage import pack_and_upload +from tx.utils.log import logger + + +@jax.tree_util.register_dataclass +@dataclass +class AccumulatedGradients: + """Stores accumulated gradients for all LoRA adapters.""" + + grad_sum: nnx.State + counts: jax.Array + + @classmethod + def create(cls, lora_params: nnx.State, max_adapters: int) -> "AccumulatedGradients": + """Initialize with zeros.""" + return cls( + grad_sum=jax.tree.map(jnp.zeros_like, lora_params), + counts=jnp.zeros((max_adapters,), dtype=jnp.int32), + ) + + def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "AccumulatedGradients": + """Accumulate gradients and increment counts.""" + # Count occurrences of each adapter index in the batch + batch_counts = jnp.bincount(adapter_indices, length=self.counts.shape[0]) + return AccumulatedGradients( + grad_sum=jax.tree.map(lambda a, b: a + b, self.grad_sum, lora_grads), + counts=self.counts + batch_counts, + ) + + def get_mean(self, adapter_index: jax.Array) -> nnx.State: + """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" + count = self.counts[adapter_index] + return jax.tree.map( + lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)), + self.grad_sum, + ) + + def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": + """Reset gradients and count for a specific adapter.""" + return AccumulatedGradients( + grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum), + counts=self.counts.at[adapter_index].set(0), + ) + + +class NativeBackend(AbstractBackend): + """Backend for Qwen3 models with LoRA adapters. + + This backend: + - Uses jax.value_and_grad for gradient computation + - Uses 2D mesh (dp, tp) + - Supports multiple LoRA adapters via AccumulatedGradients with counts array + - Supports both FORWARD and FORWARD_BACKWARD request types + """ + + def __init__(self, config: EngineConfig): + """Initialize Native LoRA backend.""" + self.config = config + self.metrics = types.EngineMetrics() + + # Initialize the shared base model with LoRA config + checkpoint_path = resolve_model_path(config.base_model) + self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) + base_config = PretrainedConfig.from_pretrained(checkpoint_path) + self.model_config = Qwen3Config( + base_config, + max_lora_adapters=config.max_lora_adapters, + max_lora_rank=config.max_lora_rank, + shard_attention_heads=config.shard_attention_heads, + ) + + model_class = get_model_class(self.model_config) + + # Create model and load weights + self.mesh = jax.make_mesh((1, config.tensor_parallel_size), ("dp", "tp")) + + with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True): + self.model = model_class( + self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0) + ) + load_safetensors(checkpoint_path, self.model_config, self.model) + + # Split model into LoRA and non-LoRA parameters + self.graphdef, self.lora_params, self.non_lora_params = nnx.split( + self.model, self.model.is_lora_param, ... + ) + + # Initialize adapter 0 with dummy config (required for base model sampling path) + update_adapter_config(self.model, adapter_index=0, lora_config=types.LoraConfig(rank=1, alpha=1.0)) + + # Initialize global accumulated gradients + self.accumulated_grads = AccumulatedGradients.create( + self.lora_params, config.max_lora_adapters + ) + + # Per-model optimizer storage (managed internally) + self.optimizers: dict[str, nnx.Optimizer] = {} + + logger.info( + f"Initialized base model {config.base_model} with " + f"max_lora_adapters={config.max_lora_adapters}, max_lora_rank={config.max_lora_rank}" + ) + + self._create_loss_and_grad_fn() + + def _micro_batch_size(self, total: int) -> int: + """Return effective micro-batch size; 0/absent => disabled (use full fused batch).""" + mb = self.config.train_micro_batch_size + return total if mb <= 0 else max(1, min(mb, total)) + + @contextmanager + def _jit_timing_context(self, seq_len: int, mode: str): + """Context manager to track JIT compilation times for different sequence lengths. + + Args: + seq_len: The sequence length being compiled + mode: Either 'train' or 'sample' to track separately + """ + jit_times = ( + self.metrics.train_seq_len_jit_times + if mode == "train" + else self.metrics.sample_seq_len_jit_times + ) + if not self.config.enforce_eager and seq_len not in jit_times: + logger.info(f"JIT compiling for {mode} seq_len={seq_len} in progress...") + start_time = time.time() + yield + elapsed = time.time() - start_time + jit_times[seq_len] = elapsed + logger.info(f"JIT compilation for {mode} seq_len={seq_len} took {elapsed:.2f}s") + else: + yield + + def _create_loss_and_grad_fn(self): + """Compile and cache the loss function to avoid re-jitting on every call.""" + + # Wrap the model forward call to use nnx.remat for gradient checkpointing + def _model_forward( + graphdef: nnx.GraphDef, + lora_params: nnx.State, + non_lora_params: nnx.State, + input_ids: jax.Array, + attention_mask: jax.Array, + adapter_indices: jax.Array, + ) -> jax.Array: + model = nnx.merge(graphdef, lora_params, non_lora_params) + output = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) + return output.logits + + if self.config.gradient_checkpointing: + # policy=None corresponds to full activation recomputation + _model_forward = jax.checkpoint(_model_forward, policy=None) + + def loss_for_lora( + lora_params: nnx.State, + non_lora_params: nnx.State, + input_ids: jax.Array, + attention_mask: jax.Array, + adapter_indices: jax.Array, + target_ids: jax.Array, + loss_mask: jax.Array, + loss_fn_types: jax.Array, + sampling_logprobs: jax.Array, + advantages: jax.Array, + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + logits = _model_forward( + self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices + ) # [B, T, V] + + log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) + target_logprobs = (target_logits - log_sum_exp).squeeze(-1) + + def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): + return jax.lax.switch( + loss_fn_type, + LOSS_FUNCTIONS, + target_logprobs, + loss_mask, + sampling_logprobs, + advantages, + ) + + per_token_losses = jax.vmap(compute_loss_per_example)( + loss_fn_types, + target_logprobs, + loss_mask, + sampling_logprobs, + advantages, + ) + + per_seq_loss = per_token_losses.sum(axis=-1) / jnp.maximum(loss_mask.sum(axis=-1), 1e-9) + # Return sum of losses (we'll divide gradients by per-adapter batch size later) + return per_seq_loss.sum(), (target_logprobs, per_token_losses) + + # Only differentiate with respect to lora_params (argnums=0) + loss_and_grad_fn = jax.value_and_grad(loss_for_lora, argnums=0, has_aux=True) + + def forward_only( + accumulated_grads: AccumulatedGradients, + lora_params: nnx.State, + non_lora_params: nnx.State, + input_ids: jax.Array, + attention_mask: jax.Array, + adapter_indices: jax.Array, + target_ids: jax.Array, + loss_mask: jax.Array, + loss_fn_types: jax.Array, + sampling_logprobs: jax.Array, + advantages: jax.Array, + ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]: + _, (target_logprobs, per_token_losses) = loss_for_lora( + lora_params, + non_lora_params, + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) + return accumulated_grads, per_token_losses, target_logprobs + + def forward_backward_and_accumulate( + accumulated_grads: AccumulatedGradients, + lora_params: nnx.State, + non_lora_params: nnx.State, + input_ids: jax.Array, + attention_mask: jax.Array, + adapter_indices: jax.Array, + target_ids: jax.Array, + loss_mask: jax.Array, + loss_fn_types: jax.Array, + sampling_logprobs: jax.Array, + advantages: jax.Array, + ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]: + """Fused forward-backward-accumulate operation.""" + # Forward-backward + (_, (target_logprobs, per_token_losses)), lora_grads = loss_and_grad_fn( + lora_params, + non_lora_params, + input_ids, + attention_mask, + adapter_indices, + target_ids, + loss_mask, + loss_fn_types, + sampling_logprobs, + advantages, + ) + # Accumulate gradients + new_accumulated_grads = accumulated_grads.add(lora_grads, adapter_indices) + return new_accumulated_grads, per_token_losses, target_logprobs + + if self.config.enforce_eager: + # Disable JIT compilation for debugging + self._forward_backward_and_accumulate = forward_backward_and_accumulate + self._forward = forward_only + + else: + # Retrieve the sharding of lora and non_lora params and compute the sharding of inputs and outputs + lora_shardings = jax.tree.map( + lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.lora_params) + ) + non_lora_shardings = jax.tree.map( + lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.non_lora_params) + ) + # Get sharding for AccumulatedGradients + accumulated_grads_shardings = jax.tree.map( + lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.accumulated_grads) + ) + + replicated = jax.NamedSharding(self.mesh, jax.P(None)) + + # JIT the fused function + self._forward_backward_and_accumulate = jax.jit( + forward_backward_and_accumulate, + in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) + (replicated,) * 8, + out_shardings=(accumulated_grads_shardings, replicated, replicated), + donate_argnames=("accumulated_grads",), + ) + self._forward = jax.jit( + forward_only, + in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) + (replicated,) * 8, + out_shardings=(accumulated_grads_shardings, replicated, replicated), + ) + + # JIT-compiled function to compute full gradients and apply optimizer update + def compute_grads_and_update( + accumulated_grads: AccumulatedGradients, + lora_params: nnx.State, + optimizer: nnx.Optimizer, + adapter_index: jax.Array, + ) -> AccumulatedGradients: + """Compute full gradients, apply optimizer update, and reset accumulated grads.""" + optimizer.update(lora_params, accumulated_grads.get_mean(adapter_index)) + return accumulated_grads.reset_adapter(adapter_index) + + if self.config.enforce_eager: + self._compute_grads_and_update = compute_grads_and_update + else: + self._compute_grads_and_update = nnx.jit(compute_grads_and_update) + + def register_model(self, model_id: str, adapter_index: int, lora_config: types.LoraConfig) -> None: + """Register a new model with the backend. + + Creates optimizer and configures LoRA adapter. + """ + # Create optimizer + with jax.set_mesh(self.mesh): + tx = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0) + self.optimizers[model_id] = nnx.Optimizer(self.model, tx, wrt=self.model.is_lora_param) + + # Configure adapter + update_adapter_config(self.model, adapter_index, lora_config) + logger.info(f"Registered model {model_id} with adapter_index={adapter_index}") + + def unregister_model(self, model_id: str, adapter_index: int) -> None: + """Unregister a model from the backend. + + Removes optimizer and resets adapter weights. + """ + # Remove optimizer + self.optimizers.pop(model_id, None) + + # Zero out adapter weights + def zero_adapter_slice(path: tuple, p: jnp.ndarray) -> jnp.ndarray: + if len(path) >= 2 and path[-2].key in {"lora_A", "lora_B"}: + return p.at[adapter_index].set(0.0) + return p + + updated_params = jax.tree.map_with_path(zero_adapter_slice, self.lora_params) + nnx.update(self.lora_params, updated_params) + logger.info(f"Unregistered model {model_id} (adapter_index={adapter_index})") + + def _process_model_pass_batch( + self, + prepared_batch: types.PreparedModelPassBatch, + model_pass_fn: Callable, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Common batch processing logic for forward-only and forward-backward operations. + + Args: + prepared_batch: PreparedModelPassBatch with all data extracted from requests + model_pass_fn: Callable to perform the model pass (forward or forward_backward) + + Returns: + Dict mapping request_id to result_data or error info + """ + if not prepared_batch.all_input_ids: + return {} + + results = {} + + # Extract data from prepared batch + all_input_ids = prepared_batch.all_input_ids + all_targets = prepared_batch.all_targets + all_token_weights = prepared_batch.all_token_weights + all_sampling_logprobs = prepared_batch.all_sampling_logprobs + all_advantages = prepared_batch.all_advantages + all_adapter_indices = prepared_batch.all_adapter_indices + all_loss_fn_types = prepared_batch.all_loss_fn_types + request_batch_slices = prepared_batch.request_batch_slices + + # Pad sequences to same length. Also bin it so the JIT has to compile fewer kernels. + max_len = round_up_seq_len(max(len(seq) for seq in all_input_ids), self.config.min_seq_len) + + input_ids = pad_batch(all_input_ids, max_len, np.int32) + target_ids = pad_batch(all_targets, max_len, np.int32) + adapter_indices = jnp.array(all_adapter_indices, dtype=jnp.int32) + loss_fn_types = jnp.array(all_loss_fn_types, dtype=jnp.int32) + + # Create attention mask (1 for real tokens, 0 for padding) + attention_mask = pad_batch([[1] * len(seq) for seq in all_input_ids], max_len, np.int32) + loss_mask = pad_batch(all_token_weights, max_len, np.float32) + sampling_logprobs = pad_batch(all_sampling_logprobs, max_len, np.float32) + advantages = pad_batch(all_advantages, max_len, np.float32) + + total_bs = int(input_ids.shape[0]) + micro_bs = self._micro_batch_size(total_bs) + seq_lens = [len(seq) for seq in all_input_ids] + + # Collect full padded arrays on device, slice after transfer + token_losses_device = [] + logprobs_device = [] + seq_len = input_ids.shape[1] + + with jax.set_mesh(self.mesh), self._jit_timing_context(seq_len, mode="train"): + for mb_start in range(0, total_bs, micro_bs): + mb_end = min(mb_start + micro_bs, total_bs) + self.accumulated_grads, per_token_losses, target_logprobs = model_pass_fn( + self.accumulated_grads, + self.lora_params, + self.non_lora_params, + input_ids[mb_start:mb_end], + attention_mask[mb_start:mb_end], + adapter_indices[mb_start:mb_end], + target_ids[mb_start:mb_end], + loss_mask[mb_start:mb_end], + loss_fn_types[mb_start:mb_end], + sampling_logprobs[mb_start:mb_end], + advantages[mb_start:mb_end], + ) + token_losses_device.append(per_token_losses) + logprobs_device.append(target_logprobs) + + # Single batched device-to-host transfer for all arrays + token_losses_host, logprobs_host = jax.device_get((token_losses_device, logprobs_device)) + + # Flatten microbatches and slice to actual sequence lengths + token_losses_out = [] + logprobs_out = [] + idx = 0 + for mb_losses, mb_logprobs in zip(token_losses_host, logprobs_host): + for i in range(mb_losses.shape[0]): + token_losses_out.append(mb_losses[i, : seq_lens[idx]].astype(jnp.float32)) + logprobs_out.append(mb_logprobs[i, : seq_lens[idx]].astype(jnp.float32)) + idx += 1 + + # Compute per-request results + for request_id, _, start_idx, end_idx in request_batch_slices: + loss_fn_outputs = [] + # Compute per-example losses + for i in range(start_idx, end_idx): + # Extract losses for this example's tokens + token_losses = token_losses_out[i] + token_logprobs = logprobs_out[i] + loss_fn_outputs.append( + { + "elementwise_loss": { + "data": token_losses.tolist(), + "dtype": "float32", + "shape": [token_losses.shape[0]], + }, + "logprobs": { + "data": token_logprobs.tolist(), + "dtype": "float32", + "shape": [token_logprobs.shape[0]], + }, + } + ) + + results[request_id] = types.ForwardBackwardOutput( + loss_fn_output_type="scalar", + loss_fn_outputs=loss_fn_outputs, + metrics={}, + ) + + return results + + def process_forward_backward_batch( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Process forward_backward requests in a batch.""" + return self._process_model_pass_batch(prepared_batch, self._forward_backward_and_accumulate) + + def process_forward_batch( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Process forward-only requests in a batch (no gradient computation).""" + return self._process_model_pass_batch(prepared_batch, self._forward) + + def process_optim_step( + self, + model_id: str, + adapter_index: int, + request_data: types.OptimStepInput, + ) -> types.OptimStepOutput: + """Process an optim_step request and apply accumulated gradients.""" + adapter_index_arr = jnp.int32(adapter_index) + optimizer = self.optimizers[model_id] + + # Check if we have any gradients accumulated (count > 0) + if self.accumulated_grads.counts[adapter_index] == 0: + logger.warning(f"No accumulated gradients for model {model_id}, skipping optimizer step") + return types.OptimStepOutput() + + # Update hyperparameters from the request + hp = optimizer.opt_state.hyperparams + hp["learning_rate"][...] = request_data.adam_params.learning_rate + hp["b1"][...] = request_data.adam_params.beta1 + hp["b2"][...] = request_data.adam_params.beta2 + hp["eps"][...] = request_data.adam_params.eps + + # JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads + with jax.set_mesh(self.mesh): + self.accumulated_grads = self._compute_grads_and_update( + self.accumulated_grads, + self.lora_params, + optimizer, + adapter_index_arr, + ) + + logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index})") + return types.OptimStepOutput() + + def process_sample_batch( + self, + prepared_batch: types.PreparedSampleBatch, + ) -> dict[str, types.SampleOutput | types.ErrorResponse]: + """Process multiple sample requests in a single batch. + + Args: + prepared_batch: PreparedSampleBatch with all data extracted from requests + + Returns: + Dict mapping request_id --> result_data or error info + """ + if not prepared_batch.all_prompts: + return {} + + results = {} + + # Extract data from prepared batch + all_prompts = prepared_batch.all_prompts + all_sampling_params = prepared_batch.all_sampling_params + all_adapter_indices = prepared_batch.all_adapter_indices + request_batch_slices = prepared_batch.request_batch_slices + needs_prompt_logprobs = prepared_batch.needs_prompt_logprobs + + total_batch_size = len(all_prompts) + max_batch_size = ( + self.config.sample_max_num_sequences if self.config.sample_max_num_sequences > 0 else total_batch_size + ) + # Collect generated sequences and prompt logprobs across batches + all_sequences: list[types.GeneratedSequence] = [] + all_prompt_logprobs: list[list[float]] = [] + + with jax.set_mesh(self.mesh): + model = nnx.merge(self.graphdef, self.lora_params, self.non_lora_params) + for batch_start in range(0, total_batch_size, max_batch_size): + batch_end = min(batch_start + max_batch_size, total_batch_size) + batch_prompts = pad(all_prompts[batch_start:batch_end], max_batch_size, fill=[]) + batch_adapter_indices = pad(all_adapter_indices[batch_start:batch_end], max_batch_size, fill=0) + sampling_params = pad( + all_sampling_params[batch_start:batch_end], max_batch_size, fill=all_sampling_params[batch_start] + ) + + # Pad sequences to same length within the batch to minimize memory usage. + # Also bin it so the JIT has to compile fewer kernels. + # Use left-padding for sampling so the last position is always the last real token. + max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0), self.config.min_seq_len) + input_ids = pad_batch(batch_prompts, max_len, np.int32, left=True) + attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32, left=True) + + with self._jit_timing_context(max_len, mode="sample"): + result = model.generate( + input_ids, + attention_mask, + sampling_params=sampling_params, + adapter_indices=jnp.array(batch_adapter_indices, dtype=jnp.int32), + prompt_logprobs=needs_prompt_logprobs, + tokenizer=self.tokenizer, + ) + # Only take the actual results, not the padded ones + batch_size = batch_end - batch_start + all_sequences.extend( + types.GeneratedSequence(stop_reason=stop_reason, tokens=tokens, logprobs=logprobs) + for stop_reason, tokens, logprobs in zip( + result.stop_reasons[:batch_size], + result.generated_ids[:batch_size], + result.logprobs[:batch_size], + ) + ) + if needs_prompt_logprobs and result.prompt_logprobs: + all_prompt_logprobs.extend(result.prompt_logprobs[:batch_size]) + + for request_id, _, start_idx, end_idx, prompt_logprobs_requested in request_batch_slices: + sequences = [all_sequences[i] for i in range(start_idx, end_idx)] + # Each of `num_samples` samples in a request share the same prompt; use the first's prompt logprobs + prompt_logprobs = ( + all_prompt_logprobs[start_idx] if prompt_logprobs_requested and all_prompt_logprobs else None + ) + results[request_id] = types.SampleOutput(sequences=sequences, prompt_logprobs=prompt_logprobs) + + return results + + def save_checkpoint( + self, + output_path, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> None: + """Save training checkpoint as tar.gz using Flax checkpoints.""" + with pack_and_upload(output_path) as temp_dir: + checkpoint_data = self.extract_checkpoint_data(model_id, models) + checkpoints.save_checkpoint( + target=checkpoint_data, + ckpt_dir=temp_dir, + step=0, + prefix="checkpoint_", + overwrite=True, + ) + logger.info(f"Saved training checkpoint to {output_path}") + + def extract_checkpoint_data( + self, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> dict: + """Extract adapter state and optimizer state for checkpointing.""" + adapter_index = models[model_id].adapter_index + rank = models[model_id].lora_config.rank + lora_weights = extract_adapter_state(adapter_index, self.lora_params, rank) + optimizer_state = extract_adapter_state(adapter_index, nnx.state(self.optimizers[model_id]), rank) + return { + "lora_weights": lora_weights, + "optimizer_state": optimizer_state, + "lora_config": models[model_id].lora_config.model_dump(), + } + + def insert_checkpoint_data( + self, + model_id: str, + checkpoint_data: dict, + models: dict[str, types.ModelMetadata], + ) -> None: + """Insert checkpoint data into model state.""" + adapter_index = models[model_id].adapter_index + rank = checkpoint_data["lora_config"]["rank"] + + if models[model_id].lora_config.rank != rank: + raise ValueError( + f"Rank mismatch: checkpoint has rank {rank}, " + f"model configured with rank {models[model_id].lora_config.rank}" + ) + + insert_adapter_state(adapter_index, self.lora_params, checkpoint_data["lora_weights"], rank) + insert_adapter_state(adapter_index, nnx.state(self.optimizers[model_id]), checkpoint_data["optimizer_state"], rank) + + def save_sampler_checkpoint( + self, + output_path, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> None: + """Save sampler checkpoint as tar.gz using save_lora_checkpoint.""" + lora_model = models[model_id] + save_lora_checkpoint( + self.model, + self.config.base_model, + lora_model.lora_config, + lora_model.adapter_index, + output_path, + ) + logger.info(f"Saved LoRA sampler checkpoint to {output_path}") + + def extract_sampler_weights( + self, + model_id: str, + models: dict[str, types.ModelMetadata], + ) -> dict: + """Extract weights for sampler checkpoint. + + Returns data needed for save_lora_checkpoint. + """ + return { + "model": self.model, + "base_model": self.config.base_model, + "lora_config": models[model_id].lora_config, + "adapter_index": models[model_id].adapter_index, + } + + def insert_sampler_weights( + self, + model_id: str, + checkpoint_id: str, + checkpoint_path, + models: dict[str, types.ModelMetadata], + ) -> None: + """Insert sampler weights from checkpoint file.""" + adapter_index = models[model_id].adapter_index + adapter_config = models[model_id].lora_config + load_lora_checkpoint(self.model, adapter_config, adapter_index, checkpoint_path) + models[model_id].loaded_checkpoint_id = checkpoint_id + logger.info(f"Loaded LoRA sampler weights for model {model_id} at adapter index {adapter_index}") diff --git a/skyrl-tx/tx/tinker/backends/utils.py b/skyrl-tx/tx/tinker/backends/utils.py new file mode 100644 index 000000000..46e0fd8c2 --- /dev/null +++ b/skyrl-tx/tx/tinker/backends/utils.py @@ -0,0 +1,50 @@ +"""Shared helper utilities for TinkerEngine backends.""" + +import time +from contextlib import contextmanager + +import numpy as np +import jax.numpy as jnp + +from tx.utils.log import logger + + +@contextmanager +def log_timing(request: str): + """Context manager to log execution time for a request.""" + start_time = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - start_time + logger.info(f"(timing) {request} took {elapsed:.3f}s") + + +def pad(xs, pad_to: int, *, fill): + """Pad a list to a specified length with a fill value.""" + return xs + ([fill] * (pad_to - len(xs))) + + +def pad_batch(sequences: list[list], max_length: int, dtype, left: bool = False): + """Pad a batch of sequences to max_length. + + Args: + sequences: List of sequences to pad. + max_length: Target length for all sequences. + dtype: NumPy dtype for the output array. + left: If True, use left-padding (tokens at end). Required for autoregressive + generation so the last position corresponds to the last real token. + If False (default), use right-padding (tokens at start). + + Returns: + A JAX array of shape (batch_size, max_length) with the padded sequences. + """ + batch_size = len(sequences) + padded = np.zeros((batch_size, max_length), dtype=dtype) + for i, seq in enumerate(sequences): + assert len(seq) <= max_length, f"Sequence length {len(seq)} exceeds max_length {max_length}" + if left: + padded[i, max_length - len(seq) :] = seq + else: + padded[i, : len(seq)] = seq + return jnp.asarray(padded) diff --git a/skyrl-tx/tx/tinker/config.py b/skyrl-tx/tx/tinker/config.py index acfcb5709..ff1629bee 100644 --- a/skyrl-tx/tx/tinker/config.py +++ b/skyrl-tx/tx/tinker/config.py @@ -54,6 +54,15 @@ class EngineConfig(BaseModel): default=Path("/tmp/lora_models"), description="Directory where LoRA models will be extracted for external inference engines", ) + min_seq_len: int = Field( + default=32, + description="Minimum sequence length for padding buckets (sequences shorter than this are padded up to this length)", + ) + maxtext_config_str: str | None = Field( + default=None, + description="MaxText config as space-separated key=value string. If set, uses MaxTextBackend instead of NativeBackend.", + json_schema_extra={"argparse_type": str}, + ) def convert_env_var(env_name: str, env_value: str, expected_type: type): diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index e953e8716..f4f8eb89b 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -3,128 +3,161 @@ import argparse import time from contextlib import contextmanager -from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import Callable + from pydantic import BaseModel from sqlmodel import create_engine, Session, select, update, func -import numpy as np -import jax -import jax.numpy as jnp from flax import nnx from flax.training import checkpoints - -import optax -from transformers import AutoTokenizer, PretrainedConfig - -from tx.models.configs import Qwen3Config from tx.tinker.db_models import FutureDB, RequestStatus, CheckpointDB, CheckpointStatus from tx.tinker import types from tx.tinker.config import EngineConfig, add_model -from tx.tinker.loss_fns import LOSS_TYPES, LOSS_FUNCTIONS -from tx.utils.storage import download_and_unpack, pack_and_upload -from tx.utils.models import ( - get_dtype, - get_model_class, - save_lora_checkpoint, - load_lora_checkpoint, - load_safetensors, - extract_adapter_state, - insert_adapter_state, - round_up_seq_len, - resolve_model_path, -) -from tx.layers.lora import update_adapter_config +from tx.tinker.backends import NativeBackend, MaxTextBackend, parse_maxtext_config +from tx.tinker.backends.utils import log_timing +from tx.tinker.loss_fns import LOSS_TYPES +from tx.utils.storage import download_and_unpack from tx.utils.log import logger -@contextmanager -def log_timing(request: str): - """Context manager to log execution time for a request.""" - start_time = time.perf_counter() - try: - yield - finally: - elapsed = time.perf_counter() - start_time - logger.info(f"(timing) {request} took {elapsed:.3f}s") +class TinkerEngine: + """Background engine for processing training requests. + The engine handles: + - Database operations (futures, checkpoints) + - Request finding/scheduling + - File I/O (download/upload checkpoints) + - Storing models and optimizers dicts + - Validating requests against loaded models -def pad(xs, pad_to: int, *, fill): - """Pad a list to a specified length with a fill value.""" - return xs + ([fill] * (pad_to - len(xs))) + Computation is delegated to the backend (NativeBackend or MaxTextBackend). + """ + def _filter_valid_requests( + self, + requests: dict[str, tuple[str, any]], + ) -> tuple[dict[str, any], dict[str, tuple[str, any]]]: + """Filter out requests with invalid model_ids and return error results for them. -def pad_batch(sequences: list[list], max_length: int, dtype, left: bool = False) -> jax.Array: - """Pad a batch of sequences to max_length. + Args: + requests: Dict mapping request_id to (model_id, request_data) tuples - Args: - sequences: List of sequences to pad. - max_length: Target length for all sequences. - dtype: NumPy dtype for the output array. - left: If True, use left-padding (tokens at end). Required for autoregressive - generation so the last position corresponds to the last real token. - If False (default), use right-padding (tokens at start). + Returns: + Tuple of (error_results, valid_requests) + """ + results = {} + valid_requests = {} - Returns: - A JAX array of shape (batch_size, max_length) with the padded sequences. - """ - batch_size = len(sequences) - padded = np.zeros((batch_size, max_length), dtype=dtype) - for i, seq in enumerate(sequences): - assert len(seq) <= max_length, f"Sequence length {len(seq)} exceeds max_length {max_length}" - if left: - padded[i, max_length - len(seq) :] = seq - else: - padded[i, : len(seq)] = seq - return jnp.asarray(padded) + for request_id, (model_id, request_data) in requests.items(): + if model_id and model_id not in self.models: + results[request_id] = types.ErrorResponse(error=f"Model {model_id} not loaded", status="failed") + else: + valid_requests[request_id] = (model_id, request_data) + return results, valid_requests + + def _prepare_model_pass_batch( + self, + requests: dict[str, tuple[str, types.ForwardBackwardInput]], + ) -> types.PreparedModelPassBatch: + """Prepare batch data for forward/forward_backward operations. -@jax.tree_util.register_dataclass -@dataclass -class AccumulatedGradients: - """Stores accumulated gradients for all LoRA adapters.""" + Extracts tokens, targets, and metadata from requests into lists + that the backend will convert to JAX arrays. - grad_sum: nnx.State - counts: jax.Array + Args: + requests: Dict mapping request_id to (model_id, request_data) tuples (pre-validated) - @classmethod - def create(cls, lora_params: nnx.State, max_adapters: int) -> "AccumulatedGradients": - """Initialize with zeros.""" - return cls( - grad_sum=jax.tree.map(jnp.zeros_like, lora_params), - counts=jnp.zeros((max_adapters,), dtype=jnp.int32), - ) + Returns: + PreparedModelPassBatch with all data extracted from requests + """ + all_input_ids = [] + all_targets = [] + all_token_weights = [] + all_adapter_indices = [] + all_sampling_logprobs = [] + all_advantages = [] + all_loss_fn_types = [] + request_batch_slices = [] - def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "AccumulatedGradients": - """Accumulate gradients and increment counts.""" - # Count occurrences of each adapter index in the batch - batch_counts = jnp.bincount(adapter_indices, length=self.counts.shape[0]) - return AccumulatedGradients( - grad_sum=jax.tree.map(lambda a, b: a + b, self.grad_sum, lora_grads), - counts=self.counts + batch_counts, - ) + for request_id, (model_id, request_data) in requests.items(): + adapter_index = self.models[model_id].adapter_index + loss_fn_type = LOSS_TYPES[request_data.loss_fn] - def get_mean(self, adapter_index: jax.Array) -> nnx.State: - """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" - count = self.counts[adapter_index] - return jax.tree.map( - lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)), - self.grad_sum, + request_start = len(all_input_ids) + for item in request_data.data: + tokens = [t for chunk in item.model_input.chunks for t in chunk.tokens] + all_input_ids.append(tokens) + loss_fn_inputs = item.loss_fn_inputs + all_targets.append(loss_fn_inputs.target_tokens.data) + all_token_weights.append(loss_fn_inputs.weights.data) + all_sampling_logprobs.append(loss_fn_inputs.logprobs.data) + all_advantages.append(loss_fn_inputs.advantages.data) + all_adapter_indices.append(adapter_index) + all_loss_fn_types.append(loss_fn_type) + + request_batch_slices.append((request_id, model_id, request_start, len(all_input_ids))) + + return types.PreparedModelPassBatch( + all_input_ids=all_input_ids, + all_targets=all_targets, + all_token_weights=all_token_weights, + all_sampling_logprobs=all_sampling_logprobs, + all_advantages=all_advantages, + all_adapter_indices=all_adapter_indices, + all_loss_fn_types=all_loss_fn_types, + request_batch_slices=request_batch_slices, ) - def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": - """Reset gradients and count for a specific adapter.""" - return AccumulatedGradients( - grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum), - counts=self.counts.at[adapter_index].set(0), + def _prepare_sample_batch( + self, + requests: dict[str, tuple[str, types.SampleInput]], + adapter_indices: list[int], + ) -> types.PreparedSampleBatch: + """Prepare batch data for sample operations. + + Extracts prompts and sampling params from requests into lists + that the backend will convert to JAX arrays. + + Args: + requests: Dict mapping request_id to (model_id, request_data) tuples (pre-validated) + adapter_indices: List of adapter indices corresponding to each request + + Returns: + PreparedSampleBatch with all data extracted from requests + """ + all_prompts = [] + all_sampling_params = [] + all_adapter_indices = [] + request_batch_slices = [] + + needs_prompt_logprobs = any( + request_data.prompt_logprobs for (_, request_data) in requests.values() ) + for i, (request_id, (model_id, request_data)) in enumerate(requests.items()): + request_start = len(all_prompts) -class TinkerEngine: - """Background engine for processing training requests.""" + # Expand requests for num_samples + for _ in range(request_data.num_samples): + prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens] + all_prompts.append(prompt_tokens) + all_sampling_params.append(request_data.sampling_params) + all_adapter_indices.append(adapter_indices[i]) + + request_batch_slices.append(( + request_id, model_id, request_start, len(all_prompts), request_data.prompt_logprobs + )) + + return types.PreparedSampleBatch( + all_prompts=all_prompts, + all_sampling_params=all_sampling_params, + all_adapter_indices=all_adapter_indices, + needs_prompt_logprobs=needs_prompt_logprobs, + request_batch_slices=request_batch_slices, + ) def __init__( self, @@ -136,54 +169,70 @@ def __init__( # Store LoRA model metadata (model_id -> metadata) self.models: dict[str, types.ModelMetadata] = {} - # Store optimizer instances per LoRA adapter (model_id -> optimizer) - self.optimizers: dict[str, nnx.Optimizer] = {} - # Metrics recorded in the engine - self.metrics = types.EngineMetrics() - - # Initialize the shared base model with LoRA config - checkpoint_path = resolve_model_path(self.config.base_model) - self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) - base_config = PretrainedConfig.from_pretrained(checkpoint_path) - self.model_config = Qwen3Config( - base_config, - max_lora_adapters=self.config.max_lora_adapters, - max_lora_rank=self.config.max_lora_rank, - shard_attention_heads=self.config.shard_attention_heads, - ) + # Cache for evicted model states (model_id -> checkpoint_data) + # Allows restoring weights when same model_id is recreated + self._model_cache: dict[str, dict] = {} + + # Initialize the backend (handles model state and computation) + if config.maxtext_config_str: + maxtext_config = parse_maxtext_config(config.maxtext_config_str) + self.backend = MaxTextBackend(config, maxtext_config) + else: + self.backend = NativeBackend(config) - model_class = get_model_class(self.model_config) + logger.info( + f"Initialized TinkerEngine with backend={type(self.backend).__name__}, " + f"max_lora_adapters={config.max_lora_adapters}, max_lora_rank={config.max_lora_rank}" + ) - # Create model and load weights - self.mesh = jax.make_mesh((1, self.config.tensor_parallel_size), ("dp", "tp")) - with jax.set_mesh(self.mesh): - self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0)) - load_safetensors(checkpoint_path, self.model_config, self.model) + @property + def metrics(self) -> types.EngineMetrics: + """Pass-through to backend metrics for backwards compatibility.""" + return self.backend.metrics - # Split model into LoRA and non-LoRA parameters - self.graphdef, self.lora_params, self.non_lora_params = nnx.split(self.model, self.model.is_lora_param, ...) - update_adapter_config(self.model, adapter_index=0, lora_config=types.LoraConfig(rank=1, alpha=1.0)) + def _find_lru_model(self) -> str | None: + """Find the least recently used model. - # Initialize global accumulated gradients - self.accumulated_grads = AccumulatedGradients.create(self.lora_params, self.config.max_lora_adapters) + Returns: + model_id of the LRU model, or None if no models exist + """ + if not self.models: + return None - logger.info( - f"Initialized base model {self.config.base_model} with max_lora_adapters={self.config.max_lora_adapters}, max_lora_rank={self.config.max_lora_rank}" + # Find model with oldest last_used timestamp (None treated as oldest) + lru_model_id = min( + self.models.keys(), + key=lambda mid: self.models[mid].last_used or 0.0 ) + return lru_model_id - self._create_loss_and_grad_fn() + def _evict_model(self, model_id: str) -> int: + """Evict a model and return its adapter index for reuse. - def _extract_checkpoint_data(self, model_id: str) -> dict: - """Extract adapter state and optimizer state for checkpointing.""" - adapter_index = self.models[model_id].adapter_index - rank = self.models[model_id].lora_config.rank - lora_weights = extract_adapter_state(adapter_index, self.lora_params, rank) - optimizer_state = extract_adapter_state(adapter_index, nnx.state(self.optimizers[model_id]), rank) - return { - "lora_weights": lora_weights, - "optimizer_state": optimizer_state, - "lora_config": self.models[model_id].lora_config.model_dump(), - } + Caches the model's LoRA weights and optimizer state before eviction, + allowing restoration if the same model_id is created again. + + Args: + model_id: The model to evict + + Returns: + The adapter_index that was freed up + """ + # Cache the model state before evicting (for potential restoration) + if model_id in self.models: + cached_state = self.backend.extract_checkpoint_data(model_id, self.models) + self._model_cache[model_id] = cached_state + logger.info(f"Cached state for model {model_id}") + + metadata = self.models.pop(model_id) + self.backend.unregister_model(model_id, metadata.adapter_index) + logger.info(f"Evicted model {model_id} (adapter_index={metadata.adapter_index}) to make room for new model") + return metadata.adapter_index + + def _touch_model(self, model_id: str) -> None: + """Update last_used timestamp for a model.""" + if model_id in self.models: + self.models[model_id].last_used = time.time() @contextmanager def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType): @@ -212,225 +261,6 @@ def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoi session.add(checkpoint_db) session.commit() - def _create_loss_and_grad_fn(self): - """Compile and cache the loss function to avoid re-jitting on every call.""" - - # Wrap the model forward call to use nnx.remat for gradient checkpointing - def _model_forward( - graphdef: nnx.GraphDef, - lora_params: nnx.State, - non_lora_params: nnx.State, - input_ids: jax.Array, - attention_mask: jax.Array, - adapter_indices: jax.Array, - ) -> jax.Array: - model = nnx.merge(graphdef, lora_params, non_lora_params) - output = model(input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices) - return output.logits - - if self.config.gradient_checkpointing: - # policy=None corresponds to full activation recomputation - _model_forward = jax.checkpoint(_model_forward, policy=None) - - def loss_for_lora( - lora_params: nnx.State, - non_lora_params: nnx.State, - input_ids: jax.Array, - attention_mask: jax.Array, - adapter_indices: jax.Array, - target_ids: jax.Array, - loss_mask: jax.Array, - loss_fn_types: jax.Array, - sampling_logprobs: jax.Array, - advantages: jax.Array, - ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - logits = _model_forward( - self.graphdef, lora_params, non_lora_params, input_ids, attention_mask, adapter_indices - ) # [B, T, V] - - log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) - target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) - target_logprobs = (target_logits - log_sum_exp).squeeze(-1) - - def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): - return jax.lax.switch( - loss_fn_type, - LOSS_FUNCTIONS, - target_logprobs, - loss_mask, - sampling_logprobs, - advantages, - ) - - per_token_losses = jax.vmap(compute_loss_per_example)( - loss_fn_types, - target_logprobs, - loss_mask, - sampling_logprobs, - advantages, - ) - - per_seq_loss = per_token_losses.sum(axis=-1) / loss_mask.sum(axis=-1) - # Return sum of losses (we'll divide gradients by per-adapter batch size later) - return per_seq_loss.sum(), (target_logprobs, per_token_losses) - - # Only differentiate with respect to lora_params (argnums=0) - loss_and_grad_fn = jax.value_and_grad(loss_for_lora, argnums=0, has_aux=True) - - def forward_only( - accumulated_grads: AccumulatedGradients, - lora_params: nnx.State, - non_lora_params: nnx.State, - input_ids: jax.Array, - attention_mask: jax.Array, - adapter_indices: jax.Array, - target_ids: jax.Array, - loss_mask: jax.Array, - loss_fn_types: jax.Array, - sampling_logprobs: jax.Array, - advantages: jax.Array, - ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]: - _, (target_logprobs, per_token_losses) = loss_for_lora( - lora_params, - non_lora_params, - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) - return accumulated_grads, per_token_losses, target_logprobs - - def forward_backward_and_accumulate( - accumulated_grads: AccumulatedGradients, - lora_params: nnx.State, - non_lora_params: nnx.State, - input_ids: jax.Array, - attention_mask: jax.Array, - adapter_indices: jax.Array, - target_ids: jax.Array, - loss_mask: jax.Array, - loss_fn_types: jax.Array, - sampling_logprobs: jax.Array, - advantages: jax.Array, - ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]: - """Fused forward-backward-accumulate operation.""" - # Forward-backward - (_, (target_logprobs, per_token_losses)), lora_grads = loss_and_grad_fn( - lora_params, - non_lora_params, - input_ids, - attention_mask, - adapter_indices, - target_ids, - loss_mask, - loss_fn_types, - sampling_logprobs, - advantages, - ) - # Accumulate gradients - new_accumulated_grads = accumulated_grads.add(lora_grads, adapter_indices) - return new_accumulated_grads, per_token_losses, target_logprobs - - if self.config.enforce_eager: - # Disable JIT compilation for debugging - self._forward_backward_and_accumulate = forward_backward_and_accumulate - self._forward = forward_only - - else: - # Retrieve the sharding of lora and non_lora params and compute the sharding of inputs and outputs - lora_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.lora_params) - ) - non_lora_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.non_lora_params) - ) - # Get sharding for AccumulatedGradients - accumulated_grads_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.accumulated_grads) - ) - - replicated = jax.NamedSharding(self.mesh, jax.P(None)) - - # JIT the fused function - self._forward_backward_and_accumulate = jax.jit( - forward_backward_and_accumulate, - in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) + (replicated,) * 8, - out_shardings=(accumulated_grads_shardings, replicated, replicated), - donate_argnames=("accumulated_grads",), - ) - self._forward = jax.jit( - forward_only, - in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) + (replicated,) * 8, - out_shardings=(accumulated_grads_shardings, replicated, replicated), - ) - - # JIT-compiled function to compute full gradients and apply optimizer update - def compute_grads_and_update( - accumulated_grads: AccumulatedGradients, - lora_params: nnx.State, - optimizer: nnx.Optimizer, - adapter_index: jax.Array, - ) -> AccumulatedGradients: - """Compute full gradients, apply optimizer update, and reset accumulated grads.""" - optimizer.update(lora_params, accumulated_grads.get_mean(adapter_index)) - return accumulated_grads.reset_adapter(adapter_index) - - if self.config.enforce_eager: - self._compute_grads_and_update = compute_grads_and_update - else: - self._compute_grads_and_update = nnx.jit(compute_grads_and_update) - - def _micro_batch_size(self, total: int) -> int: - """Return effective micro-batch size; 0/absent => disabled (use full fused batch).""" - mb = self.config.train_micro_batch_size - return total if mb <= 0 else max(1, min(mb, total)) - - @contextmanager - def _jit_timing_context(self, seq_len: int, mode: str): - """Context manager to track JIT compilation times for different sequence lengths. - - Args: - seq_len: The sequence length being compiled - mode: Either 'train' or 'sample' to track separately - """ - jit_times = self.metrics.train_seq_len_jit_times if mode == "train" else self.metrics.sample_seq_len_jit_times - if not self.config.enforce_eager and seq_len not in jit_times: - logger.info(f"JIT compiling for {mode} seq_len={seq_len} in progress...") - start_time = time.time() - yield - elapsed = time.time() - start_time - jit_times[seq_len] = elapsed - logger.info(f"JIT compilation for {mode} seq_len={seq_len} took {elapsed:.2f}s") - else: - yield - - def _filter_valid_requests( - self, - requests: dict[str, tuple[str, any]], - ) -> tuple[dict[str, any], dict[str, tuple[str, any]]]: - """Filter out requests with invalid model_ids and return error results for them. - - Args: - requests: Dict mapping request_id to (model_id, request_data) tuples - - Returns: - Tuple of (error_results, valid_requests) - """ - results = {} - valid_requests = {} - - for request_id, (model_id, request_data) in requests.items(): - if model_id and model_id not in self.models: - results[request_id] = types.ErrorResponse(error=f"Model {model_id} not loaded", status="failed") - else: - valid_requests[request_id] = (model_id, request_data) - - return results, valid_requests - def find_batchable_model_passes( self, session: Session, request_type: types.RequestType ) -> dict[str, tuple[str, types.ForwardBackwardInput]]: @@ -536,12 +366,6 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput: """Create and initialize a model.""" - # Assign adapter index for this model_id - adapter_index = max((m.adapter_index for m in self.models.values()), default=0) + 1 - - if adapter_index >= self.config.max_lora_adapters: - raise ValueError(f"Maximum number of LoRA adapters ({self.config.max_lora_adapters}) reached") - # Extract LoRA configuration lora_config = request_data.lora_config @@ -549,18 +373,36 @@ def process_create_model(self, model_id: str, request_data: types.CreateModelInp if not (0 < lora_config.rank <= self.config.max_lora_rank): raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") + # Determine adapter index - either get a new one or reuse from evicted model + if len(self.models) >= self.config.max_lora_adapters: + # Evict LRU model and reuse its adapter index + lru_model_id = self._find_lru_model() + adapter_index = self._evict_model(lru_model_id) + else: + # Assign new adapter index + adapter_index = max((m.adapter_index for m in self.models.values()), default=0) + 1 + self.models[model_id] = types.ModelMetadata( adapter_index=adapter_index, lora_config=lora_config, + last_used=time.time(), ) - with jax.set_mesh(self.mesh): - # These values are always overridden by the hyperparams in the optim_step request. - tx = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0) - self.optimizers[model_id] = nnx.Optimizer(self.model, tx, wrt=self.model.is_lora_param) + # Register model with backend (creates optimizer and configures adapter) + self.backend.register_model(model_id, adapter_index, lora_config) - # Update the adapter's rank and scaling in all LoRA layers - update_adapter_config(self.model, adapter_index, lora_config) + # Check if we have cached state for this model_id and restore if rank matches + cached_state = self._model_cache.pop(model_id, None) + if cached_state is not None: + cached_rank = cached_state["lora_config"]["rank"] + if cached_rank == lora_config.rank: + self.backend.insert_checkpoint_data(model_id, cached_state, self.models) + logger.info(f"Restored cached state for model {model_id}") + else: + logger.info( + f"Skipped cache restore for {model_id}: rank mismatch " + f"(cached={cached_rank}, requested={lora_config.rank})" + ) logger.info(f"Created LoRA model {model_id} with adapter index {adapter_index}, config {lora_config}") @@ -570,151 +412,52 @@ def process_create_model(self, model_id: str, request_data: types.CreateModelInp lora_config=request_data.lora_config, ) - def _process_model_pass_batch( - self, - requests: dict[str, tuple[str, types.ForwardBackwardInput]], - model_pass_fn: Callable, + def process_forward_backward_batch( + self, requests: dict[str, tuple[str, types.ForwardBackwardInput]] ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: - """Common batch processing logic for forward-only and forward-backward operations. - - Args: - requests: Dict mapping request_id to (model_id, request_data) tuples - model_pass_fn: Callable to perform the model pass (forward or forward_backward) - - Returns: - Dict mapping request_id to result_data or error info - """ - results, valid_requests = self._filter_valid_requests(requests) - + """Process forward_backward requests by delegating to backend.""" + # Filter invalid requests before delegating to backend + error_results, valid_requests = self._filter_valid_requests(requests) if not valid_requests: - return results + return error_results - # Collect all examples and their metadata - all_input_ids = [] - all_targets = [] - all_token_weights = [] - all_adapter_indices = [] - example_model_ids = [] # map each example to its model_id - request_batch_slices = [] # Track which examples belong to which request - all_sampling_logprobs = [] - all_advantages = [] - all_loss_fn_types = [] + # Update last_used for all models in this batch + for model_id, _ in valid_requests.values(): + self._touch_model(model_id) - for request_id, (model_id, request_data) in valid_requests.items(): - adapter_index = self.models[model_id].adapter_index - loss_fn_type = LOSS_TYPES[request_data.loss_fn] + # Prepare batch data + prepared_batch = self._prepare_model_pass_batch(valid_requests) - request_start = len(all_input_ids) - for item in request_data.data: - tokens = [t for chunk in item.model_input.chunks for t in chunk.tokens] - all_input_ids.append(tokens) - loss_fn_inputs = item.loss_fn_inputs - all_targets.append(loss_fn_inputs.target_tokens.data) - all_token_weights.append(loss_fn_inputs.weights.data) - all_sampling_logprobs.append(loss_fn_inputs.logprobs.data) - all_advantages.append(loss_fn_inputs.advantages.data) - all_adapter_indices.append(adapter_index) - example_model_ids.append(model_id) - all_loss_fn_types.append(loss_fn_type) + # Delegate computation to backend + results = self.backend.process_forward_backward_batch(prepared_batch) + results.update(error_results) + return results - request_batch_slices.append((request_id, model_id, request_start, len(all_input_ids))) + def process_forward_batch( + self, requests: dict[str, tuple[str, types.ForwardBackwardInput]] + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Process forward-only requests by delegating to backend.""" + # Filter invalid requests before delegating to backend + error_results, valid_requests = self._filter_valid_requests(requests) + if not valid_requests: + return error_results - # Pad sequences to same length. Also bin it so the JIT has to compile fewer kernels. - max_len = round_up_seq_len(max(len(seq) for seq in all_input_ids)) - - input_ids = pad_batch(all_input_ids, max_len, np.int32) - target_ids = pad_batch(all_targets, max_len, np.int32) - adapter_indices = jnp.array(all_adapter_indices, dtype=jnp.int32) - loss_fn_types = jnp.array(all_loss_fn_types, dtype=jnp.int32) - - # Create attention mask (1 for real tokens, 0 for padding) - attention_mask = pad_batch([[1] * len(seq) for seq in all_input_ids], max_len, np.int32) - loss_mask = pad_batch(all_token_weights, max_len, np.float32) - sampling_logprobs = pad_batch(all_sampling_logprobs, max_len, np.float32) - advantages = pad_batch(all_advantages, max_len, np.float32) - - total_bs = int(input_ids.shape[0]) - micro_bs = self._micro_batch_size(total_bs) - seq_lens = [len(seq) for seq in all_input_ids] - - # Collect full padded arrays on device, slice after transfer - token_losses_device = [] - logprobs_device = [] - seq_len = input_ids.shape[1] - - with jax.set_mesh(self.mesh), self._jit_timing_context(seq_len, mode="train"): - for mb_start in range(0, total_bs, micro_bs): - mb_end = min(mb_start + micro_bs, total_bs) - self.accumulated_grads, per_token_losses, target_logprobs = model_pass_fn( - self.accumulated_grads, - self.lora_params, - self.non_lora_params, - input_ids[mb_start:mb_end], - attention_mask[mb_start:mb_end], - adapter_indices[mb_start:mb_end], - target_ids[mb_start:mb_end], - loss_mask[mb_start:mb_end], - loss_fn_types[mb_start:mb_end], - sampling_logprobs[mb_start:mb_end], - advantages[mb_start:mb_end], - ) - token_losses_device.append(per_token_losses) - logprobs_device.append(target_logprobs) - - # Single batched device-to-host transfer for all arrays - token_losses_host, logprobs_host = jax.device_get((token_losses_device, logprobs_device)) - - # Flatten microbatches and slice to actual sequence lengths - token_losses_out = [] - logprobs_out = [] - idx = 0 - for mb_losses, mb_logprobs in zip(token_losses_host, logprobs_host): - for i in range(mb_losses.shape[0]): - token_losses_out.append(mb_losses[i, : seq_lens[idx]].astype(jnp.float32)) - logprobs_out.append(mb_logprobs[i, : seq_lens[idx]].astype(jnp.float32)) - idx += 1 - - # Compute per-request results - for request_id, _, start_idx, end_idx in request_batch_slices: - loss_fn_outputs = [] - # Compute per-example losses - for i in range(start_idx, end_idx): - # Extract losses for this example's tokens - token_losses = token_losses_out[i] - token_logprobs = logprobs_out[i] - loss_fn_outputs.append( - { - "elementwise_loss": { - "data": token_losses.tolist(), - "dtype": "float32", - "shape": [token_losses.shape[0]], - }, - "logprobs": { - "data": token_logprobs.tolist(), - "dtype": "float32", - "shape": [token_logprobs.shape[0]], - }, - } - ) + # Update last_used for all models in this batch + for model_id, _ in valid_requests.values(): + self._touch_model(model_id) - results[request_id] = types.ForwardBackwardOutput( - loss_fn_output_type="scalar", - loss_fn_outputs=loss_fn_outputs, - metrics={}, - ) + # Prepare batch data + prepared_batch = self._prepare_model_pass_batch(valid_requests) + # Delegate computation to backend + results = self.backend.process_forward_batch(prepared_batch) + results.update(error_results) return results - def process_forward_backward_batch(self, requests): - return self._process_model_pass_batch(requests, self._forward_backward_and_accumulate) - - def process_forward_batch(self, requests): - return self._process_model_pass_batch(requests, self._forward) - def process_sample_batch( self, requests: dict[str, tuple[str, types.SampleInput]] ) -> dict[str, types.SampleOutput | types.ErrorResponse]: - """Process multiple sample requests in a single batch + """Process multiple sample requests in a single batch. Args: requests: Dict mapping request_id to (model_id, request_data) tuples @@ -722,89 +465,27 @@ def process_sample_batch( Returns: Dict mapping request_id --> result_data or error info """ - results, valid_requests = self._filter_valid_requests(requests) + if not requests: + return {} + # Filter invalid requests before delegating to backend + error_results, valid_requests = self._filter_valid_requests(requests) if not valid_requests: - return results + return error_results - # Computes prompt_logprobs for the whole batch if any request asked for them - needs_prompt_logprobs = any(request_data.prompt_logprobs for (_, request_data) in valid_requests.values()) + # Update last_used for all models in this batch + for model_id, _ in valid_requests.values(): + self._touch_model(model_id) - all_prompts = [] - all_sampling_params = [] - all_adapter_indices = [] - request_batch_slices = [] - - adapter_indices_batch = self.load_sampler_weights(valid_requests) - - for i, (request_id, (model_id, request_data)) in enumerate(valid_requests.items()): - request_start = len(all_prompts) - - # Expand requests for num_samples (TODO: Once we have continuous batching / - # paged attention, we should do the prefill only once and share the kv cache) - for _ in range(request_data.num_samples): - prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens] - all_prompts.append(prompt_tokens) - all_sampling_params.append(request_data.sampling_params) - all_adapter_indices.append(adapter_indices_batch[i]) - - request_batch_slices.append((request_id, model_id, request_start, len(all_prompts), request_data)) - - total_batch_size = len(all_prompts) - max_batch_size = ( - self.config.sample_max_num_sequences if self.config.sample_max_num_sequences > 0 else total_batch_size - ) - # Collect generated sequences and prompt logprobs across batches - all_sequences: list[types.GeneratedSequence] = [] - all_prompt_logprobs: list[list[float]] = [] - - with jax.set_mesh(self.mesh): - model = nnx.merge(self.graphdef, self.lora_params, self.non_lora_params) - for batch_start in range(0, total_batch_size, max_batch_size): - batch_end = min(batch_start + max_batch_size, total_batch_size) - batch_prompts = pad(all_prompts[batch_start:batch_end], max_batch_size, fill=[]) - adapter_indices = pad(all_adapter_indices[batch_start:batch_end], max_batch_size, fill=0) - sampling_params = pad( - all_sampling_params[batch_start:batch_end], max_batch_size, fill=all_sampling_params[batch_start] - ) + # Load sampler weights and get adapter indices + adapter_indices = self.load_sampler_weights(valid_requests) - # Pad sequences to same length within the batch to minimize memory usage. - # Also bin it so the JIT has to compile fewer kernels. - # Use left-padding for sampling so the last position is always the last real token. - max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0)) - input_ids = pad_batch(batch_prompts, max_len, np.int32, left=True) - attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32, left=True) - - with self._jit_timing_context(max_len, mode="sample"): - result = model.generate( - input_ids, - attention_mask, - sampling_params=sampling_params, - adapter_indices=jnp.array(adapter_indices, dtype=jnp.int32), - prompt_logprobs=needs_prompt_logprobs, - tokenizer=self.tokenizer, - ) - # Only take the actual results, not the padded ones - batch_size = batch_end - batch_start - all_sequences.extend( - types.GeneratedSequence(stop_reason=stop_reason, tokens=tokens, logprobs=logprobs) - for stop_reason, tokens, logprobs in zip( - result.stop_reasons[:batch_size], - result.generated_ids[:batch_size], - result.logprobs[:batch_size], - ) - ) - if needs_prompt_logprobs and result.prompt_logprobs: - all_prompt_logprobs.extend(result.prompt_logprobs[:batch_size]) - - for request_id, _, start_idx, end_idx, request_data in request_batch_slices: - sequences = [all_sequences[i] for i in range(start_idx, end_idx)] - # Each of `num_samples` samples in a request share the same prompt; use the first's prompt logprobs - prompt_logprobs = ( - all_prompt_logprobs[start_idx] if request_data.prompt_logprobs and all_prompt_logprobs else None - ) - results[request_id] = types.SampleOutput(sequences=sequences, prompt_logprobs=prompt_logprobs) + # Prepare batch data + prepared_batch = self._prepare_sample_batch(valid_requests, adapter_indices) + # Delegate computation to backend + results = self.backend.process_sample_batch(prepared_batch) + results.update(error_results) return results def process_optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: @@ -812,60 +493,33 @@ def process_optim_step(self, model_id: str, request_data: types.OptimStepInput) if model_id not in self.models: raise ValueError(f"Model {model_id} not loaded") - adapter_index = jnp.int32(self.models[model_id].adapter_index) - - # Check if we have any gradients accumulated (count > 0) - if self.accumulated_grads.counts[adapter_index] == 0: - logger.warning(f"No accumulated gradients for model {model_id}, skipping optimizer step") - return types.OptimStepOutput() - - # Update hyperparameters from the request - hp = self.optimizers[model_id].opt_state.hyperparams - hp["learning_rate"][...] = request_data.adam_params.learning_rate - hp["b1"][...] = request_data.adam_params.beta1 - hp["b2"][...] = request_data.adam_params.beta2 - hp["eps"][...] = request_data.adam_params.eps - - # JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads - with jax.set_mesh(self.mesh): - self.accumulated_grads = self._compute_grads_and_update( - self.accumulated_grads, - self.lora_params, - self.optimizers[model_id], - adapter_index, - ) + self._touch_model(model_id) + adapter_index = self.models[model_id].adapter_index - logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index})") - return types.OptimStepOutput() + return self.backend.process_optim_step(model_id, adapter_index, request_data) def process_load_weights(self, model_id: str, request_data: types.LoadWeightsInput) -> types.LoadWeightsOutput: """Loads a clean, trimmed training checkpoint.""" if model_id not in self.models: raise ValueError("Model not loaded. Create the model before loading a checkpoint.") - adapter_index = self.models[model_id].adapter_index + self._touch_model(model_id) checkpoint_dir = ( self.config.checkpoints_base / request_data.source_model_id / f"{request_data.checkpoint_id}.tar.gz" ) with download_and_unpack(checkpoint_dir) as temp_dir: checkpoint = checkpoints.restore_checkpoint( - ckpt_dir=temp_dir, target=self._extract_checkpoint_data(model_id), prefix="checkpoint_" + ckpt_dir=temp_dir, + target=self.backend.extract_checkpoint_data(model_id, self.models), + prefix="checkpoint_", ) if checkpoint is None: raise FileNotFoundError(f"Training checkpoint not found in {checkpoint_dir}") - # Validate rank - rank = checkpoint["lora_config"]["rank"] - if self.models[model_id].lora_config.rank != rank: - raise ValueError( - f"Rank mismatch: checkpoint has rank {rank}, model configured with rank {self.models[model_id].lora_config.rank}" - ) - - # Update both LoRA weights and optimizer state - insert_adapter_state(adapter_index, self.lora_params, checkpoint["lora_weights"], rank) - insert_adapter_state(adapter_index, nnx.state(self.optimizers[model_id]), checkpoint["optimizer_state"], rank) + # Insert checkpoint data into model state via backend + self.backend.insert_checkpoint_data(model_id, checkpoint, self.models) logger.info(f"Loaded training checkpoint for model {model_id} from {checkpoint_dir}") return types.LoadWeightsOutput(type="load_weights") @@ -878,19 +532,12 @@ def process_save_weights(self, model_id: str, request_data: types.SaveWeightsInp if model_id not in self.models: raise ValueError(f"Model {model_id} not loaded") + self._touch_model(model_id) checkpoint_id = request_data.path output_path = self.config.checkpoints_base / model_id / f"{checkpoint_id}.tar.gz" with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.TRAINING): - with pack_and_upload(output_path) as temp_dir: - checkpoints.save_checkpoint( - target=self._extract_checkpoint_data(model_id), - ckpt_dir=temp_dir, - step=0, - prefix="checkpoint_", - overwrite=True, - ) - + self.backend.save_checkpoint(output_path, model_id, self.models) logger.info(f"Saved trimmed training checkpoint for model {model_id} to {output_path}") return types.SaveWeightsOutput( @@ -905,6 +552,7 @@ def process_save_weights_for_sampler( if model_id not in self.models: raise ValueError(f"Model {model_id} not loaded") + self._touch_model(model_id) lora_model = self.models[model_id] # Make sure the user cannot store checkpoints in places like ../../ @@ -912,10 +560,7 @@ def process_save_weights_for_sampler( output_path = self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz" with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER): - # Save the LoRA adapter weights and LoRA config as tar.gz - save_lora_checkpoint( - self.model, self.config.base_model, lora_model.lora_config, lora_model.adapter_index, output_path - ) + self.backend.save_sampler_checkpoint(output_path, model_id, self.models) logger.info( f"Saved LoRA adapter weights for model {model_id} (adapter {lora_model.adapter_index}) to {output_path}" @@ -957,11 +602,9 @@ def load_sampler_weights(self, requests: dict[str, tuple[str, types.SampleInput] self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz" ) logger.info(f"Loading LoRA sampler checkpoint from {checkpoint_path}") - adapter_config = self.models[model_id].lora_config - load_lora_checkpoint(self.model, adapter_config, adapter_index, checkpoint_path) - self.models[model_id].loaded_checkpoint_id = checkpoint_id - logger.info(f"Loaded LoRA sampler weights for model {model_id} at adapter index {adapter_index}") + # Use backend to insert sampler weights + self.backend.insert_sampler_weights(model_id, checkpoint_id, checkpoint_path, self.models) adapter_indices.append(adapter_index) else: # This code path is for sampling from the base model diff --git a/skyrl-tx/tx/tinker/extra/external_inference.py b/skyrl-tx/tx/tinker/extra/external_inference.py index c94700f89..86f974c50 100644 --- a/skyrl-tx/tx/tinker/extra/external_inference.py +++ b/skyrl-tx/tx/tinker/extra/external_inference.py @@ -1,4 +1,7 @@ +import shutil + import httpx +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception from datetime import datetime, timezone from sqlmodel.ext.asyncio.session import AsyncSession @@ -9,6 +12,15 @@ from tx.utils.storage import download_and_unpack +def _is_retryable_error(exc: BaseException) -> bool: + """Check if exception is retryable (connection errors or 5xx status).""" + if isinstance(exc, (httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadError)): + return True + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code >= 500: + return True + return False + + class ExternalInferenceClient: """Client for calling external inference engines (e.g., vLLM).""" @@ -31,7 +43,7 @@ async def call_and_store_result( async with httpx.AsyncClient( base_url=self.base_url, headers={"Authorization": f"Bearer {self.api_key}"}, - timeout=httpx.Timeout(300.0, connect=10.0), # 5 minutes for inference, 10s for connect + timeout=httpx.Timeout(600.0, connect=10.0), # 10 minutes for inference, 10s for connect ) as http_client: result = await self._forward_to_engine(sample_req, model_id, checkpoint_id, http_client) result_data = result.model_dump() @@ -48,6 +60,12 @@ async def call_and_store_result( future.completed_at = datetime.now(timezone.utc) await session.commit() + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2, min=2, max=30), + retry=retry_if_exception(_is_retryable_error), + reraise=True, + ) async def _forward_to_engine( self, request, model_id: str, checkpoint_id: str, http_client: httpx.AsyncClient ) -> types.SampleOutput: @@ -66,10 +84,10 @@ async def _forward_to_engine( if not target_dir.exists(): try: with download_and_unpack(checkpoint_path) as extracted_path: - extracted_path.rename(target_dir) + # shutil.move allows moving between filesystems. + shutil.move(str(extracted_path), str(target_dir)) except FileExistsError: - # This could happen if two processes try to download the file. - # In that case the other process won the race and created target_dir. + # Race condition: another process created target_dir pass payload = { diff --git a/skyrl-tx/tx/tinker/types.py b/skyrl-tx/tx/tinker/types.py index 33c1917c6..06c7b0393 100644 --- a/skyrl-tx/tx/tinker/types.py +++ b/skyrl-tx/tx/tinker/types.py @@ -166,6 +166,7 @@ class ModelMetadata(BaseModel): adapter_index: int lora_config: LoraConfig loaded_checkpoint_id: str | None = None + last_used: float | None = None # timestamp from time.time() class SampleInput(BaseModel): @@ -192,3 +193,51 @@ class SampleOutput(BaseModel): class EngineMetrics(BaseModel): train_seq_len_jit_times: dict[int, float] = {} sample_seq_len_jit_times: dict[int, float] = {} + + +# Prepared batch data for backend processing +# These are prepared by the engine and passed to the backend + +class PreparedModelPassBatch(BaseModel): + """Prepared batch data for forward/forward_backward operations. + + Engine extracts this from requests, backend converts to JAX arrays and computes. + """ + + # Per-example data (list of lists) + all_input_ids: list[list[int]] + all_targets: list[list[int]] + all_token_weights: list[list[float]] + all_sampling_logprobs: list[list[float]] + all_advantages: list[list[float]] + + # Per-example scalars + all_adapter_indices: list[int] + all_loss_fn_types: list[int] + + # Mapping from examples back to requests: (request_id, model_id, start_idx, end_idx) + request_batch_slices: list[tuple[int, str, int, int]] + + class Config: + arbitrary_types_allowed = True + + +class PreparedSampleBatch(BaseModel): + """Prepared batch data for sample operations. + + Engine extracts this from requests, backend converts to JAX arrays and computes. + """ + + # Per-sample data + all_prompts: list[list[int]] + all_sampling_params: list[SamplingParams] + all_adapter_indices: list[int] + + # Whether any request needs prompt logprobs + needs_prompt_logprobs: bool + + # Mapping from samples back to requests: (request_id, model_id, start_idx, end_idx, prompt_logprobs_requested) + request_batch_slices: list[tuple[str, str, int, int, bool]] + + class Config: + arbitrary_types_allowed = True diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index ceeb73ad4..275619617 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -269,14 +269,159 @@ def insert_state(path: tuple, p: jax.Array, new: jax.Array): nnx.update(lora_params, updated) -def round_up_seq_len(seq_len: int) -> int: +def convert_maxtext_lora_to_hf( + lora_state: nnx.State, + output_path: Path, + base_model_name: str = "", + lora_rank: int = 8, + lora_alpha: int = 32, +) -> None: + """Convert MaxText LoRA tensors to HuggingFace PEFT format. + + MaxText LoRA shapes (layer axis in middle, heads sometimes factored): + - query/key/value lora_a: (hidden_size, num_layers, rank) + - query lora_b: (rank, num_layers, num_heads, head_dim) + - key/value lora_b: (rank, num_layers, num_kv_heads, head_dim) + - out lora_a: (num_heads, num_layers, head_dim, rank) + - out lora_b: (rank, num_layers, hidden_size) + + HuggingFace PEFT format (per layer): + - base_model.model.model.layers.{i}.self_attn.{proj}.lora_A.weight: (rank, in_features) + - base_model.model.model.layers.{i}.self_attn.{proj}.lora_B.weight: (out_features, rank) + + Args: + lora_state: NNX state containing MaxText LoRA parameters + output_path: Path to save the PEFT checkpoint (directory) + base_model_name: Name of the base model for PEFT config + lora_rank: LoRA rank + lora_alpha: LoRA alpha scaling factor + """ + # Map MaxText projection names to HuggingFace names + proj_name_map = { + "query": "q_proj", + "key": "k_proj", + "value": "v_proj", + "out": "o_proj", + } + + # Collect all LoRA tensors by path + lora_tensors = {} + for path, val in jax.tree_util.tree_leaves_with_path(lora_state): + path_str = "/".join(str(k.key) if hasattr(k, 'key') else str(k) for k in path) + lora_tensors[path_str] = np.asarray(val) + + # Determine num_layers from any tensor (layer axis is at position 1 for most) + sample_tensor = next(iter(lora_tensors.values())) + # Find layer dimension - it's 48 in the examples + num_layers = None + for tensor in lora_tensors.values(): + # Layer axis is typically at position 1 for lora_a, position 1 for lora_b + if tensor.ndim >= 2 and tensor.shape[1] == 48: + num_layers = 48 + break + if tensor.ndim >= 2 and tensor.shape[0] == 48: + num_layers = 48 + break + + if num_layers is None: + # Try to infer from shapes + for tensor in lora_tensors.values(): + for dim in tensor.shape: + if dim in [36, 48, 64, 80, 96]: # Common layer counts + num_layers = dim + break + if num_layers: + break + + if num_layers is None: + raise ValueError("Could not determine num_layers from tensor shapes") + + logger.info(f"Converting MaxText LoRA to HuggingFace format: {num_layers} layers, rank={lora_rank}") + + # Output tensors in HuggingFace format + hf_tensors = {} + + for path_str, tensor in lora_tensors.items(): + # Parse the path to identify projection and lora_a/lora_b + # Example: base/decoder/layers/self_attention/query/lora_a/.value + parts = path_str.split("/") + + # Find projection name and lora type + proj_name = None + lora_type = None + for i, part in enumerate(parts): + if part in proj_name_map: + proj_name = proj_name_map[part] + if part in ("lora_a", "lora_b"): + lora_type = "lora_A" if part == "lora_a" else "lora_B" + + if proj_name is None or lora_type is None: + logger.warning(f"Skipping unrecognized path: {path_str}") + continue + + logger.info(f"Converting {path_str}: shape {tensor.shape} -> {proj_name}/{lora_type}") + + # Convert based on projection and lora type + # MaxText shapes vary, need to handle each case + for layer_idx in range(num_layers): + if lora_type == "lora_A": + if proj_name == "o_proj": + # out lora_a: (num_heads, num_layers, head_dim, rank) -> (rank, num_heads * head_dim) + # Layer axis at position 1 + layer_tensor = tensor[:, layer_idx, :, :] # (num_heads, head_dim, rank) + # Flatten heads: (num_heads * head_dim, rank) then transpose to (rank, in_features) + layer_tensor = layer_tensor.reshape(-1, layer_tensor.shape[-1]) # (in_features, rank) + layer_tensor = layer_tensor.T # (rank, in_features) + else: + # query/key/value lora_a: (hidden_size, num_layers, rank) -> (rank, hidden_size) + # Layer axis at position 1 + layer_tensor = tensor[:, layer_idx, :] # (hidden_size, rank) + layer_tensor = layer_tensor.T # (rank, hidden_size) + else: # lora_B + if proj_name == "o_proj": + # out lora_b: (rank, num_layers, hidden_size) -> (hidden_size, rank) + # Layer axis at position 1 + layer_tensor = tensor[:, layer_idx, :] # (rank, hidden_size) + layer_tensor = layer_tensor.T # (hidden_size, rank) + else: + # query lora_b: (rank, num_layers, num_heads, head_dim) -> (num_heads * head_dim, rank) + # key/value lora_b: (rank, num_layers, num_kv_heads, head_dim) -> (num_kv_heads * head_dim, rank) + # Layer axis at position 1 + layer_tensor = tensor[:, layer_idx, ...] # (rank, num_heads, head_dim) or (rank, num_kv_heads, head_dim) + # Flatten heads and transpose: (rank, out_features) -> (out_features, rank) + layer_tensor = layer_tensor.reshape(layer_tensor.shape[0], -1) # (rank, out_features) + layer_tensor = layer_tensor.T # (out_features, rank) + + # HuggingFace PEFT key format + hf_key = f"base_model.model.model.layers.{layer_idx}.self_attn.{proj_name}.{lora_type}.weight" + hf_tensors[hf_key] = layer_tensor.astype(np.float32) + + # Save as safetensors + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + safetensors.numpy.save_file(hf_tensors, output_path / "adapter_model.safetensors") + logger.info(f"Saved {len(hf_tensors)} tensors to {output_path / 'adapter_model.safetensors'}") + + # Save PEFT config + peft_config = peft.LoraConfig( + base_model_name_or_path=base_model_name, + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ) + peft_config.save_pretrained(output_path) + logger.info(f"Saved PEFT config to {output_path}") + + +def round_up_seq_len(seq_len: int, min_seq_len: int = 32) -> int: """ Rounds a sequence length up to roughly two significant binary digits. We do this to pad sequences, so the Jax JIT compiler needs to compile fewer different shapes. """ - if seq_len <= 32: - return 32 + if seq_len <= min_seq_len: + return min_seq_len # Find the position of the most significant bit. msb_pos = seq_len.bit_length() - 1 diff --git a/ttl.md b/ttl.md new file mode 100644 index 000000000..1e7098e77 --- /dev/null +++ b/ttl.md @@ -0,0 +1,40 @@ +``` +# Storage structure +store = { + key: { + value: any, + expires_at: timestamp | null + } +} + +# SET with TTL +function set(key, value, ttl_seconds=null): + expires_at = null + if ttl_seconds: + expires_at = now() + ttl_seconds + store[key] = {value: value, expires_at: expires_at} + +# GET with lazy expiration check +function get(key): + if key not in store: + return null + entry = store[key] + if entry.expires_at and now() > entry.expires_at: + delete store[key] + return null + return entry.value + +# Optional: background cleanup (run periodically) +function cleanup_expired(): + for key in store.keys(): + if store[key].expires_at and now() > store[key].expires_at: + delete store[key] +``` + +--- + +Two common strategies: +- **Lazy expiration** (check on read) - simple, no background work, but dead keys linger until accessed +- **Active expiration** (background sweep) - cleaner memory, but adds complexity/overhead + +Most production systems do both - lazy on read + periodic background cleanup for keys nobody's asking for. \ No newline at end of file