diff --git a/skyrl-tx/tests/torch/models/test_qwen3.py b/skyrl-tx/tests/torch/models/test_qwen3.py new file mode 100644 index 000000000..7c184cb57 --- /dev/null +++ b/skyrl-tx/tests/torch/models/test_qwen3.py @@ -0,0 +1,347 @@ +import tempfile + +import pytest +import safetensors.torch +import torch +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig + +from tx.models import Qwen3Config +from tx.torch.models.qwen3 import Qwen3ForCausalLM +from tx.torch.layers.lora import LoRAMixin + +pytestmark = pytest.mark.torch # Mark all tests in this file as torch tests + + +@pytest.fixture +def device(): + """Return the device to use for tests.""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def load_lora_weights( + module: LoRAMixin, + adapter_idx: int, + lora_A_weights: torch.Tensor, + lora_B_weights: torch.Tensor, + scaling: float, + rank: int, +) -> None: + """Load LoRA weights from tensors to a module with LoRA support. + + This is a generic helper that works with any LoRAMixin module. + + Args: + module: Module with LoRA support (LoRAMixin) + adapter_idx: Index of the adapter to load weights into + lora_A_weights: Weights for lora_A matrix [in_features, rank] + lora_B_weights: Weights for lora_B matrix [rank, out_features] + scaling: Scaling factor (typically lora_alpha / rank) + rank: Rank of the LoRA adapter + """ + assert module.lora_A is not None and module.lora_B is not None + assert module.lora_scaling is not None and module.lora_ranks is not None + + with torch.no_grad(): + # Copy lora_A and lora_B weights + module.lora_A[adapter_idx, :, :rank].copy_(lora_A_weights) + module.lora_B[adapter_idx, :rank, :].copy_(lora_B_weights) + + # Set scaling and rank + module.lora_scaling[adapter_idx] = scaling + module.lora_ranks[adapter_idx] = rank + + +def load_model_from_hf_checkpoint( + checkpoint_dir: str, config: Qwen3Config, device, dtype: torch.dtype = torch.float32 +) -> Qwen3ForCausalLM: + """Load our model from a HuggingFace checkpoint directory.""" + model = Qwen3ForCausalLM(config, dtype=dtype, device=device) + + # Load all safetensors files + state_dict = {} + from pathlib import Path + + for file in Path(checkpoint_dir).glob("*.safetensors"): + state_dict.update(safetensors.torch.load_file(file)) + + # Load into model (strict=False because we may have LoRA params that HF doesn't have) + model.load_state_dict(state_dict, strict=False) + return model.to(device) + + +def load_lora_adapter_from_hf(our_model: Qwen3ForCausalLM, hf_peft_model, adapter_idx: int, lora_config: LoraConfig): + """Load LoRA adapter weights from HuggingFace PEFT model to our model. + + This iterates through all layers and uses the generic load_lora_weights helper + to load weights from the HF PEFT model structure. + """ + scaling = lora_config.lora_alpha / lora_config.r + rank = lora_config.r + + for i, layer in enumerate(our_model.model.layers): + hf_layer = hf_peft_model.base_model.model.model.layers[i] + + # Attention projections + for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]: + hf_proj = getattr(hf_layer.self_attn, proj_name) + our_proj = getattr(layer.self_attn, proj_name) + load_lora_weights( + our_proj, + adapter_idx=adapter_idx, + lora_A_weights=hf_proj.lora_A["default"].weight.T, + lora_B_weights=hf_proj.lora_B["default"].weight.T, + scaling=scaling, + rank=rank, + ) + + # MLP projections + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + hf_proj = getattr(hf_layer.mlp, proj_name) + our_proj = getattr(layer.mlp, proj_name) + load_lora_weights( + our_proj, + adapter_idx=adapter_idx, + lora_A_weights=hf_proj.lora_A["default"].weight.T, + lora_B_weights=hf_proj.lora_B["default"].weight.T, + scaling=scaling, + rank=rank, + ) + + +def test_qwen3_basic_shapes(device): + """Test that the model initializes and produces correct output shapes.""" + base_config = PretrainedConfig.from_pretrained("Qwen/Qwen3-0.6B") + config = Qwen3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=False) + + model = Qwen3ForCausalLM(config, dtype=torch.float32, device=device).to(device) + + # Create dummy input + batch_size, seq_len = 2, 10 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones(batch_size, seq_len, device=device) + + # Forward pass + with torch.no_grad(): + outputs = model(input_ids, attention_mask=attention_mask) + + # Check shapes + assert outputs.logits.shape == (batch_size, seq_len, config.vocab_size) + assert outputs.last_hidden_state.shape == (batch_size, seq_len, config.hidden_size) + assert outputs.kv_cache is not None + assert len(outputs.kv_cache.keys) == config.num_hidden_layers + + +def test_qwen3_vs_hf(device): + """Test that our PyTorch implementation matches HuggingFace outputs.""" + model_name = "Qwen/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Prepare input + inputs = ["The capital of France is", "The most popular programming language is"] + batch = tokenizer(inputs, return_tensors="pt", padding=True) + input_ids = batch.input_ids.to(device) + attention_mask = batch.attention_mask.to(device) + + with tempfile.TemporaryDirectory() as tmp: + # Load and save HF model + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + hf_model.save_pretrained(tmp, safe_serialization=True) + hf_model = hf_model.to(device) + hf_model.eval() + + # Get HF outputs + with torch.no_grad(): + hf_outputs = hf_model(input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True) + + # Load our model from saved checkpoint + base_config = PretrainedConfig.from_pretrained(model_name) + config = Qwen3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=False) + model = load_model_from_hf_checkpoint(tmp, config, device) + model.eval() + + # Get our outputs + with torch.no_grad(): + outputs = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) + + # Compare outputs + assert outputs.hidden_states is not None + hf_hidden_states = hf_outputs.hidden_states + our_hidden_states = outputs.hidden_states + + # Check first layer (after embedding) + assert torch.allclose( + hf_hidden_states[0], our_hidden_states[0], rtol=1e-4, atol=1e-4 + ), f"First hidden state mismatch: max diff = {(hf_hidden_states[0] - our_hidden_states[0]).abs().max()}" + + # Check middle layer + mid_idx = len(hf_hidden_states) // 2 + assert torch.allclose( + hf_hidden_states[mid_idx], our_hidden_states[mid_idx], rtol=1e-3, atol=1e-3 + ), f"Middle hidden state mismatch: max diff = {(hf_hidden_states[mid_idx] - our_hidden_states[mid_idx]).abs().max()}" + + # Check final layer + assert torch.allclose( + hf_hidden_states[-1], our_hidden_states[-1], rtol=1e-3, atol=1e-3 + ), f"Final hidden state mismatch: max diff = {(hf_hidden_states[-1] - our_hidden_states[-1]).abs().max()}" + + # Check logits + assert torch.allclose( + hf_outputs.logits, outputs.logits, rtol=1e-3, atol=1e-3 + ), f"Logits mismatch: max diff = {(hf_outputs.logits - outputs.logits).abs().max()}" + + +def test_qwen3_lora_adapters(device): + """Test multiple LoRA adapters by comparing with HuggingFace PEFT models using two different adapters.""" + base_model_name = "Qwen/Qwen3-0.6B" + lora_adapters = ["pcmoritz/qwen3-0.6b-lora-random", "pcmoritz/qwen3-0.6b-lora-random2"] + + tokenizer = AutoTokenizer.from_pretrained(base_model_name) + # Use two different inputs to test with different adapters + inputs = ["The capital of France is", "My name is"] + batch = tokenizer(inputs, return_tensors="pt", padding=True) + input_ids = batch.input_ids.to(device) + attention_mask = batch.attention_mask.to(device) + + with tempfile.TemporaryDirectory() as base_tmp: + # Save base model checkpoint + base_hf_model = AutoModelForCausalLM.from_pretrained( + base_model_name, attn_implementation="eager", use_safetensors=True + ) + base_hf_model.save_pretrained(base_tmp, safe_serialization=True) + + # Create HF PEFT models with different adapters + hf_lora_models = [] + lora_configs = [] + for adapter_name in lora_adapters: + lora_config = LoraConfig.from_pretrained(adapter_name) + lora_config.target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + lora_configs.append(lora_config) + + hf_base = AutoModelForCausalLM.from_pretrained( + base_model_name, attn_implementation="eager", use_safetensors=True + ) + hf_model = get_peft_model(hf_base, lora_config) + hf_model.load_adapter(adapter_name, adapter_name="default") + hf_model.to(device) + hf_model.eval() + hf_lora_models.append(hf_model) + + # Get outputs from all HF models + hf_outputs_list = [] + with torch.no_grad(): + for idx in range(len(lora_adapters)): + hf_output = hf_lora_models[idx]( + input_ids[idx : idx + 1], + attention_mask=attention_mask[idx : idx + 1], + output_hidden_states=True, + return_dict=True, + ) + hf_outputs_list.append(hf_output) + + # Create our model with LoRA support and load base weights from checkpoint + base_config = PretrainedConfig.from_pretrained(base_model_name) + config = Qwen3Config( + base_config, + max_lora_adapters=len(lora_adapters), + max_lora_rank=max(cfg.r for cfg in lora_configs), + shard_attention_heads=False, + ) + model = load_model_from_hf_checkpoint(base_tmp, config, device) + + # Load LoRA adapter weights from all adapters + for adapter_idx, (hf_model, lora_config) in enumerate(zip(hf_lora_models, lora_configs)): + load_lora_adapter_from_hf(model, hf_model, adapter_idx, lora_config) + + model.eval() + + # Use different adapter indices for each input + adapter_indices = torch.arange(len(lora_adapters), dtype=torch.long, device=device) + with torch.no_grad(): + outputs = model( + input_ids, + attention_mask=attention_mask, + adapter_indices=adapter_indices, + output_hidden_states=True, + ) + + # Compare outputs with corresponding adapters + for idx in range(len(lora_adapters)): + max_diff = (hf_outputs_list[idx].logits[0] - outputs.logits[idx]).abs().max().item() + assert torch.allclose( + hf_outputs_list[idx].logits[0], outputs.logits[idx], rtol=1e-3, atol=1e-3 + ), f"Adapter {idx} logits mismatch: max diff = {max_diff}" + + +def test_qwen3_kv_cache(device): + """Test that KV cache works correctly for generation.""" + model_name = "Qwen/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Prepare input + input_text = "The capital of France is" + batch = tokenizer([input_text], return_tensors="pt") + input_ids = batch.input_ids.to(device) + attention_mask = batch.attention_mask.to(device) + + with tempfile.TemporaryDirectory() as tmp: + # Save HF model checkpoint + hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) + hf_model.save_pretrained(tmp, safe_serialization=True) + + # Load our model from checkpoint + base_config = PretrainedConfig.from_pretrained(model_name) + config = Qwen3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=False) + model = load_model_from_hf_checkpoint(tmp, config, device) + model.eval() + + # Test 1: Prefill phase (no cache) + with torch.no_grad(): + output_no_cache = model(input_ids, attention_mask=attention_mask) + + # Test 2: Using cache for next token + + with torch.no_grad(): + # Prefill + prefill_output = model(input_ids, attention_mask=attention_mask) + kv_cache = prefill_output.kv_cache + + # Pad cache to accommodate more tokens (e.g., 20 tokens total) + max_length = 20 + kv_cache = kv_cache.pad_to_length(max_length) + + # Next token (simulate getting next token) + next_token = output_no_cache.logits[:, -1:, :].argmax(dim=-1) + + # Build attention mask for the full sequence (actual tokens + new token + padding) + extended_attention_mask = torch.cat([attention_mask, torch.ones(1, 1, device=device)], dim=1) + + # Pad attention mask to match cache size + mask_padding = max_length - extended_attention_mask.shape[1] + if mask_padding > 0: + extended_attention_mask = torch.cat( + [extended_attention_mask, torch.zeros(1, mask_padding, device=device)], dim=1 + ) + + # Compute position for the new token explicitly (matching JAX implementation) + # The new token is at position cache_position (5 in this case) + next_position = torch.tensor([[kv_cache.cache_position]], device=device) + + # Generate with cache + cache_output = model( + next_token, attention_mask=extended_attention_mask, positions=next_position, kv_cache=kv_cache + ) + + # The cache output should be valid (no NaNs) + assert not torch.isnan(cache_output.logits).any(), "KV cache produced NaN values" + assert ( + cache_output.kv_cache.cache_position == input_ids.shape[1] + 1 + ), f"Cache position should be {input_ids.shape[1] + 1}, got {cache_output.kv_cache.cache_position}" diff --git a/skyrl-tx/tx/torch/layers/lora.py b/skyrl-tx/tx/torch/layers/lora.py new file mode 100644 index 000000000..999add9cb --- /dev/null +++ b/skyrl-tx/tx/torch/layers/lora.py @@ -0,0 +1,203 @@ +from __future__ import annotations +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from .util import Param, prepare_routing + + +class LoRAMixin(nn.Module): + """A mixin for PyTorch modules to add multi-adapter LoRA support. + + This mixin adds LoRA parameters (lora_A, lora_B) and methods to apply + the low-rank adaptation to a base module's output. + + Provides: + - init_lora(...) -> allocate lora_A, lora_B, lora_scaling, lora_ranks + - apply_lora(x, base_output, adapter_indices) -> apply LoRA adaptation + + Stored tensors (when enabled): + lora_A: [A, in_features, r_max] (Param, He-uniform) + lora_B: [A, r_max, out_features] (Param, zeros) + lora_scaling:[A] (Buffer, alpha/rank per adapter) + lora_ranks: [A] (Buffer, int rank per adapter) + """ + + lora_scaling: torch.Tensor | None + lora_ranks: torch.Tensor | None + lora_A: nn.Parameter | None + lora_B: nn.Parameter | None + + def init_lora( + self, + *, + max_lora_adapters: int, + max_lora_rank: int, + shape_A: tuple[int, ...], + shape_B: tuple[int, ...], + dtype: torch.dtype, + device: torch.device | str, + ) -> None: + self.max_lora_adapters = int(max_lora_adapters) + self.max_lora_rank = int(max_lora_rank) + + if self.max_lora_adapters == 0: + self.lora_scaling = None + self.lora_ranks = None + self.lora_A = None + self.lora_B = None + return + + self.register_buffer( + "lora_scaling", + torch.full((self.max_lora_adapters,), 1.0, dtype=dtype, device=device), + persistent=True, + ) + self.register_buffer( + "lora_ranks", + torch.full((self.max_lora_adapters,), self.max_lora_rank, dtype=torch.int32, device=device), + persistent=True, + ) + + self.lora_A = Param( + *shape_A, + dtype=dtype, + kernel_init=lambda t: nn.init.kaiming_uniform_(t, a=math.sqrt(5)), + device=device, + ) + self.lora_B = Param( + *shape_B, + dtype=dtype, + kernel_init=nn.init.zeros_, + device=device, + ) + + def apply_lora( + self, + x: torch.Tensor, + base_output: torch.Tensor, + adapter_indices: torch.Tensor | None, + ) -> torch.Tensor: + """Apply multi-adapter LoRA to base module output. + + Args: + x: Input tensor [B, T, in_features] + base_output: Base module output [B, T, out_features] + adapter_indices: Adapter index per batch element [B], broadcasted over sequence length + + Returns: + base_output + lora_output with per-adapter routing and scaling + """ + if self.max_lora_adapters == 0 or adapter_indices is None: + return base_output + + if x.dim() != 3: + raise ValueError("x must be [B, T, in_features].") + B, T, in_features = x.shape + if adapter_indices.dim() != 1 or adapter_indices.size(0) != B: + raise ValueError("adapter_indices must be shape [B].") + + # Flatten tokens to [N, in] + x_flat = x.reshape(-1, in_features) # [B*T, in] + # Broadcast adapter ids across sequence length + adapters_flat = adapter_indices.repeat_interleave(T) # [B*T] + + # Route by adapter (ragged groups) + x_sorted, group_sizes, unsort_idx, _ = prepare_routing( + tokens=x_flat, + indices=adapters_flat, + num_groups=self.max_lora_adapters, + adapter_indices=None, + ) + + # Compute LoRA: (x @ A) @ B per-adapter group + N = x_sorted.size(0) + out_features = base_output.size(-1) + y_sorted = torch.empty(N, out_features, dtype=base_output.dtype, device=base_output.device) + + offset = 0 + for adapter_index, group_size in enumerate(group_sizes.tolist()): + if group_size == 0: + continue + start_idx, end_idx = offset, offset + group_size + adapter_input = x_sorted[start_idx:end_idx] # [group_size, in_features] + adapter_rank = int(self.lora_ranks[adapter_index].item()) + if adapter_rank > 0: + lora_A_matrix = self.lora_A[adapter_index, :, :adapter_rank] # [in_features, adapter_rank] + lora_B_matrix = self.lora_B[adapter_index, :adapter_rank, :] # [adapter_rank, out_features] + intermediate_result = adapter_input.matmul(lora_A_matrix) # [group_size, adapter_rank] + adapter_output = intermediate_result.matmul(lora_B_matrix) # [group_size, out_features] + else: + adapter_output = torch.zeros(group_size, out_features, dtype=y_sorted.dtype, device=y_sorted.device) + y_sorted[start_idx:end_idx] = adapter_output + offset = end_idx + + # Unsort back to original token order -> [B*T, out] + y_flat = y_sorted[unsort_idx] + + # Reshape and scale: lora_output * self.lora_scaling[adapter_indices, None, None] + y = y_flat.view(B, T, out_features) + y = y * self.lora_scaling[adapter_indices].view(B, 1, 1) + + return base_output + y + + +class LoRALinear(LoRAMixin, nn.Linear): + """An nn.Linear layer with multi-adapter LoRA support. + + Combines base linear transformation with optional per-adapter low-rank updates. + + Forward pass: + base_out = F.linear(x, weight, bias) + return self.apply_lora(x, base_out, adapter_indices) + """ + + def __init__( + self, + in_features: int, + out_features: int, + *, + max_lora_adapters: int, + max_lora_rank: int, + dtype: torch.dtype, + use_bias: bool, + device: torch.device | str, + ): + nn.Linear.__init__(self, in_features, out_features, bias=use_bias, device=device, dtype=dtype) + LoRAMixin.init_lora( + self, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + shape_A=(max_lora_adapters, in_features, max_lora_rank), + shape_B=(max_lora_adapters, max_lora_rank, out_features), + dtype=self.weight.dtype, + device=self.weight.device, + ) + + def forward(self, x: torch.Tensor, adapter_indices: torch.Tensor | None = None) -> torch.Tensor: + base_out = F.linear(x, self.weight, self.bias) + return self.apply_lora(x, base_out, adapter_indices) + + +def update_adapter_config(model: nn.Module, adapter_index: int, lora_rank: int, lora_alpha: float): + """Update lora_ranks and lora_scaling for a specific adapter across all LoRA layers. + + Note: This method needs to be called BEFORE any training happens, you should not update + the config for the same adapter index multiple times throughout training (e.g. it will + invalidate your current training progress and also violate the assumption that lora_B + is zero). + + Args: + model: The model containing LoRA layers + adapter_index: Index of the adapter to update + lora_rank: Rank to set for this adapter + lora_alpha: Alpha value to use for computing scaling (alpha / rank) + """ + scaling = lora_alpha / lora_rank + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, LoRAMixin) and m.max_lora_adapters > 0: + m.lora_ranks[adapter_index] = lora_rank + m.lora_scaling[adapter_index] = scaling + # Zero out columns beyond the rank for this adapter; lora_B is already zero + m.lora_A.data[adapter_index, :, lora_rank:] = 0.0 diff --git a/skyrl-tx/tx/torch/layers/util.py b/skyrl-tx/tx/torch/layers/util.py new file mode 100644 index 000000000..e4bf7d152 --- /dev/null +++ b/skyrl-tx/tx/torch/layers/util.py @@ -0,0 +1,59 @@ +from __future__ import annotations +import torch +import torch.nn as nn + + +def Param( + *shape: int, + dtype: torch.dtype, + kernel_init: callable, + device: torch.device | str, +) -> nn.Parameter: + """Create an initialized parameter tensor. + + Args: + *shape: Shape of the parameter tensor + dtype: Data type of the tensor + kernel_init: Initialization function that modifies tensor in-place + device: Device to place tensor on + + Returns: + Initialized nn.Parameter + """ + tensor = torch.empty(*shape, dtype=dtype, device=device) + kernel_init(tensor) + return nn.Parameter(tensor, requires_grad=True) + + +def prepare_routing( + tokens: torch.Tensor, + indices: torch.Tensor, + num_groups: int, + adapter_indices: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Prepare inputs for ragged operations by sorting tokens by group. + + Args: + tokens: Tensor of shape (num_tokens, ...) to be sorted by group + indices: Tensor of shape (num_tokens,) indicating group assignment for each token + num_groups: Total number of groups + adapter_indices: Optional tensor of shape (num_tokens,) to be sorted together with tokens + + Returns: + sorted_tokens: Tokens sorted by group index + group_sizes: Number of tokens in each group + unsort_indices: Indices to restore original order after ragged operations + sorted_adapter_indices: Adapter indices sorted with tokens (or None if not provided) + """ + # Sort by group index + sort_idx = torch.argsort(indices) + sorted_tokens = tokens[sort_idx] + sorted_adapter_indices = None if adapter_indices is None else adapter_indices[sort_idx] + + # Compute group sizes (minlength guarantees output length) + group_sizes = torch.bincount(indices, minlength=num_groups) + + # Inverse permutation to restore original order + unsort_indices = torch.argsort(sort_idx) + + return sorted_tokens, group_sizes, unsort_indices, sorted_adapter_indices diff --git a/skyrl-tx/tx/torch/models/__init__.py b/skyrl-tx/tx/torch/models/__init__.py new file mode 100644 index 000000000..2b1b0e596 --- /dev/null +++ b/skyrl-tx/tx/torch/models/__init__.py @@ -0,0 +1,10 @@ +"""PyTorch model implementations.""" + +from tx.torch.models.outputs import CausalLMOutput, ModelOutput +from tx.torch.models.qwen3 import Qwen3ForCausalLM + +__all__ = [ + "Qwen3ForCausalLM", + "CausalLMOutput", + "ModelOutput", +] diff --git a/skyrl-tx/tx/torch/models/outputs.py b/skyrl-tx/tx/torch/models/outputs.py new file mode 100644 index 000000000..8777fda00 --- /dev/null +++ b/skyrl-tx/tx/torch/models/outputs.py @@ -0,0 +1,41 @@ +"""Model output dataclasses for PyTorch.""" + +from __future__ import annotations +from dataclasses import dataclass +from typing import List + +import torch + +from tx.torch.utils.generator import KVCache + + +@dataclass +class ModelOutput: + """Output type for models like Qwen3Model. + + Attributes: + last_hidden_state: The last hidden state from the model [B, T, hidden_size] + kv_cache: The updated key-value cache + hidden_states: All hidden states if output_hidden_states=True + """ + + last_hidden_state: torch.Tensor + kv_cache: KVCache + hidden_states: List[torch.Tensor] | None = None + + +@dataclass +class CausalLMOutput: + """Output type for causal language models like Qwen3ForCausalLM. + + Attributes: + logits: The language modeling logits [B, T, vocab_size] + last_hidden_state: The last hidden state from the model [B, T, hidden_size] + kv_cache: The updated key-value cache + hidden_states: All hidden states, if output_hidden_states=True + """ + + logits: torch.Tensor + last_hidden_state: torch.Tensor + kv_cache: KVCache + hidden_states: List[torch.Tensor] | None = None diff --git a/skyrl-tx/tx/torch/models/qwen3.py b/skyrl-tx/tx/torch/models/qwen3.py new file mode 100644 index 000000000..19acc3369 --- /dev/null +++ b/skyrl-tx/tx/torch/models/qwen3.py @@ -0,0 +1,354 @@ +from __future__ import annotations +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tx.models.configs import Qwen3Config +from tx.torch.layers.lora import LoRALinear +from tx.torch.models.outputs import ModelOutput, CausalLMOutput +from tx.torch.utils.generator import KVCache, compute_positions + + +def apply_rope(inputs: torch.Tensor, position_ids: torch.Tensor, head_dim: int, theta: int) -> torch.Tensor: + """Apply rotary position embeddings to input tensor. + + Args: + inputs: Input tensor of shape [B, n_head, T, head_dim] + position_ids: Position IDs of shape [B, T] (can include negative positions for left-padding) + head_dim: Dimension of each attention head + theta: RoPE theta parameter + + Returns: + Tensor with rotary embeddings applied + """ + fraction = 2 * torch.arange(0, head_dim // 2, dtype=torch.float32, device=inputs.device) / head_dim + timescale = theta**fraction + x = position_ids[:, None, :, None] / timescale[None, None, None, :] # [B, 1, T, dim/2] + sin, cos = x.sin().to(dtype=inputs.dtype), x.cos().to(dtype=inputs.dtype) # [B, 1, T, dim/2] + a, b = inputs.chunk(2, dim=-1) + return torch.cat([a * cos - b * sin, b * cos + a * sin], dim=-1) + + +class Qwen3Attention(nn.Module): + def __init__( + self, + config: Qwen3Config, + *, + dtype: torch.dtype, + device: torch.device | str, + max_lora_adapters: int, + max_lora_rank: int, + ): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // self.num_heads + self.gqa_groups = self.num_heads // self.num_kv_heads + + self.q_proj = LoRALinear( + config.hidden_size, + self.num_heads * self.head_dim, + use_bias=False, + dtype=dtype, + device=device, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + ) + self.k_proj = LoRALinear( + config.hidden_size, + self.num_kv_heads * self.head_dim, + use_bias=False, + dtype=dtype, + device=device, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + ) + self.v_proj = LoRALinear( + config.hidden_size, + self.num_kv_heads * self.head_dim, + use_bias=False, + dtype=dtype, + device=device, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + ) + self.o_proj = LoRALinear( + self.num_heads * self.head_dim, + config.hidden_size, + use_bias=False, + dtype=dtype, + device=device, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + ) + + self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps).to(dtype=dtype) + self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps).to(dtype=dtype) + + def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor: + # [B, n_kv, T, H] -> [B, n_head, T, H] + return x if self.num_kv_heads == self.num_heads else x.repeat_interleave(self.gqa_groups, dim=1) + + def forward( + self, + x: torch.Tensor, # [B, T, D] + *, + attention_mask: torch.Tensor, # [B, T_kv] (1=keep, 0=mask) + positions: torch.Tensor, # [B, T] + adapter_indices: torch.Tensor | None = None, # [B] or None + kv_cache: tuple[torch.Tensor, torch.Tensor, int] | None = None, # (k_cache, v_cache, cache_position) + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + B, T, _ = x.shape + + # Project and reshape to [B, T, num_heads, head_dim] + q = self.q_norm(self.q_proj(x, adapter_indices=adapter_indices).view(B, T, self.num_heads, self.head_dim)) + k = self.k_norm(self.k_proj(x, adapter_indices=adapter_indices).view(B, T, self.num_kv_heads, self.head_dim)) + v = self.v_proj(x, adapter_indices=adapter_indices).view(B, T, self.num_kv_heads, self.head_dim) + + # Transpose to [B, n, T, H] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Apply RoPE + q = apply_rope(q, positions, self.head_dim, self.config.rope_theta) + k = apply_rope(k, positions, self.head_dim, self.config.rope_theta) + + # Handle KV cache update + if kv_cache is not None: + k_cache, v_cache, cache_pos = kv_cache + + T_new = k.size(2) + with torch.no_grad(): + k_cache[:, :, cache_pos : cache_pos + T_new, :].copy_(k) + v_cache[:, :, cache_pos : cache_pos + T_new, :].copy_(v) + + k, v = k_cache, v_cache + + updated_cache = (k, v) + + # Attention (causal only during prefill, GQA handled via repeat) + k_full = self._repeat_kv(k) + v_full = self._repeat_kv(v) + + # Use SDPA with bool mask + attn_mask_bool = attention_mask[:, None, None, :].to(dtype=torch.bool) + attn_output = F.scaled_dot_product_attention( + q, + k_full, + v_full, + attn_mask=attn_mask_bool, + dropout_p=0.0, + is_causal=(kv_cache is None), + ) + + output = attn_output.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_dim) + return self.o_proj(output, adapter_indices=adapter_indices), updated_cache + + +class Qwen3MLP(nn.Module): + def __init__( + self, + config: Qwen3Config, + *, + dtype: torch.dtype, + device: torch.device | str, + max_lora_adapters: int, + max_lora_rank: int, + ): + super().__init__() + self.gate_proj = LoRALinear( + config.hidden_size, + config.intermediate_size, + use_bias=False, + dtype=dtype, + device=device, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + ) + self.up_proj = LoRALinear( + config.hidden_size, + config.intermediate_size, + use_bias=False, + dtype=dtype, + device=device, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + ) + self.down_proj = LoRALinear( + config.intermediate_size, + config.hidden_size, + use_bias=False, + dtype=dtype, + device=device, + max_lora_adapters=max_lora_adapters, + max_lora_rank=max_lora_rank, + ) + + def forward(self, x: torch.Tensor, adapter_indices: torch.Tensor | None = None) -> torch.Tensor: + gate_out = self.gate_proj(x, adapter_indices) + up_out = self.up_proj(x, adapter_indices) + return self.down_proj(F.silu(gate_out) * up_out, adapter_indices) + + +class Qwen3DecoderLayer(nn.Module): + def __init__( + self, + config: Qwen3Config, + *, + dtype: torch.dtype, + device: torch.device | str, + max_lora_adapters: int, + max_lora_rank: int, + ): + super().__init__() + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(dtype=dtype) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(dtype=dtype) + self.self_attn = Qwen3Attention( + config, dtype=dtype, device=device, max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank + ) + self.mlp = Qwen3MLP( + config, dtype=dtype, device=device, max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank + ) + + def forward( + self, + hidden_states: torch.Tensor, + *, + attention_mask: torch.Tensor, + positions: torch.Tensor, + adapter_indices: torch.Tensor | None = None, + kv_cache: tuple[torch.Tensor, torch.Tensor, int] | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, updated_cache = self.self_attn( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, adapter_indices=adapter_indices) + hidden_states = residual + hidden_states + + return hidden_states, updated_cache + + +class Qwen3Model(nn.Module): + def __init__(self, config: Qwen3Config, *, dtype: torch.dtype, device: torch.device | str): + super().__init__() + self.config = config + max_lora_adapters = getattr(config, "max_lora_adapters", 0) + max_lora_rank = getattr(config, "max_lora_rank", 8) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype) + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + config, dtype=dtype, device=device, max_lora_adapters=max_lora_adapters, max_lora_rank=max_lora_rank + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(dtype=dtype) + + def forward( + self, + input_ids: torch.Tensor, # [B, T] + *, + attention_mask: torch.Tensor, # [B, T_kv] (1=keep, 0=mask) + positions: torch.Tensor, # [B, T] + output_hidden_states: bool | None = None, + adapter_indices: torch.Tensor | None = None, # [B] or None + kv_cache: KVCache | None = None, + ) -> ModelOutput: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embed_tokens(input_ids) + all_hidden_states = [] + updated_keys, updated_values = [], [] + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + + hidden_states, (k, v) = layer( + hidden_states, + attention_mask=attention_mask, + positions=positions, + adapter_indices=adapter_indices, + kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position), + ) + updated_keys.append(k) + updated_values.append(v) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + # Increment cache_position if cache exists, or use sequence length for new cache + new_cache_position = ( + kv_cache.cache_position + input_ids.shape[1] if kv_cache is not None else input_ids.shape[1] + ) + + return ModelOutput( + last_hidden_state=hidden_states, + kv_cache=KVCache(keys=updated_keys, values=updated_values, cache_position=new_cache_position), + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + +class Qwen3ForCausalLM(nn.Module): + def __init__(self, config: Qwen3Config, *, dtype: torch.dtype, device: torch.device | str): + super().__init__() + self.config = config + self.model = Qwen3Model(config, dtype=dtype, device=device) + if not self.config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype=dtype) + + @staticmethod + def is_lora_param(name: str) -> bool: + """Return True if a parameter name corresponds to LoRA weights.""" + return "lora_A" in name or "lora_B" in name + + def forward( + self, + input_ids: torch.Tensor, + *, + attention_mask: torch.Tensor, + positions: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + adapter_indices: torch.Tensor | None = None, + kv_cache: KVCache | None = None, + ) -> CausalLMOutput: + if positions is None: + positions = compute_positions(attention_mask) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + positions=positions, + output_hidden_states=output_hidden_states, + adapter_indices=adapter_indices, + kv_cache=kv_cache, + ) + hidden_states = outputs.last_hidden_state + if self.config.tie_word_embeddings: + logits = F.linear(hidden_states, self.model.embed_tokens.weight) + else: + logits = self.lm_head(hidden_states) + + return CausalLMOutput( + logits=logits, + last_hidden_state=outputs.last_hidden_state, + kv_cache=outputs.kv_cache, + hidden_states=outputs.hidden_states, + ) diff --git a/skyrl-tx/tx/torch/utils/generator.py b/skyrl-tx/tx/torch/utils/generator.py new file mode 100644 index 000000000..32de1ad8d --- /dev/null +++ b/skyrl-tx/tx/torch/utils/generator.py @@ -0,0 +1,64 @@ +"""Generator utilities for autoregressive text generation with KV caching.""" + +from __future__ import annotations +from dataclasses import dataclass +from typing import List +import torch + + +@dataclass +class KVCache: + """Key-value cache for all layers, each entry in the list corresponds to one layer. + + Attributes: + keys: List of key tensors, one per layer [B, n_heads, T_cache, head_dim] + values: List of value tensors, one per layer [B, n_heads, T_cache, head_dim] + cache_position: Current position in the cache (next position to write to) + """ + + keys: List[torch.Tensor] + values: List[torch.Tensor] + cache_position: int + + def pad_to_length(self, max_length: int) -> KVCache: + """Pad KV cache to a specified maximum length. + + Args: + max_length: Target length to pad the cache to. + + Returns: + New KVCache with padded keys and values. + """ + # k and v have shape [B, n_heads, T, head_dim] + cache_pad_length = max_length - self.keys[0].shape[2] + if cache_pad_length <= 0: + return self + + padded_keys = [] + padded_values = [] + for k, v in zip(self.keys, self.values): + # Pad along the sequence dimension (dim=2) + k_padded = torch.nn.functional.pad(k, (0, 0, 0, cache_pad_length)) + v_padded = torch.nn.functional.pad(v, (0, 0, 0, cache_pad_length)) + padded_keys.append(k_padded) + padded_values.append(v_padded) + + return KVCache(keys=padded_keys, values=padded_values, cache_position=self.cache_position) + + +def compute_positions(attention_mask: torch.Tensor) -> torch.Tensor: + """Compute positions from attention mask. + + Positions start at 0 from the first non-zero value in the attention mask + and increment sequentially. Supports left-padding with negative positions. + + Args: + attention_mask: [B, T] tensor with 1=valid token, 0=padding + + Returns: + positions: [B, T] tensor with positions (can be negative for left-padding) + """ + first_token_idx = attention_mask.argmax(dim=1, keepdim=True) # [B, 1] + seq_len = attention_mask.shape[1] + positions = torch.arange(seq_len, device=attention_mask.device)[None, :] - first_token_idx + return positions