diff --git a/FAST_LLM_WEIGHT_ACCESS.md b/FAST_LLM_WEIGHT_ACCESS.md new file mode 100644 index 000000000..18258e516 --- /dev/null +++ b/FAST_LLM_WEIGHT_ACCESS.md @@ -0,0 +1,94 @@ +# How to Access Weights in Fast-LLM + +## The Problem +When you load a Fast-LLM model and try to access weights through normal parameter attributes like `model.layer.weight` or `model.layer.bias`, they appear to be all zeros. This is misleading! + +## The Root Cause: FSDP Weight Management + +Fast-LLM uses a sophisticated FSDP (Fully Sharded Data Parallel) system for memory efficiency: + +1. **Weight Shard**: The actual weights are stored in a flat 1D tensor called `_weight_shard` +2. **Weight Buffer**: Parameters are views into `_weight_buffer` (starts as zeros) +3. **Restore Step**: `restore_parameters()` copies from shard to buffer before forward pass + +### Architecture + +``` +_weight_shard (1D tensor with actual data) + ↓ restore_parameters() +_weight_buffer (flat buffer, initially zeros) + ↓ views +parameters (.weight, .bias - show zeros until restored!) +``` + +## The Solution - Method 1: Access the Shard Directly + +```python +from fast_llm.engine.multi_stage.config import ShardName + +# Load model +model = GPTModel.from_pretrained(CheckpointLoadConfig(...)) + +# Get the actual weights (NOT through .weight or .bias!) +weights_shard = model.get_shard(ShardName.weights) # Returns a 1D tensor with ALL weights + +# weights_shard is a flat tensor containing all model weights +print(weights_shard.shape) # e.g., torch.Size([2804643712]) +print(weights_shard.sum()) # Non-zero! +``` + +## The Solution - Method 2: Restore Parameters First + +```python +# Load model +model = GPTModel.from_pretrained(CheckpointLoadConfig(...)) + +# Parameters show zeros BEFORE restore +print(model.base_model.decoder[0].mlp.router.bias.sum()) # 0.0 + +# Restore parameters from shard to buffer +for stage in model._stages: + stage.restore_parameters() + +# Parameters show actual weights AFTER restore +print(model.base_model.decoder[0].mlp.router.bias.sum()) # Non-zero! +``` + +## Why Parameters Show Zeros + +Fast-LLM's FSDP implementation: +- Creates parameters as **views into `_weight_buffer`** (see `fsdp.py:82-90`) +- The buffer starts empty (zeros) for memory efficiency +- `restore_parameters()` copies from `_weight_shard` to `_weight_buffer` (see `fsdp.py:181-189`) +- This happens automatically during forward pass (see `stage.py:121` - asserts `_is_restored`) + +## Important Notes + +1. **Forward pass calls restore automatically**: When you call `model(input)`, Fast-LLM internally calls `restore_parameters()` first +2. **Parameters are views**: Modifying parameters after restore modifies the buffer +3. **Chunking parameters**: If you chunk `.weight` or `.bias` before restore, you'll chunk zeros! + +## Verification Examples + +```python +# ❌ WRONG - will show zeros (before restore) +bias = model.decoder[0].mlp.layer_1.bias +print(bias[0, :10]) # All zeros! + +# ✅ CORRECT - access through shard +weights = model.get_shard(ShardName.weights) +print(weights.sum()) # Non-zero! +print(weights.count_nonzero()) # Many non-zero elements + +# ✅ ALSO CORRECT - restore first, then access parameters +for stage in model._stages: + stage.restore_parameters() +bias = model.decoder[0].mlp.layer_1.bias +print(bias.sum()) # Non-zero! +``` + +## References +- `fast_llm/engine/multi_stage/fsdp.py:82-90` - Parameter buffer creation +- `fast_llm/engine/multi_stage/fsdp.py:181-189` - `restore_parameters()` implementation +- `fast_llm/engine/multi_stage/stage.py:121` - Forward pass asserts `_is_restored` +- `tests/models/test_checkpoint.py:227` - Shard access example diff --git a/check_expert_weights.py b/check_expert_weights.py new file mode 100644 index 000000000..54a25e49d --- /dev/null +++ b/check_expert_weights.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +""" +Check if expert weights are in the correct order after conversion. +""" + +import pathlib + +import torch +import transformers + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat, ModelConfigType +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.models.gpt.model import GPTModel + +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + +print("=" * 80) +print("Checking Expert Weight Order") +print("=" * 80) + +# Load HF model +print("\n1. Loading HF model...") +hf_model = transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, +).cuda() + +hf_experts = hf_model.model.layers[0].mlp.experts + +# Load Fast-LLM model +print("2. Loading Fast-LLM model...") + +gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) +) + +# Wrap with HuggingFace interface +fast_llm_model = HuggingfaceGPTModelForCausalLM(gpt_model) + +# Get Fast-LLM MoE weights +# Access the GPT model's decoder layers +fast_llm_layer0_mlp = fast_llm_model.fast_llm_base_model.decoder[0].mlp + +# Get layer_1 (gate_up_proj) weight +# HF format: (num_experts, in_features, 2 * out_features) = (32, 2880, 5760) +# Fast-LLM format: (num_experts * 2 * out_features, in_features) = (184320, 2880) + +# Check expert 9 +expert_idx = 9 +in_features = 2880 +expert_dim = 2880 # out_features for MoE + +print(f"\n3. Comparing Expert {expert_idx} gate_up_proj weights...") + +# HF expert 9 gate_up weight +hf_gate_up_weight = hf_experts.gate_up_proj[expert_idx] # (in_features, 2*expert_dim) = (2880, 5760) +hf_gate_up_bias = hf_experts.gate_up_proj_bias[expert_idx] # (2*expert_dim,) = (5760,) + +print(f"HF gate_up_proj[{expert_idx}] shape: {hf_gate_up_weight.shape}") +print(f"HF gate_up_proj_bias[{expert_idx}] shape: {hf_gate_up_bias.shape}") +print(f"HF gate_up_proj[{expert_idx}] first 10 values: {hf_gate_up_weight[0, :10].float()}") +print(f"HF gate_up_proj_bias[{expert_idx}] first 10 values: {hf_gate_up_bias[:10].float()}") + +# Fast-LLM expert 9 gate_up weight +# layer_1.weight is (num_experts * 2 * expert_dim, in_features) = (184320, 2880) +# According to the converter at line 186: weight_per_expert = torch.cat([gate_t, up_t], dim=1) +# This concatenates gate and up FOR EACH EXPERT, then reshapes +# So it's: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...] +# This is INTERLEAVED per expert! + +fast_llm_layer1_weight = fast_llm_layer0_mlp.layer_1.weight # (184320, 2880) +fast_llm_layer1_bias = fast_llm_layer0_mlp.layer_1.bias # (32, 5760) per-expert biases + +num_experts = 32 + +# Extract expert 9's gate and up weights +# Each expert has 2 * expert_dim rows: first expert_dim rows are gate, next expert_dim rows are up +expert_start = expert_idx * 2 * expert_dim +expert_gate_start = expert_start +expert_gate_end = expert_start + expert_dim +expert_up_start = expert_start + expert_dim +expert_up_end = expert_start + 2 * expert_dim + +fast_llm_expert9_gate = fast_llm_layer1_weight[expert_gate_start:expert_gate_end, :] # (expert_dim, in_features) +fast_llm_expert9_up = fast_llm_layer1_weight[expert_up_start:expert_up_end, :] # (expert_dim, in_features) + +# Biases are per-expert: (32, 5760) where 5760 = 2 * expert_dim (gate and up interleaved) +if fast_llm_layer1_bias is not None: + fast_llm_expert9_bias = fast_llm_layer1_bias[expert_idx, :] # (5760,) + # De-interleave + fast_llm_expert9_gate_bias = fast_llm_expert9_bias[0::2] # (expert_dim,) + fast_llm_expert9_up_bias = fast_llm_expert9_bias[1::2] # (expert_dim,) +else: + fast_llm_expert9_gate_bias = None + fast_llm_expert9_up_bias = None + +print(f"\nFast-LLM expert {expert_idx} gate weight shape: {fast_llm_expert9_gate.shape}") +print(f"Fast-LLM expert {expert_idx} up weight shape: {fast_llm_expert9_up.shape}") +print(f"Fast-LLM expert {expert_idx} gate weight first 10 values (row 0): {fast_llm_expert9_gate[0, :10].float()}") +if fast_llm_expert9_gate_bias is not None: + print(f"Fast-LLM expert {expert_idx} gate bias first 10 values: {fast_llm_expert9_gate_bias[:10].float()}") + +# Compare +# HF: input @ weight + bias, where weight is (in_features, 2*expert_dim) interleaved +# For comparison, extract HF gate and up separately +hf_gate_weight = hf_gate_up_weight[:, 0::2] # (in_features, expert_dim) +hf_up_weight = hf_gate_up_weight[:, 1::2] # (in_features, expert_dim) +hf_gate_bias = hf_gate_up_bias[0::2] # (expert_dim,) +hf_up_bias = hf_gate_up_bias[1::2] # (expert_dim,) + +print(f"\nHF expert {expert_idx} gate weight (de-interleaved) shape: {hf_gate_weight.shape}") +print(f"HF expert {expert_idx} gate weight first 10 values (row 0): {hf_gate_weight[0, :10].float()}") +print(f"HF expert {expert_idx} gate bias first 10 values: {hf_gate_bias[:10].float()}") + +# Fast-LLM weight is transposed compared to HF +# HF: (in_features, expert_dim) +# Fast-LLM: (expert_dim, in_features) +# So we need to transpose Fast-LLM to compare +fast_llm_expert9_gate_transposed = fast_llm_expert9_gate.t() # (in_features, expert_dim) + +print(f"\n4. Comparison:") +print(f"HF gate weight [0, :10]: {hf_gate_weight[0, :10].float()}") +print(f"Fast-LLM gate weight [0, :10] (transposed): {fast_llm_expert9_gate_transposed[0, :10].float()}") + +diff = (hf_gate_weight.float() - fast_llm_expert9_gate_transposed.float()).abs() +print(f"Max diff: {diff.max().item():.6f}") +print(f"Mean diff: {diff.mean().item():.6f}") + +if diff.max().item() < 1e-5: + print("\n✅ Expert 9 gate weights match!") +else: + print(f"\n❌ Expert 9 gate weights DO NOT match! Max diff = {diff.max().item():.6f}") + +print("\n" + "=" * 80) diff --git a/compare_mlp_traces.py b/compare_mlp_traces.py new file mode 100644 index 000000000..5bd0d2104 --- /dev/null +++ b/compare_mlp_traces.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +""" +Compare MLP component traces between HF and Fast-LLM using instrumented code. +""" + +import pathlib + +import torch +import transformers + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat, ModelConfigType +from fast_llm.functional.triton import mlp as mlp_module +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.models.gpt.model import GPTModel + +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + +print("=" * 80) +print("Comparing MLP Traces: HF vs Fast-LLM") +print("=" * 80) + +# Create small test input for detailed tracing +torch.manual_seed(42) +test_input = torch.randint(0, 201088, size=(1, 4), dtype=torch.int64, device="cuda") +print(f"\nTest input: {test_input}") + +# ============================================================================== +# Part 1: HF Model - Manual Tracing +# ============================================================================== +print("\n" + "=" * 80) +print("Part 1: HuggingFace Model - Manual Component Tracing") +print("=" * 80) + +hf_model = ( + transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +hf_traces = {} + + +def make_hook(name): + def hook(module, input, output): + if isinstance(input, tuple): + hf_traces[f"{name}_input"] = input[0].detach().float() + else: + hf_traces[f"{name}_input"] = input.detach().float() + if isinstance(output, tuple): + hf_traces[f"{name}_output"] = output[0].detach().float() + else: + hf_traces[f"{name}_output"] = output.detach().float() + + return hook + + +layer0 = hf_model.model.layers[0] +layer0.post_attention_layernorm.register_forward_hook(make_hook("norm2")) +layer0.mlp.register_forward_hook(make_hook("mlp")) +layer0.mlp.experts.register_forward_hook(make_hook("experts")) + +with torch.no_grad(): + hf_output = hf_model(test_input) + +mlp_input = hf_traces["norm2_output"] + +print(f"\n1. MLP Input (after norm2):") +print(f" shape={mlp_input.shape}, mean={mlp_input.float().mean():.6f}, std={mlp_input.float().std():.6f}") +print(f" first token [0, 0, :10]: {mlp_input[0, 0, :10].float()}") + +# Manual MLP forward to trace components +mlp = layer0.mlp +experts = mlp.experts + +with torch.no_grad(): + + # Router (convert back to bfloat16 for HF model operations) + mlp_input_bf16 = mlp_input.to(torch.bfloat16) + router_scores, router_indices = mlp.router(mlp_input_bf16.flatten(0, 1)) + + print(f"\n2. Router:") + print(f" scores shape={router_scores.shape}, indices shape={router_indices.shape}") + print(f" first token top-4 experts: {router_indices[0]}") + print(f" first token top-4 scores: {router_scores[0]}") + + # Process first token through first expert + first_token = mlp_input_bf16[0, 0:1, :] # (1, hidden_size) + expert_idx = router_indices[0, 0].item() + expert_score = router_scores[0, expert_idx].item() # Get score for this specific expert + + print(f"\n3. Processing token through expert {expert_idx}:") + + # gate_up_proj + gate_up = first_token @ experts.gate_up_proj[expert_idx] + experts.gate_up_proj_bias[expert_idx] + print(f" gate_up shape={gate_up.shape}, mean={gate_up.float().mean():.6f}, std={gate_up.float().std():.6f}") + print(f" gate_up [0, :10]: {gate_up[0, :10].float()}") + + # De-interleave + gate = gate_up[..., 0::2] + up = gate_up[..., 1::2] + print(f" gate mean={gate.float().mean():.6f}, std={gate.float().std():.6f}") + print(f" gate [0, :10]: {gate[0, :10].float()}") + print(f" up mean={up.float().mean():.6f}, std={up.float().std():.6f}") + print(f" up [0, :10]: {up[0, :10].float()}") + + # Activation + alpha = 1.702 + limit = 7.0 + gate_clamped = gate.clamp(min=None, max=limit) + up_clamped = up.clamp(min=-limit, max=limit) + glu = gate_clamped * torch.sigmoid(gate_clamped * alpha) + activated = (up_clamped + 1) * glu + + print(f" activated mean={activated.float().mean():.6f}, std={activated.float().std():.6f}") + print(f" activated [0, :10]: {activated[0, :10].float()}") + + # down_proj + down_out = activated @ experts.down_proj[expert_idx] + experts.down_proj_bias[expert_idx] + weighted_out = down_out * expert_score + + print(f" down_proj mean={down_out.float().mean():.6f}, std={down_out.float().std():.6f}") + print(f" down_proj [0, :10]: {down_out[0, :10].float()}") + print(f" weighted (score={expert_score:.4f}) [0, :10]: {weighted_out[0, :10].float()}") + + # Full MLP + mlp_out, _ = mlp(mlp_input_bf16.flatten(0, 1)) + mlp_out = mlp_out.view_as(mlp_input_bf16) + + print(f"\n4. Full MLP output:") + print(f" shape={mlp_out.shape}, mean={mlp_out.float().mean():.6f}, std={mlp_out.float().std():.6f}") + print(f" first token [0, 0, :10]: {mlp_out[0, 0, :10].float()}") + +del hf_model +torch.cuda.empty_cache() + +# ============================================================================== +# Part 2: Fast-LLM Model - Using Instrumented Code +# ============================================================================== +print("\n" + "=" * 80) +print("Part 2: Fast-LLM Model - Instrumented Tracing") +print("=" * 80) + +# Clear traces +mlp_module._MLP_DEBUG_TRACES.clear() + + +# Load GPT model first, then wrap +gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) +) +fast_llm_model = HuggingfaceGPTModelForCausalLM(gpt_model) + +with torch.no_grad(): + fl_output = fast_llm_model(test_input) + +# Print Fast-LLM traces +traces = mlp_module._MLP_DEBUG_TRACES + +print(f"\nFast-LLM traced:") +print(f" - {len(traces.get('looped_inputs', []))} looped MLP calls") +print(f" - {len(traces.get('mlp_inputs', []))} mlp_forward calls") + +if traces.get("looped_inputs"): + print(f"\n1. Looped MLP Input (first call, first token):") + looped_in = traces["looped_inputs"][0] + hidden = looped_in["hidden_states"] + scores = looped_in["scores"] + top_experts = looped_in["top_experts"] + + print(f" hidden_states shape={hidden.shape}, mean={hidden.mean():.6f}, std={hidden.std():.6f}") + print(f" hidden_states [0, :10]: {hidden[0, :10]}") + print(f" top_experts: {top_experts[0]}") + print(f" scores: {scores[0]}") + +if traces.get("looped_outputs"): + print(f"\n2. Looped MLP Output (first call, first token):") + looped_out = traces["looped_outputs"][0] + print(f" shape={looped_out.shape}, mean={looped_out.mean():.6f}, std={looped_out.std():.6f}") + print(f" [0, :10]: {looped_out[0, :10]}") + +if traces.get("mlp_inputs"): + print(f"\n1. MLP Forward Input (first call):") + mlp_in = traces["mlp_inputs"][0] + input_tensor = mlp_in["input"] + scores_tensor = mlp_in["scores"] + sparse_used = mlp_in["sparse_map_used"] + + print(f" input shape={input_tensor.shape}, mean={input_tensor.mean():.6f}, std={input_tensor.std():.6f}") + print(f" input [0, :10]: {input_tensor[0, :10]}") + if scores_tensor is not None: + print(f" scores shape={scores_tensor.shape}: {scores_tensor[0]}") + print(f" sparse_map used: {sparse_used}") + +if traces.get("mlp_outputs"): + print(f"\n2. MLP Forward Output (first call):") + mlp_out = traces["mlp_outputs"][0] + output_tensor = mlp_out["output"] + out_shape = mlp_out["shape"] + + print(f" output full shape={out_shape}") + print( + f" output (first token) shape={output_tensor.shape}, mean={output_tensor.mean():.6f}, std={output_tensor.std():.6f}" + ) + print(f" output [0, :10]: {output_tensor[0, :10]}") + +print("\n" + "=" * 80) +print("✅ Tracing complete! Compare the values above.") +print("=" * 80) diff --git a/debug_moe_routing.py b/debug_moe_routing.py new file mode 100644 index 000000000..ffb1bed08 --- /dev/null +++ b/debug_moe_routing.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +Debug MoE routing to understand expert selection and scoring differences. +""" + +import pathlib + +import torch +import transformers + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat, ModelConfigType +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.models.gpt.model import GPTModel + +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + +print("=" * 80) +print("Debug MoE Routing") +print("=" * 80) + +# Create test input +torch.manual_seed(42) +test_input = torch.randint(0, 201088, size=(1, 4), dtype=torch.int64, device="cuda") +print(f"\nTest input: {test_input}") + +# ============================================================================== +# Part 1: HF Model Router +# ============================================================================== +print("\n" + "=" * 80) +print("Part 1: HuggingFace Model Router") +print("=" * 80) + +hf_model = ( + transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +# Get embeddings and first norm +with torch.no_grad(): + hidden_states = hf_model.model.embed_tokens(test_input) # (1, 4, 2880) + hidden_states = hf_model.model.layers[0].input_layernorm(hidden_states) + + # Attention + attn_output = hf_model.model.layers[0].self_attn(hidden_states)[0] + hidden_states = hidden_states + attn_output + + # Pre-MLP norm + residual = hidden_states + hidden_states = hf_model.model.layers[0].post_attention_layernorm(hidden_states) + + print(f"\nMLP input shape: {hidden_states.shape}") + print(f"MLP input [0, 0, :10]: {hidden_states[0, 0, :10].float()}") + + # Router + router = hf_model.model.layers[0].mlp.router + print(f"\nRouter weight shape: {router.weight.shape}") + print(f"Router bias shape: {router.bias.shape if router.bias is not None else None}") + + # Flatten for router + hidden_states_flat = hidden_states.flatten(0, 1) # (4, 2880) + router_logits, router_indices = router(hidden_states_flat) + + print(f"\nRouter logits shape: {router_logits.shape}") + print(f"Router indices shape: {router_indices.shape}") + print(f"\nFirst token router logits (all 32): {router_logits[0].float()}") + print(f"First token top-4 indices: {router_indices[0]}") + print(f"First token top-4 scores: {router_logits[0, router_indices[0]].float()}") + +del hf_model +torch.cuda.empty_cache() + +# ============================================================================== +# Part 2: Fast-LLM Model Router +# ============================================================================== +print("\n" + "=" * 80) +print("Part 2: Fast-LLM Model Router") +print("=" * 80) + +gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) +) +fast_llm_model = HuggingfaceGPTModelForCausalLM(gpt_model) + +# Run forward to get internal activations +with torch.no_grad(): + output = fast_llm_model(test_input) + +print(f"\nFast-LLM model config:") +print(f" experts: {gpt_model.config.base_model.decoder.blocks.full.mlp.experts}") +print(f" experts_per_token: {gpt_model.config.base_model.decoder.blocks.full.mlp.experts_per_token}") + +print("\n" + "=" * 80) diff --git a/debug_router_comparison.py b/debug_router_comparison.py new file mode 100644 index 000000000..178d2fbeb --- /dev/null +++ b/debug_router_comparison.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Compare router outputs between HF and Fast-LLM to see if routing is consistent. +""" + +import pathlib +import torch +import transformers + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat, ModelConfigType +from fast_llm.models.gpt.model import GPTModel + +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + +# Create test input +torch.manual_seed(42) +test_input_bf16 = torch.rand(1, 2880, device="cuda", dtype=torch.bfloat16) # Single token for HF +test_input = test_input_bf16.float() # Float32 for Fast-LLM + +print("=" * 80) +print("Testing Router Outputs") +print("=" * 80) + +# ================================================================================ +# HF Model - Router +# ================================================================================ +print("\n1. HuggingFace Model - Router for Layer 0") +hf_model = ( + transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +layer0_mlp = hf_model.model.layers[0].mlp + +with torch.no_grad(): + # Get router logits + router_logits = test_input_bf16 @ layer0_mlp.router.weight.t() + layer0_mlp.router.bias + print(f" Router logits shape: {router_logits.shape}") + print(f" Router logits [:10]: {router_logits[0, :10].float()}") + print(f" Router logits [9]: {router_logits[0, 9].float()}") + + # Get top-k experts (k=4) + router_probs = torch.nn.functional.softmax(router_logits, dim=-1) + top_k_probs, top_k_indices = torch.topk(router_probs, k=4, dim=-1) + + print(f" Top-4 expert indices: {top_k_indices[0]}") + print(f" Top-4 expert probs: {top_k_probs[0].float()}") + print(f" Top-4 expert probs (normalized): {(top_k_probs / top_k_probs.sum(dim=-1, keepdim=True))[0].float()}") + +del hf_model +torch.cuda.empty_cache() + +# ================================================================================ +# Fast-LLM Model - Router +# ================================================================================ +print("\n2. Fast-LLM Model - Router for Layer 0") + +gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) +) + +# Restore parameters +for stage in gpt_model._stages: + stage.restore_parameters() + +layer0_mlp_fast = gpt_model.base_model.decoder[0].mlp +router_weight = layer0_mlp_fast.router.weight +router_bias = layer0_mlp_fast.router.bias + +print(f" Router weight shape: {router_weight.shape}") +print(f" Router bias shape: {router_bias.shape}") + +with torch.no_grad(): + # Get router logits + router_logits_fast = torch.nn.functional.linear(test_input, router_weight, router_bias) + print(f" Router logits shape: {router_logits_fast.shape}") + print(f" Router logits [:10]: {router_logits_fast[0, :10]}") + print(f" Router logits [9]: {router_logits_fast[0, 9]}") + + # Get top-k experts (k=4) + router_probs_fast = torch.nn.functional.softmax(router_logits_fast, dim=-1) + top_k_probs_fast, top_k_indices_fast = torch.topk(router_probs_fast, k=4, dim=-1) + + print(f" Top-4 expert indices: {top_k_indices_fast[0]}") + print(f" Top-4 expert probs: {top_k_probs_fast[0]}") + print(f" Top-4 expert probs (normalized): {(top_k_probs_fast / top_k_probs_fast.sum(dim=-1, keepdim=True))[0]}") + +print("\n" + "=" * 80) +print("Comparison:") +print(" Router outputs match!" if torch.allclose(router_logits.float(), router_logits_fast, rtol=1e-3) else " Router outputs differ!") +print("=" * 80) diff --git a/debug_single_expert.py b/debug_single_expert.py new file mode 100644 index 000000000..1e55209ad --- /dev/null +++ b/debug_single_expert.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +Debug single expert processing to find where HF and Fast-LLM diverge. +""" + +import pathlib +import torch +import transformers + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat, ModelConfigType +from fast_llm.models.gpt.model import GPTModel + +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + +# Create test input +torch.manual_seed(42) +test_input_bf16 = torch.rand(1, 2880, device="cuda", dtype=torch.bfloat16) # Single token for HF +test_input = test_input_bf16.float() # Float32 for Fast-LLM + +print("=" * 80) +print("Testing Single Expert Processing") +print("=" * 80) + +# ================================================================================ +# HF Model - Expert 9 +# ================================================================================ +print("\n1. HuggingFace Model - Expert 9") +hf_model = ( + transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +layer0 = hf_model.model.layers[0] +experts = layer0.mlp.experts +expert_idx = 9 + +with torch.no_grad(): + # gate_up_proj + gate_up = test_input_bf16 @ experts.gate_up_proj[expert_idx] + experts.gate_up_proj_bias[expert_idx] + print(f" gate_up shape: {gate_up.shape}, mean: {gate_up.float().mean():.6f}") + print(f" gate_up [:10]: {gate_up[0, :10].float()}") + + # De-interleave + gate = gate_up[..., 0::2] + up = gate_up[..., 1::2] + print(f" gate [:10]: {gate[0, :10].float()}") + print(f" up [:10]: {up[0, :10].float()}") + + # Activation + alpha = 1.702 + limit = 7.0 + gate_clamped = gate.clamp(max=limit) + up_clamped = up.clamp(min=-limit, max=limit) + glu = gate_clamped * torch.sigmoid(gate_clamped * alpha) + activated = (up_clamped + 1) * glu + + print(f" activated shape: {activated.shape}, mean: {activated.float().mean():.6f}") + print(f" activated [:10]: {activated[0, :10].float()}") + + # down_proj + down_out = activated @ experts.down_proj[expert_idx] + experts.down_proj_bias[expert_idx] + + print(f" down_out shape: {down_out.shape}, mean: {down_out.float().mean():.6f}") + print(f" down_out [:10]: {down_out[0, :10].float()}") + +del hf_model +torch.cuda.empty_cache() + +# ================================================================================ +# Fast-LLM Model - Expert 9 +# ================================================================================ +print("\n2. Fast-LLM Model - Expert 9") + +gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) +) + +# Restore parameters +for stage in gpt_model._stages: + stage.restore_parameters() + +layer0_mlp = gpt_model.base_model.decoder[0].mlp +weight_1 = layer0_mlp.layer_1.weight +bias_1 = layer0_mlp.layer_1.bias +weight_2 = layer0_mlp.layer_2.weight +bias_2 = layer0_mlp.layer_2.bias + +# Chunk to get expert 9 +weight_1_chunks = weight_1.chunk(32) +bias_1_chunks = bias_1.chunk(32) +weight_2_chunks = weight_2.chunk(32) +bias_2_chunks = bias_2.chunk(32) + +weight_1_expert9 = weight_1_chunks[9] # (5760, 2880) +bias_1_expert9 = bias_1_chunks[9].squeeze(0) # (5760,) +weight_2_expert9 = weight_2_chunks[9] # (2880, 2880) - transposed +bias_2_expert9 = bias_2_chunks[9].squeeze(0) # (2880,) + +print(f" weight_1_expert9 shape: {weight_1_expert9.shape}") +print(f" bias_1_expert9 shape: {bias_1_expert9.shape}") +print(f" weight_2_expert9 shape: {weight_2_expert9.shape}") +print(f" bias_2_expert9 shape: {bias_2_expert9.shape}") + +with torch.no_grad(): + # Layer 1: gate_up projection (weight is already concatenated, not interleaved) + gate_up = torch.nn.functional.linear(test_input, weight_1_expert9, bias_1_expert9) + print(f" gate_up shape: {gate_up.shape}, mean: {gate_up.float().mean():.6f}") + print(f" gate_up [:10]: {gate_up[0, :10].float()}") + + # Split into gate and up (already concatenated in Fast-LLM format) + gate, up = gate_up.chunk(2, dim=-1) + print(f" gate [:10]: {gate[0, :10].float()}") + print(f" up [:10]: {up[0, :10].float()}") + + # Activation (same as HF) + alpha = 1.702 + limit = 7.0 + gate_clamped = gate.clamp(max=limit) + up_clamped = up.clamp(min=-limit, max=limit) + glu = gate_clamped * torch.sigmoid(gate_clamped * alpha) + activated = (up_clamped + 1) * glu + + print(f" activated shape: {activated.shape}, mean: {activated.float().mean():.6f}") + print(f" activated [:10]: {activated[0, :10].float()}") + + # Layer 2: down projection + # Test both with and without transpose + print(f"\n Testing weight_2 transpose:") + print(f" weight_2_expert9 shape: {weight_2_expert9.shape}") + + # Option 1: With transpose + down_out_with_t = torch.nn.functional.linear(activated, weight_2_expert9.t(), bias_2_expert9) + print(f" WITH transpose: down_out shape: {down_out_with_t.shape}, mean: {down_out_with_t.float().mean():.6f}") + print(f" WITH transpose: down_out [:10]: {down_out_with_t[0, :10].float()}") + + # Option 2: Without transpose (treating weight_2 as already transposed) + down_out_no_t = activated @ weight_2_expert9.t() + bias_2_expert9 + print(f" Matmul (@): down_out shape: {down_out_no_t.shape}, mean: {down_out_no_t.float().mean():.6f}") + print(f" Matmul (@): down_out [:10]: {down_out_no_t[0, :10].float()}") + + # Option 3: Direct use without any transpose + down_out_direct = activated @ weight_2_expert9 + bias_2_expert9 + print(f" Direct (no .t()): down_out shape: {down_out_direct.shape}, mean: {down_out_direct.float().mean():.6f}") + print(f" Direct (no .t()): down_out [:10]: {down_out_direct[0, :10].float()}") + +print("\n" + "=" * 80) +print("Comparison complete!") +print("=" * 80) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 96fb53321..270171755 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -120,14 +120,14 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: cls.base_model_converter_class.export_config(config.base_model), { "model_type": cls.get_huggingface_model_type(), - "architecture": cls.architecture, + "architectures": [cls.architecture], }, ) @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - Assert.eq(config["architecture"], cls.architecture) + Assert.eq(config["architectures"], [cls.architecture]) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 684193848..b679c0bfa 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -44,6 +44,7 @@ class ActivationType(enum.StrEnum): relu = "relu" squared_relu = "squared_relu" identity = "identity" + gpt_oss_glu = "gpt_oss_glu" # Custom GLU for GPT-OSS: (up + 1) * (gate * sigmoid(gate * 1.702)) @property def activation_fn(self) -> typing.Callable[["torch.Tensor"], "torch.Tensor"]: @@ -66,12 +67,25 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP + def gpt_oss_glu_activation(x: torch.Tensor) -> torch.Tensor: + # Custom GPT-OSS GLU: (up + 1) * (gate * sigmoid(gate * 1.702)) + # Input x has shape [..., 2*dim] where first half is gate, second half is up + # Includes clamping: gate max 7.0, up in [-7.0, 7.0] + gate, up = x.chunk(2, dim=-1) + alpha = 1.702 + limit = 7.0 + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + return (up + 1.0) * glu + _ACTIVATION_FN_MAP = { ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), ActivationType.identity: lambda x: x, + ActivationType.gpt_oss_glu: gpt_oss_glu_activation, } @@ -83,6 +97,7 @@ def _set_activation_fn_map() -> None: ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", ActivationType.identity: "identity", + ActivationType.gpt_oss_glu: "gpt_oss_glu", # Custom activation for GPT-OSS } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index dbc05184d..af056a4a4 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -65,8 +65,24 @@ def update_linear_gradients( ) else: accumulate_gradient(weight, torch.mm(lhs, rhs)) + + # Bias gradients if bias is not None and bias.requires_grad: - accumulate_gradient(bias, grad_output.sum(dim=0)) + if sparse_map is not None and bias.ndim == 2: + # For sparse maps with 2D bias: bias has shape (num_experts, out_features_per_expert) + # This is the case for manually created MoE biases (e.g., layer_2 in MoE) + # Need to sum gradients per expert + grad_bias = torch.zeros_like(bias) + for expert_idx in range(sparse_map.num_experts): + expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item() + expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item() + # Sum gradients only from unpadded rows + if expert_begin < expert_pad_begin: + grad_bias[expert_idx].copy_(grad_output[expert_begin:expert_pad_begin].sum(dim=0)) + accumulate_gradient(bias, grad_bias) + else: + # For 1D bias (including sparse maps where bias already has experts in flattened dim) + accumulate_gradient(bias, grad_output.sum(dim=0)) def linear_forward( @@ -115,7 +131,6 @@ def output_parallel_linear_forward( # Matmul if TritonConfig.TRITON_LINEAR or sparse_map is not None: - assert bias is None if sparse_map is not None: assert not transposed_weight output = output_sparse_matmul( @@ -123,6 +138,23 @@ def output_parallel_linear_forward( maybe_transpose(weight, not transposed_weight), sparse_map, ).unflatten(0, input_.shape[:-1]) + # Add bias if present (for sparse maps, bias has expert dimension) + if bias is not None: + if sparse_map is not None: + # bias shape: (num_experts, out_features_per_expert) + # We need to add the correct expert's bias to each row + # sparse_map tells us which expert each row belongs to + output_flat = output.flatten(0, -2) + for expert_idx in range(sparse_map.num_experts): + expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item() + expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item() + # Add bias only to unpadded rows + if expert_begin < expert_pad_begin: + output_flat[expert_begin:expert_pad_begin] += bias[expert_idx] + output = output_flat.unflatten(0, input_.shape[:-1]) + else: + # Regular bias for non-sparse case + output = output + bias else: output = torch.nn.functional.linear(input1, maybe_transpose(weight, transposed_weight), bias) @@ -179,12 +211,28 @@ def input_parallel_linear_forward( ) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: # Matmul if TritonConfig.TRITON_LINEAR or sparse_map is not None: - assert bias is None if sparse_map is not None: assert transposed_weight output = input_inner_sparse_matmul( input_.flatten(0, -2), maybe_transpose(weight, not transposed_weight), sparse_map ).unflatten(0, input_.shape[:-1]) + # Add bias if present (for sparse maps, bias has expert dimension) + if bias is not None: + if sparse_map is not None: + # bias shape: (num_experts, out_features_per_expert) + # We need to add the correct expert's bias to each row + # sparse_map tells us which expert each row belongs to + output_flat = output.flatten(0, -2) + for expert_idx in range(sparse_map.num_experts): + expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item() + expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item() + # Add bias only to unpadded rows + if expert_begin < expert_pad_begin: + output_flat[expert_begin:expert_pad_begin] += bias[expert_idx] + output = output_flat.unflatten(0, input_.shape[:-1]) + else: + # Regular bias for non-sparse case + output = output + bias else: output = torch.nn.functional.linear(input_, maybe_transpose(weight, transposed_weight), bias) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ab408368f..3e0ba373f 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -25,6 +25,9 @@ from fast_llm.functional.triton.sparse_linear import output_sparse_matmul from fast_llm.tensor import param_get_and_unset_is_zero +# Global dictionary for debugging MLP intermediate values +_MLP_DEBUG_TRACES = {} + @triton_jit() def triton_mlp_activation_forward_kernel( @@ -61,10 +64,22 @@ def triton_mlp_activation_forward_kernel( out = relu_out * relu_out elif activation_type == "identity": out = input_ + elif activation_type == "gpt_oss_glu": + # GPT-OSS custom GLU: (up + 1) * (gate * sigmoid(gate * 1.702)) + # For gated=True, input_ is gate, other (loaded below) is up + # Includes clamping: gate max 7.0, up in [-7.0, 7.0] + tl.static_assert(gated, "gpt_oss_glu requires gated=True") + other = tl.load(input_ptr + n_cols, mask=mask) + # Clamp gate to max 7.0 + gate_clamped = tl.minimum(input_, 7.0) + # Clamp up to [-7.0, 7.0] + up_clamped = tl.minimum(tl.maximum(other, -7.0), 7.0) + glu = gate_clamped * (1.0 / (1.0 + tl.exp(-gate_clamped * 1.702))) # gate * sigmoid(gate * 1.702) + out = (up_clamped + 1.0) * glu else: tl.static_assert(False, activation_type) - if gated: + if gated and activation_type != "gpt_oss_glu": other = tl.load(input_ptr + n_cols, mask=mask) out = out * other @@ -124,15 +139,39 @@ def triton_mlp_activation_backward_kernel( grad = 1 if gated or recompute: out = input_ + elif activation_type == "gpt_oss_glu": + # GPT-OSS custom GLU: out = (up + 1) * (gate * sigmoid(gate * 1.702)) + # input_ is gate, other is up + # Includes clamping: gate max 7.0, up in [-7.0, 7.0] + tl.static_assert(gated, "gpt_oss_glu requires gated=True") + other = tl.load(input_ptr + n_cols, mask=mask) + alpha = 1.702 + # Clamp gate to max 7.0 + gate_clamped = tl.minimum(input_, 7.0) + # Clamp up to [-7.0, 7.0] + up_clamped = tl.minimum(tl.maximum(other, -7.0), 7.0) + sigma = 1.0 / (1.0 + tl.exp(-gate_clamped * alpha)) # sigmoid(gate * alpha) + glu = gate_clamped * sigma + # grad_gate = (up + 1) * d_glu/d_gate = (up + 1) * sigma * (1 + gate * alpha * (1 - sigma)) + # Only backprop through gate if it wasn't clamped (input_ <= 7.0) + grad_glu = sigma * (1.0 + gate_clamped * alpha * (1.0 - sigma)) + grad_gate = tl.where(input_ <= 7.0, (up_clamped + 1.0) * grad_glu, 0.0) + # grad_up = glu = gate * sigma + # Only backprop through up if it wasn't clamped (other in [-7.0, 7.0]) + grad_up = tl.where((other >= -7.0) & (other <= 7.0), glu, 0.0) + tl.store(grad_input_ptr, grad_gate * output_grad, mask=mask) + tl.store(grad_input_ptr + n_cols, grad_up * output_grad, mask=mask) + if recompute: + out = (up_clamped + 1.0) * glu else: tl.static_assert(False, activation_type) - if gated: + if gated and activation_type != "gpt_oss_glu": other = tl.load(input_ptr + n_cols, mask=mask) tl.store(grad_input_ptr, grad * other * output_grad, mask=mask) tl.store(grad_input_ptr + n_cols, out * output_grad, mask=mask) # noqa out = out * other - else: + elif not gated: tl.store(grad_input_ptr, grad * output_grad, mask=mask) if recompute: @@ -197,11 +236,27 @@ def torch_mlp_activation( gated: bool, activation_type: ActivationType, ) -> torch.Tensor: - if gated: + # DEBUG: Save activation input + if "activation_inputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["activation_inputs"] = [] + _MLP_DEBUG_TRACES["activation_inputs"].append(input_.detach().cpu()[:1]) # Save first token only + + # GPT-OSS GLU handles the gating internally, not via standard pattern + if activation_type == ActivationType.gpt_oss_glu: + assert gated, "gpt_oss_glu requires gated=True" + result = activation_type.activation_fn(input_) + elif gated: x1, x2 = input_.chunk(2, dim=-1) - return activation_type.activation_fn(x1) * x2 + result = activation_type.activation_fn(x1) * x2 else: - return activation_type.activation_fn(input_) + result = activation_type.activation_fn(input_) + + # DEBUG: Save activation output + if "activation_outputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["activation_outputs"] = [] + _MLP_DEBUG_TRACES["activation_outputs"].append(result.detach().cpu()[:1]) # Save first token only + + return result def mlp_forward( @@ -220,6 +275,17 @@ def mlp_forward( transposed_layer_2_weight: bool = False, sparse_map: SparseMap | None = None, ) -> tuple[torch.Tensor, list[typing.Any] | None]: + # DEBUG: Save MLP input (including scores for MoE) + if "mlp_inputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["mlp_inputs"] = [] + _MLP_DEBUG_TRACES["mlp_inputs"].append( + { + "input": input_.detach().cpu()[:1], # First token only + "scores": scores.detach().cpu()[:1] if scores is not None else None, # First token scores + "sparse_map_used": sparse_map is not None, + } + ) + # Sparse copy input_shape = input_.shape intermediate_0 = input_ if sparse_map is None else copy_dense_to_sparse_forward(input_, sparse_map)[0] @@ -229,6 +295,11 @@ def mlp_forward( intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map ) + # DEBUG: Save layer1 output + if "layer1_outputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["layer1_outputs"] = [] + _MLP_DEBUG_TRACES["layer1_outputs"].append(intermediate_1.detach().cpu()[:1]) # Save first token only + if recompute_level.recompute_sparse_input: intermediate_0 = None else: @@ -257,6 +328,13 @@ def mlp_forward( sparse_map, ) + # DEBUG: Save layer2 output + if "layer2_outputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["layer2_outputs"] = [] + _MLP_DEBUG_TRACES["layer2_outputs"].append( + intermediate_3.detach().cpu()[:1] if sparse_map is None else intermediate_3.detach().cpu() + ) # Save first token + # Context if recompute_level.recompute_activation or not training: intermediate_2 = None @@ -268,6 +346,16 @@ def mlp_forward( else: output, _ = copy_sparse_to_dense_forward(intermediate_3, scores, sparse_map) + # DEBUG: Save final MLP output + if "mlp_outputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["mlp_outputs"] = [] + _MLP_DEBUG_TRACES["mlp_outputs"].append( + { + "output": output.detach().cpu()[:1], # First token only + "shape": output.shape, + } + ) + context = ( [ input_, @@ -459,7 +547,20 @@ def mlp_autograd_looped( sequence_parallel: bool, training: bool = True, recompute_level: MLPRecomputeLevel = MLPRecomputeLevel.none, + bias_1: torch.Tensor | None = None, + bias_2: torch.Tensor | None = None, ) -> torch.Tensor: + # DEBUG: Save looped MLP inputs + if "looped_inputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["looped_inputs"] = [] + _MLP_DEBUG_TRACES["looped_inputs"].append( + { + "hidden_states": hidden_states.detach().cpu()[:1], # First token + "scores": scores.detach().cpu()[:1], # First token scores + "top_experts": top_experts.detach().cpu()[:1], # First token expert indices + } + ) + # TODO: Needed? scores = scores.to(hidden_states.dtype) expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=num_experts).permute(2, 1, 0) @@ -468,7 +569,50 @@ def mlp_autograd_looped( hidden_states, weight_1_chunked = chunk_weight(hidden_states, weight_1, num_experts) hidden_states, weight_2_t_chunked = chunk_weight(hidden_states, weight_2, num_experts) - for expert_idx, (weight_1_chunk, weight_2_t_chunk) in enumerate(zip(weight_1_chunked, weight_2_t_chunked)): + # Chunk biases if present + if bias_1 is not None: + _, bias_1_chunked_raw = chunk_weight(hidden_states, bias_1, num_experts) + # Squeeze chunked biases to 1D since torch.nn.functional.linear expects 1D bias + # Preserve grad_buffer attribute after squeezing + bias_1_chunked = [] + for b in bias_1_chunked_raw: + squeezed = b.squeeze(0) if b.ndim == 2 else b + if hasattr(b, "grad_buffer"): + squeezed.grad_buffer = b.grad_buffer.squeeze(0) if b.grad_buffer.ndim == 2 else b.grad_buffer + if hasattr(b, "param_grad_is_zero"): + squeezed.param_grad_is_zero = b.param_grad_is_zero + bias_1_chunked.append(squeezed) + # DEBUG: Check bias shape + if "bias_shapes" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["bias_shapes"] = {} + _MLP_DEBUG_TRACES["bias_shapes"]["bias_1_orig"] = bias_1.shape + _MLP_DEBUG_TRACES["bias_shapes"]["bias_1_chunk_0"] = bias_1_chunked[0].shape + else: + bias_1_chunked = [None] * num_experts + + if bias_2 is not None: + _, bias_2_chunked_raw = chunk_weight(hidden_states, bias_2, num_experts) + # Squeeze chunked biases to 1D since torch.nn.functional.linear expects 1D bias + # Preserve grad_buffer attribute after squeezing + bias_2_chunked = [] + for b in bias_2_chunked_raw: + squeezed = b.squeeze(0) if b.ndim == 2 else b + if hasattr(b, "grad_buffer"): + squeezed.grad_buffer = b.grad_buffer.squeeze(0) if b.grad_buffer.ndim == 2 else b.grad_buffer + if hasattr(b, "param_grad_is_zero"): + squeezed.param_grad_is_zero = b.param_grad_is_zero + bias_2_chunked.append(squeezed) + # DEBUG: Check bias shape + if "bias_shapes" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["bias_shapes"] = {} + _MLP_DEBUG_TRACES["bias_shapes"]["bias_2_orig"] = bias_2.shape + _MLP_DEBUG_TRACES["bias_shapes"]["bias_2_chunk_0"] = bias_2_chunked[0].shape + else: + bias_2_chunked = [None] * num_experts + + for expert_idx, (weight_1_chunk, weight_2_t_chunk, bias_1_chunk, bias_2_chunk) in enumerate( + zip(weight_1_chunked, weight_2_t_chunked, bias_1_chunked, bias_2_chunked) + ): row, column = torch.where(expert_mask[expert_idx]) if column.size(0) > 0: output[column] += ( @@ -476,21 +620,31 @@ def mlp_autograd_looped( hidden_states[column], None, weight_1_chunk, - None, + bias_1_chunk, weight_2_t_chunk, - None, + bias_2_chunk, gated, activation_type, group, sequence_parallel, training, recompute_level, - True, + True, # transposed_layer_2_weight - weight_2 stored as (out, experts*in) ) * scores[column, row, None] ) + # Finalize gradient tracking in reverse order + if bias_2 is not None: + output = chunk_weight_post(output, bias_2, bias_2_chunked) + if bias_1 is not None: + output = chunk_weight_post(output, bias_1, bias_1_chunked) output = chunk_weight_post(output, weight_2, weight_2_t_chunked) output = chunk_weight_post(output, weight_1, weight_1_chunked) + # DEBUG: Save looped MLP output + if "looped_outputs" not in _MLP_DEBUG_TRACES: + _MLP_DEBUG_TRACES["looped_outputs"] = [] + _MLP_DEBUG_TRACES["looped_outputs"].append(output.detach().cpu()[:1]) # First token + return output diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index 7c803689c..640a69440 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -307,7 +307,9 @@ def get_sparse_map( num_rows_unpadded = num_rows_dense * num_experts_per_token max_rows = (num_rows_unpadded + num_experts * pad_to_multiple) // pad_to_multiple * pad_to_multiple dtype = torch.int16 if max_rows < 32768 else torch.int32 - if (use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton: + # TEMPORARY: Disable Triton kernel due to bug on Triton 3.3+/ARM64 + # TODO: Fix sparse_map_kernel to work correctly on newer Triton versions + if False and ((use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton): expert_ends, expert_pad_begins = top_experts.new_empty((2 * num_experts,), dtype=dtype).chunk(2) sparse_rows = expert_ends.new_empty(num_rows_dense, num_experts_per_token) sparse_map_kernel[(triton.cdiv(num_rows_dense, block_size),)]( @@ -335,3 +337,96 @@ def get_sparse_map( num_experts=num_experts, num_experts_per_token=num_experts_per_token, ) + + +@triton_jit() +def add_sparse_bias_kernel( + input_ptr, + bias_ptr, + output_ptr, + expert_ends_ptr, + num_columns: tl_constexpr, + num_experts: tl_constexpr, + block_size: tl_constexpr, +): + """Add expert-specific bias to sparse tensor.""" + sparse_row = tl.program_id(0) + offsets = tl.arange(0, block_size) + block_size * tl.program_id(1) + mask = None if num_columns % block_size == 0 else offsets < num_columns + + # Find which expert this sparse row belongs to + # The sparse rows are organized such that rows for expert i are in range [expert_begins[i], expert_ends[i]) + expert_idx = 0 + for i in range(num_experts): + expert_end = tl.load(expert_ends_ptr + i) + if sparse_row < expert_end: + expert_idx = i + break + + # Load input and bias + input_val = tl.load(input_ptr + sparse_row * num_columns + offsets, mask=mask) + bias_val = tl.load(bias_ptr + expert_idx * num_columns + offsets, mask=mask) + + # Add bias and store + output_val = input_val + bias_val + tl.store(output_ptr + sparse_row * num_columns + offsets, output_val, mask=mask) + + +def add_sparse_bias( + input_: torch.Tensor, # shape: (num_sparse_rows, out_features_per_expert) + bias: torch.Tensor, # shape: (num_experts, out_features_per_expert) + sparse_map: SparseMap, +) -> torch.Tensor: + """Add expert-specific biases to sparse tensor based on expert assignment.""" + num_sparse_rows, hidden_size = input_.shape + num_experts, bias_hidden_size = bias.shape + assert hidden_size == bias_hidden_size, f"Hidden size mismatch: {hidden_size} vs {bias_hidden_size}" + assert num_experts == sparse_map.num_experts + + # Use PyTorch implementation for now (can optimize with Triton later if needed) + output = input_.clone() + + # For each expert, add its bias to the rows it processed + for expert_idx in range(num_experts): + expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item() + expert_end = sparse_map.expert_ends[expert_idx].item() + expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item() + + # Add bias only to unpadded rows + if expert_begin < expert_pad_begin: + output[expert_begin:expert_pad_begin] += bias[expert_idx] + + return output + + +def add_sparse_bias_forward( + input_: torch.Tensor, bias: torch.Tensor, sparse_map: SparseMap +) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, SparseMap]]: + return add_sparse_bias(input_, bias, sparse_map), (input_, bias, sparse_map) + + +def add_sparse_bias_backward( + grad_output: torch.Tensor, context: tuple[torch.Tensor, torch.Tensor, SparseMap] +) -> tuple[torch.Tensor, torch.Tensor]: + input_, bias, sparse_map = context + + # Gradient w.r.t. input is just grad_output (bias is added elementwise) + grad_input = grad_output + + # Gradient w.r.t. bias: sum gradients for each expert's rows + grad_bias = torch.zeros_like(bias) + num_experts = sparse_map.num_experts + + for expert_idx in range(num_experts): + expert_begin = 0 if expert_idx == 0 else sparse_map.expert_ends[expert_idx - 1].item() + expert_end = sparse_map.expert_ends[expert_idx].item() + expert_pad_begin = sparse_map.expert_pad_begins[expert_idx].item() + + # Sum gradients only from unpadded rows + if expert_begin < expert_pad_begin: + grad_bias[expert_idx] = grad_output[expert_begin:expert_pad_begin].sum(dim=0) + + return grad_input, grad_bias + + +add_sparse_bias_autograd = wrap_forward_backward(add_sparse_bias_forward, add_sparse_bias_backward) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 167184193..fb0ed0315 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -143,6 +143,17 @@ def __init__( # Rotary embeddings. self._rotary = self._config.rotary.get_layer(head_size_dim) + # Attention sinks for streaming attention (optional) + # Sinks are learnable embeddings, one per head + sinks_dim = TensorDim("sinks", self._config.heads) + self.sinks = self._config.sinks.get_parameter( + (sinks_dim,), + default_initialization=init_normal_(std=self._hidden_size**-0.5), + lr_scale=self._lr_scale, + default_enabled=False, + peft=None, + ) + # Output. self.dense = self._config.dense_layer.get_layer( dense_dim, @@ -207,7 +218,24 @@ def _attn_fused( attn_weights = attn_weights.to(torch.float32) attn_weights = torch.where(mask, attn_weights, mask_value) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) + + # Apply attention sinks if enabled + if self.sinks is not None: + # sinks shape: (local_heads,) where local_heads = local_head_groups * local_heads_per_group + # Reshape to match attn_weights: (b, local_head_groups, sq, local_heads_per_group, sk) + sinks = self.sinks.reshape(self._local_head_groups, self._local_heads_per_group) + sinks = sinks.reshape(1, self._local_head_groups, 1, self._local_heads_per_group, 1) + sinks = sinks.expand(b, -1, sq, -1, 1) + # Concatenate sinks as an extra dimension + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + # Subtract max for numerical stability (matching HF implementation) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + # Apply softmax + combined_probs = torch.nn.functional.softmax(combined_logits, dim=-1) + # Drop the sink dimension after softmax + attn_weights = combined_probs[..., :-1].to(query.dtype) + else: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) with set_generator(self._distributed.tp_generator): attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 68b6dde91..5040d45c1 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -4,6 +4,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig @@ -99,6 +100,10 @@ class AttentionConfig(MixerConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + sinks: OptionalParameterConfig = Field( + desc="Configuration for attention sinks parameter. Sinks are learnable embeddings (one per head) prepended to keys/values for streaming attention.", + hint=FieldHint.architecture, + ) softmax_scale_power: float = Field( default=0.5, desc="The scaling power to apply to head_size in the attention calculation. " @@ -123,4 +128,8 @@ def layer_class(self) -> "type[Attention]": return Attention def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) + return ( + self.use_flash_attention + and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) + and not self.sinks.enabled + ) diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e7c6d9e92..0ece89276 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -3,14 +3,19 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim, scalar_dim from fast_llm.functional.config import ActivationType from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.common.linear.convolution import CausalConv1d - from fast_llm.layers.common.linear.linear import LinearBase + from fast_llm.layers.common.linear.linear import ( + LinearBase, + Linear, + InputParallelLinear, + OutputParallelLinear, + ) @config_class() @@ -43,7 +48,7 @@ class AffineLinearBaseConfig(LinearBaseConfig): ) -@config_class() +@config_class(registry=True) class LinearConfig(LinearBaseConfig): apply_peft: bool | None = Field( default=None, @@ -102,8 +107,16 @@ def get_layer( return out -@config_class() +@config_class(dynamic_type={LinearConfig: "affine_linear"}) class AffineLinearConfig(AffineLinearBaseConfig, LinearConfig): + def _get_weight_out_dim(self, out_dim: TensorDim, transposed_weight: bool = False) -> TensorDim: + """Get the output dimension for weight parameter. Override in subclasses for special handling.""" + return out_dim + + def _get_bias_dims(self, out_dim: TensorDim) -> tuple[TensorDim, ...]: + """Get the dimensions for bias parameter. Override in subclasses for special handling.""" + return (out_dim,) + def get_layer( self, in_dim: TensorDim, @@ -121,21 +134,27 @@ def get_layer( from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear lr_scale = combine_lr_scales(lr_scale, self.lr_scale) + + # Get weight and bias dimensions (may differ for subclasses like MoE) + weight_out_dim = self._get_weight_out_dim(out_dim, transposed_weight) + weight = self.weight.get_parameter( - (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), + (in_dim, weight_out_dim) if transposed_weight else (weight_out_dim, in_dim), default_initialization=default_weight_initialization, lr_scale=lr_scale, peft=None, ) bias = self.bias.get_parameter( - (out_dim,), + self._get_bias_dims(out_dim), default_initialization=default_bias_initialization, lr_scale=lr_scale, default_enabled=default_add_bias, peft=None, ) + + # Use weight_out_dim for layer selection if in_dim.parallel_dim is not None: - assert out_dim.parallel_dim is None + assert weight_out_dim.parallel_dim is None out = InputParallelLinear( weight, bias, @@ -143,12 +162,12 @@ def get_layer( parallel_dim=in_dim.parallel_dim, sequence_parallel=sequence_parallel, ) - elif out_dim.parallel_dim is not None: + elif weight_out_dim.parallel_dim is not None: out = OutputParallelLinear( weight, bias, transposed_weight=transposed_weight, - parallel_dim=out_dim.parallel_dim, + parallel_dim=weight_out_dim.parallel_dim, sequence_parallel=sequence_parallel, ) else: @@ -161,6 +180,44 @@ def get_layer( return out +@config_class(dynamic_type={LinearConfig: "moe_affine_linear"}) +class MoEAffineLinearConfig(AffineLinearConfig): + """ + AffineLinearConfig for MoE layers with per-expert biases. + + When out_dim is a CompositeTensorDim like (experts_dim, output_features_dim): + - Weight dimension depends on transposed_weight: + * Non-transposed (layer 1): uses full flattened size (num_experts * output_features_per_expert) + * Transposed (layer 2): uses only feature dimension (output_features_per_expert) + - Bias always uses structured dimensions (experts_dim, output_features_per_expert) for per-expert biases + + This matches the sparse MoE implementation where: + - Layer 1 (output-parallel sparse): weight is (num_experts * features, input), bias is (num_experts, features) + - Layer 2 (input-parallel sparse): weight is (num_experts * input, features), bias is (num_experts, features) + """ + + def _get_weight_out_dim(self, out_dim: TensorDim, transposed_weight: bool = False) -> TensorDim: + """For MoE, weight dimension depends on whether output or input is sparse.""" + if isinstance(out_dim, CompositeTensorDim): + if transposed_weight: + # For transposed weight (layer 2), input is sparse, output is NOT sparse + # Use only the feature dimension (last component) + return out_dim._tensor_dims[-1] + else: + # For non-transposed weight (layer 1), output IS sparse + # Use the full flattened dimension + return out_dim + else: + return out_dim + + def _get_bias_dims(self, out_dim: TensorDim) -> tuple[TensorDim, ...]: + """For MoE, use the composite structure for biases to get per-expert biases.""" + if isinstance(out_dim, CompositeTensorDim): + return out_dim._tensor_dims + else: + return (out_dim,) + + @config_class() class CausalConv1dConfig(AffineLinearBaseConfig): """ diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index ffc9eadba..eeb099c55 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -8,8 +8,14 @@ from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.mlp import ( + mlp_autograd, + mlp_autograd_looped, + torch_mlp_activation, + triton_mlp_activation_autograd, +) from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs @@ -29,7 +35,6 @@ class MixtureOfExpertMLP[ConfigType: MoEMLPConfig](MLPBase[ConfigType]): https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 With custom routing implementation supporting both topk and sinkhorn routing - TODO: Bias TODO: Sequence-tensor-parallel TODO: Expert parallel """ @@ -49,9 +54,9 @@ def __init__( return_bias: bool = True, ): Assert.gt(config.experts, 1) - # TODO: Implement? - assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__( + + # Call grandparent __init__ to avoid creating layers yet + super(MLPBase, self).__init__( config, distributed_config, hidden_dim=hidden_dim, @@ -59,6 +64,40 @@ def __init__( peft=peft, return_bias=return_bias, ) + + # Create MoE-specific dimensions + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() + self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + + # Create layers with MoE-specific dimensions + self.layer_1 = self._config.layer_1.get_layer( + hidden_dim, + intermediate_1_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + # For layer_2: pass composite dimension to enable per-expert biases + # The MoEAffineLinearConfig will extract the feature dimension for weight (since transposed=True) + # but use the full structure for per-expert biases + experts_dim = TensorDim("experts", config.experts) + moe_hidden_dim = CompositeTensorDim("moe_hidden", (experts_dim, hidden_dim)) + + self.layer_2 = self._config.layer_2.get_layer( + self._intermediate_2_dim, + moe_hidden_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, + sequence_parallel=self._sequence_parallel, + transposed_weight=True, # Weights stored in (out_features, experts*in_features) format + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.router = self._config.router.get_layer( self._hidden_dim, TensorDim("router_experts", self._config.unshared_experts), @@ -91,8 +130,11 @@ def _forward( hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: + # Create flattened dimension for debug logging + batch_seq_dim = TensorDim("batch_seq", hidden_states.size(0)) + router_expert_dim = TensorDim("router_experts", self._config.unshared_experts) self._debug( - logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + logits, "Router logits", (batch_seq_dim, router_expert_dim), kwargs ) # Apply z_loss if applicable @@ -123,13 +165,15 @@ def _forward( if self._debug.enabled: # To log all ranks set `global_=False` + # Use flattened dimension for debug logging + batch_seq_dim = TensorDim("batch_seq", hidden_states.size(0)) self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + scores, "Router scores", (batch_seq_dim, self._top_expert_dim), kwargs ) self._debug( top_experts, "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), + (batch_seq_dim, self._top_expert_dim), kwargs, ) @@ -148,16 +192,16 @@ def _forward_dropless( hidden_states, scores, self.layer_1.weight, - None, + self.layer_1.bias, self.layer_2.weight, - None, + None if self._parallel_dim.group else self.layer_2.bias, gated=self._config.gated, activation_type=self._config.activation, group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.recompute_level, - transposed_layer_2_weight=True, + transposed_layer_2_weight=True, # Weights: (out, experts*in) - transposed sparse_map=sparse_map, ) @@ -177,6 +221,8 @@ def _forward_looped( self._sequence_parallel, self.training, self._config.recompute_level, + self.layer_1.bias, + None if self._parallel_dim.group else self.layer_2.bias, ) @torch.compile diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 931c7f644..9526eac72 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -261,11 +261,17 @@ def log_distributed_tensor[ if level <= 0: return if global_: - tensor, is_first_rank = meta.local_to_global(tensor) - storage = False - is_first_rank = is_first_rank and all(group.rank() == 0 for group in duplicate_groups if group) - if not is_first_rank: - log_fn = None + try: + tensor, is_first_rank = meta.local_to_global(tensor) + storage = False + is_first_rank = is_first_rank and all(group.rank() == 0 for group in duplicate_groups if group) + if not is_first_rank: + log_fn = None + except (AssertionError, RuntimeError) as e: + # Shape mismatch during local_to_global conversion - log the local tensor instead + if log_fn is not None: + logger.warning(f"Failed to convert {name} to global tensor (expected shape {meta.shape}, got {tensor.shape}): {e}. Logging local tensor instead.") + global_ = False if log_fn is not None: return log_tensor( f"{'Global' if global_ else 'Local'} {name}: {meta.tensor_name}", diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a901a0466..fab23e3f5 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -17,6 +17,7 @@ AutoGPTHuggingfaceCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + GptOssCheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -117,6 +118,7 @@ class GPTModelConfig(FastLLMModelConfig): DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, AprielHybridSSMCheckpointFormat, + GptOssCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py index 659d1f12c..dad5ea8db 100644 --- a/fast_llm/models/gpt/conversion/auto.py +++ b/fast_llm/models/gpt/conversion/auto.py @@ -7,6 +7,7 @@ AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + GptOssCheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -15,6 +16,7 @@ ) from fast_llm.models.gpt.conversion.diffusion_dream import DiffusionDreamHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.diffusion_llama import DiffusionLlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.gpt_oss import GptOssHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.llama import LlamaHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.mistral import MistralHuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.mixtral import MixtralHuggingfaceCheckpointHandler @@ -35,4 +37,5 @@ class AutoGPTHuggingfaceCheckpointHandler( DiffusionDreamCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, + GptOssCheckpointFormat.name: GptOssHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py index 7c06906ad..a05564000 100644 --- a/fast_llm/models/gpt/conversion/config.py +++ b/fast_llm/models/gpt/conversion/config.py @@ -47,3 +47,7 @@ class DiffusionLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_hybrid_ssm" + + +class GptOssCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "gpt_oss" diff --git a/fast_llm/models/gpt/conversion/gpt_oss.py b/fast_llm/models/gpt/conversion/gpt_oss.py new file mode 100644 index 000000000..cee911717 --- /dev/null +++ b/fast_llm/models/gpt/conversion/gpt_oss.py @@ -0,0 +1,703 @@ +import typing + +import torch + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig, MoEMLPConfig +from fast_llm.models.gpt.conversion.config import GptOssCheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + LlamaAttentionConverter, + LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaHeadConverter, + LlamaMLPConverter, + get_parameter_converter, + get_weight_and_bias_converters, +) +from fast_llm.models.gpt.conversion.mistral import MistralHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.mixtral import MixtralMLPConverter +from fast_llm.tensor import SafeTensorSlice +from fast_llm.utils import Assert, safe_merge_dicts + + +class GptOssAttentionConverter(LlamaAttentionConverter): + """ + GPT-OSS attention converter. + + Inherits from Llama (which supports YARN RoPE) and adds: + - attention_bias support + - attention sinks support + """ + + @classmethod + def import_config(cls, config: dict) -> dict: + out = super().import_config(config) + # GPT-OSS supports attention_bias unlike Llama + out["add_linear_biases"] = config.get("attention_bias", False) + # GPT-OSS always uses attention sinks + out["sinks"] = {"enabled": True} + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + out = super().export_config(config) + out["attention_bias"] = config.add_linear_biases + # Don't add sinks to config, it's indicated by presence of sinks parameter + return out + + @classmethod + def _check_config(cls, config: AttentionConfig) -> None: + # Unlike Llama/Mistral, GPT-OSS supports biases + Assert.is_(type(config), AttentionConfig) + Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + # Get base converters from parent class + converters = super().get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export) + + # Add sinks converter if enabled + if config.sinks.enabled: + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.sinks", + f"{hf_prefix}.sinks", + drop_on_export=drop_on_export, + ) + ) + + return converters + + +class GptOssMoEWeightConverter(WeightConverter): + """ + Converter for GPT-OSS MoE weights (for down_proj). + + HF format: (num_experts, in_features, out_features) - e.g. (32, 2880, 2880) + Fast-LLM format: (num_experts * in_features, out_features) - e.g. (92160, 2880) + + Experts are concatenated along the first dimension WITHOUT transposing. + The layer uses transposed_weight=True, which transposes the weight during forward pass. + """ + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (weight_tensor,) = weight + # Fast-LLM: (num_experts * in_features, out_features) -> HF: (num_experts, in_features, out_features) + weight_loaded = weight_tensor[:] + num_experts = self._config.experts + total_in, out_features = weight_loaded.shape + in_features = total_in // num_experts + # Just reshape - NO transpose + weight_reshaped = weight_loaded.reshape(num_experts, in_features, out_features) + return (weight_reshaped,) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (weight_tensor,) = weight + # HF: (num_experts, in_features, out_features) -> Fast-LLM: (num_experts * in_features, out_features) + # Weight is stored as (in, out), but layer uses transposed_weight=True to transpose during forward + weight_loaded = weight_tensor[:] + num_experts, in_features, out_features = weight_loaded.shape + # Just reshape - NO transpose + weight_reshaped = weight_loaded.reshape(num_experts * in_features, out_features) + return (weight_reshaped,) + + +class GptOssMoEGateUpConverter(WeightConverter): + """ + Converter for GPT-OSS MoE gate_up_proj weights. + + HF format: (num_experts, in_features, 2 * out_features) with interleaved gate/up - e.g. (32, 2880, 5760) + where gate and up are interleaved: [g0, u0, g1, u1, ...] + Fast-LLM format: (num_experts * 2 * out_features, in_features) with concatenated gate/up - e.g. (184320, 2880) + where gate and up are concatenated: [g0, g1, ..., u0, u1, ...] + + This converter: + 1. Transposes each expert's weight + 2. De-interleaves gate and up projections + 3. Concatenates all experts + """ + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (weight_tensor,) = weight + # Fast-LLM: (num_experts * 2 * expert_dim, in_features) concatenated -> HF: (num_experts, in_features, 2 * expert_dim) interleaved + weight_loaded = weight_tensor[:] + num_experts = self._config.experts + total_out, in_features = weight_loaded.shape + expert_dim = total_out // (num_experts * 2) + + # Reshape to separate experts: (num_experts, 2 * expert_dim, in_features) + weight_per_expert = weight_loaded.reshape(num_experts, 2 * expert_dim, in_features) + + # Split each expert into gate and up: (num_experts, expert_dim, in_features) each + gate = weight_per_expert[:, :expert_dim, :] + up = weight_per_expert[:, expert_dim:, :] + + # Transpose: (num_experts, in_features, expert_dim) + gate_t = gate.transpose(1, 2) + up_t = up.transpose(1, 2) + + # Interleave columns: stack and reshape + # (num_experts, in_features, expert_dim, 2) -> (num_experts, in_features, 2 * expert_dim) + weight_interleaved = torch.stack([gate_t, up_t], dim=-1).reshape(num_experts, in_features, 2 * expert_dim) + + return (weight_interleaved,) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (weight_tensor,) = weight + # HF: (num_experts, in_features, 2 * expert_dim) interleaved -> Fast-LLM: (num_experts * 2 * expert_dim, in_features) concatenated + weight_loaded = weight_tensor[:] + num_experts, in_features, total_out = weight_loaded.shape + expert_dim = total_out // 2 + + # De-interleave: columns [0,2,4,...] are gate, [1,3,5,...] are up + # Split into gate and up by selecting even/odd columns + gate = weight_loaded[:, :, 0::2] # (num_experts, in_features, expert_dim) - even columns + up = weight_loaded[:, :, 1::2] # (num_experts, in_features, expert_dim) - odd columns + + # Transpose each: (num_experts, expert_dim, in_features) + gate_t = gate.transpose(1, 2) + up_t = up.transpose(1, 2) + + # For each expert, concatenate gate and up + # Result: (num_experts, 2 * expert_dim, in_features) + weight_per_expert = torch.cat([gate_t, up_t], dim=1) + + # Reshape to (num_experts * 2 * expert_dim, in_features) + weight_reshaped = weight_per_expert.reshape(num_experts * 2 * expert_dim, in_features) + + return (weight_reshaped,) + + +class GptOssMoEBiasConverter(WeightConverter): + """ + Converter for GPT-OSS MoE biases (for down_proj). + + Both Fast-LLM and HF formats: (num_experts, out_features_per_expert) + + No transformation needed - just pass through. + """ + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + # Both Fast-LLM and HF use (num_experts, out_features_per_expert) + return weight + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + # Both HF and Fast-LLM use (num_experts, out_features_per_expert) + return weight + + +class GptOssMoEGateUpBiasConverter(WeightConverter): + """ + Converter for GPT-OSS MoE gate_up_proj biases. + + HF format: (num_experts, 2 * expert_dim) with interleaved gate/up - e.g. (32, 5760) + where gate and up are interleaved: [g0, u0, g1, u1, ...] + Fast-LLM format: (num_experts, 2 * expert_dim) with concatenated gate/up + where gate and up are concatenated: [g0, g1, ..., u0, u1, ...] + + This converter de-interleaves/re-interleaves the biases. + """ + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (bias_tensor,) = weight + # Fast-LLM: (num_experts, 2 * expert_dim) concatenated -> HF: (num_experts, 2 * expert_dim) interleaved + bias_loaded = bias_tensor[:] + num_experts, total_dim = bias_loaded.shape + expert_dim = total_dim // 2 + + # Split into gate and up: (num_experts, expert_dim) each + gate = bias_loaded[:, :expert_dim] + up = bias_loaded[:, expert_dim:] + + # Interleave: stack and reshape (num_experts, expert_dim, 2) -> (num_experts, 2 * expert_dim) + bias_interleaved = torch.stack([gate, up], dim=-1).reshape(num_experts, 2 * expert_dim) + + return (bias_interleaved,) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (bias_tensor,) = weight + # HF: (num_experts, 2 * expert_dim) interleaved -> Fast-LLM: (num_experts, 2 * expert_dim) concatenated + bias_loaded = bias_tensor[:] + num_experts, total_dim = bias_loaded.shape + total_dim // 2 + + # De-interleave: indices [0,2,4,...] are gate, [1,3,5,...] are up + gate = bias_loaded[:, 0::2] # (num_experts, expert_dim) - even indices + up = bias_loaded[:, 1::2] # (num_experts, expert_dim) - odd indices + + # Concatenate: (num_experts, 2 * expert_dim) + bias_concat = torch.cat([gate, up], dim=1) + + return (bias_concat,) + + +def get_gpt_oss_weight_and_bias_converters( + fast_llm_prefix: str, + hf_prefix: str, + use_bias: bool, + weight_cls=WeightConverter, + drop_on_export: bool = False, + bias_converter_cls=None, + config=None, +) -> list[WeightConverter]: + """ + Get weight and bias converters for GPT-OSS MoE format. + + GPT-OSS MoE parameters don't have .weight/.bias suffixes in the checkpoint. + Instead they use: + - experts.gate_up_proj (no .weight suffix) + - experts.gate_up_proj_bias (uses _bias not .bias) + """ + converters = [ + get_parameter_converter( + f"{fast_llm_prefix}.weight", + hf_prefix, # HF doesn't have .weight suffix for MoE experts + weight_cls, + config, + drop_on_export, + ) + ] + if use_bias: + # GPT-OSS uses "_bias" suffix for expert biases + # Use provided bias converter or default + if bias_converter_cls is None: + bias_converter_cls = GptOssMoEBiasConverter + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.bias", + f"{hf_prefix}_bias", # Note: _bias not .bias + bias_converter_cls, + config, + drop_on_export, + ) + ) + return converters + + +class GptOssMLPConverter(MixtralMLPConverter): + """ + GPT-OSS MoE MLP converter. + + Handles the dequantized GPT-OSS checkpoint format which uses: + - Router at .router (not .gate like Mixtral) + - Router has bias (unlike Mixtral) + - Concatenated gate_up_proj and down_proj (not separate w1/w2/w3 like Mixtral) + - Expert biases use "_bias" suffix (not ".bias") + """ + + @classmethod + def import_config(cls, config: dict) -> dict: + out = super().import_config(config) + out["router"] = { + "type": "affine_linear", + "bias": {"enabled": True}, + } + out["add_linear_biases"] = True + # GPT-OSS uses custom GLU activation + out["activation"] = "gpt_oss_glu" + # Use moe_affine_linear type for MoE expert layers to get per-expert biases + out["layer_1"] = { + "type": "moe_affine_linear", + "bias": {"enabled": True}, + } + out["layer_2"] = { + "type": "moe_affine_linear", + "bias": {"enabled": True}, + } + return out + + @classmethod + def export_config(cls, config: MoEMLPConfig) -> dict: + Assert.custom(isinstance, config, MoEMLPConfig) + # Unlike Mixtral, GPT-OSS supports biases on expert layers + return safe_merge_dicts( + # Skip MixtralMLPConverter.export_config to avoid the bias assertion + # Call grandparent (LlamaMLPConverter) instead + LlamaMLPConverter.export_config(config), + { + "num_local_experts": config.experts, + "num_experts_per_tok": config.experts_per_token, + }, + ) + + @classmethod + def get_converters( + cls, + config: MoEMLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + # Router: GPT-OSS uses .router instead of .gate + # Router has bias in GPT-OSS (unlike Mixtral which doesn't) + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.router", + f"{hf_prefix}.router", # Different from Mixtral which uses .gate + True, + drop_on_export=drop_on_export, + ), + # Experts use concatenated format like Llama (gate_up_proj, down_proj) + # not separate w1/w2/w3 like Mixtral + # GPT-OSS gate_up_proj has interleaved gate/up, needs special converter + *get_gpt_oss_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + f"{hf_prefix}.experts.gate_up_proj", + config.add_linear_biases, + GptOssMoEGateUpConverter, # Special converter for interleaved gate/up + drop_on_export=drop_on_export, + bias_converter_cls=GptOssMoEGateUpBiasConverter, # Special bias converter + config=config, + ), + # down_proj uses standard MoE converter (no interleaving) + *get_gpt_oss_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.experts.down_proj", + config.add_linear_biases, + GptOssMoEWeightConverter, + drop_on_export=drop_on_export, + config=config, + ), + ] + + +class GptOssBlockConverter(LlamaBlockConverter): + """ + GPT-OSS block converter. + + Uses dynamic MLP converter selection (Llama vs Mixtral) based on config type. + """ + + # Layout names for heterogeneous block patterns + layout_names = { + "sliding_attention": "sliding", + "full_attention": "full", + } + + # Dynamic converter selection like Apriel + _mixer_converter_classes = { + AttentionConfig: GptOssAttentionConverter, + } + _mlp_converter_classes = { + MLPConfig: LlamaMLPConverter, + MoEMLPConfig: GptOssMLPConverter, + } + + mixer_converter_class: typing.ClassVar[type[GptOssAttentionConverter]] = GptOssAttentionConverter + mlp_converter_class: typing.ClassVar[type] = None # Will be selected dynamically + + hf_mixer_name: typing.ClassVar[str] = "self_attn" + hf_mlp_name: typing.ClassVar[str] = "mlp" # GPT-OSS uses .mlp (after dequantization) + hf_norm_1_name: typing.ClassVar[str] = "input_layernorm" + hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" + + @classmethod + def import_config(cls, config: dict, layer_type: str = "full_attention") -> dict: + # Create attention config + attention_config = cls.mixer_converter_class.import_config(config) + + # Handle sliding window for this specific layer type + if layer_type == "sliding_attention": + if "window_size" not in attention_config: + attention_config["window_size"] = config.get("sliding_window", 128) + else: + # For full attention, remove window_size if present + attention_config.pop("window_size", None) + + # Determine MLP converter based on config + if "num_local_experts" in config: + mlp_converter = cls._mlp_converter_classes[MoEMLPConfig] + else: + mlp_converter = cls._mlp_converter_classes[MLPConfig] + + return { + "mixer": attention_config, + "mlp": mlp_converter.import_config(config), + "normalization": cls.normalization_converter_class.import_config(config), + } + + @classmethod + def export_config(cls, config: DecoderBlockConfig) -> dict: + Assert.custom(isinstance, config, DecoderBlockConfig) + + # Select MLP converter based on config type + mlp_converter = cls._mlp_converter_classes[type(config.mlp)] + + return safe_merge_dicts( + cls.mixer_converter_class.export_config(config.mixer), + mlp_converter.export_config(config.mlp), + cls.normalization_converter_class.export_config(config.normalization), + ) + + @classmethod + def get_converters( + cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False + ) -> list[WeightConverter]: + # Select MLP converter based on config type + mlp_converter = cls._mlp_converter_classes[type(config.mlp)] + + return [ + *cls.mixer_converter_class.get_converters( + config.mixer, + f"{fast_llm_prefix}.mixer", + f"{hf_prefix}.{cls.hf_mixer_name}", + drop_on_export, + ), + *mlp_converter.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + f"{hf_prefix}.{cls.hf_mlp_name}", + drop_on_export, + ), + *cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.{cls.hf_norm_1_name}", + drop_on_export, + ), + *cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.{cls.hf_norm_2_name}", + drop_on_export, + ), + ] + + +class GptOssDecoderConverter: + """ + GPT-OSS decoder converter with heterogeneous block pattern support. + + Handles the `layer_types` field that specifies alternating attention patterns. + """ + + block_converter_class: typing.ClassVar[type[GptOssBlockConverter]] = GptOssBlockConverter + + @classmethod + def _get_layer_type(cls, config: DecoderBlockConfig) -> str: + """Determine layer type from block config.""" + match config.mixer: + case AttentionConfig(window_size=window_size) if window_size is not None: + return "sliding_attention" + case _: + return "full_attention" + + @classmethod + def _find_minimal_repeating_pattern(cls, layer_types: list[str]) -> list[str]: + """Find the minimal repeating pattern in layer_types. + + Uses the property that the period must divide the length. + Tries periods in increasing order to find the smallest one. + + Examples: + - ["A", "B", "A", "B"] -> ["A", "B"] + - ["A", "B", "C", "A", "B", "C"] -> ["A", "B", "C"] + - ["A", "B", "C"] -> ["A", "B", "C"] (no repetition) + """ + n = len(layer_types) + + # Try each possible period length from 1 to n + for period_len in range(1, n + 1): + # Period must divide the total length evenly + if n % period_len == 0: + candidate_pattern = layer_types[:period_len] + # Check if repeating this pattern reconstructs the full sequence + num_repeats = n // period_len + if candidate_pattern * num_repeats == layer_types: + return candidate_pattern + + # Fallback (should never reach here) + return layer_types + + @classmethod + def import_config(cls, config: dict) -> dict: + """Import decoder config, handling heterogeneous layer types.""" + layer_types = config.get("layer_types", ["full_attention"]) + + # Determine unique layer types + unique_types = list(dict.fromkeys(layer_types)) # Preserve order + + if len(unique_types) == 1: + # All layers are the same type - use FixedBlockSequenceConfig + return { + "block": cls.block_converter_class.import_config(config, unique_types[0]), + "num_blocks": config["num_hidden_layers"], + } + else: + # Multiple layer types - use PatternBlockSequenceConfig + # Find the minimal repeating pattern to enable compact representation + minimal_pattern = cls._find_minimal_repeating_pattern(layer_types) + + # Create a block config for each unique type in the minimal pattern + # Use dict.fromkeys to preserve order while removing duplicates + unique_in_pattern = list(dict.fromkeys(minimal_pattern)) + blocks = {} + for layer_type in unique_in_pattern: + layout_name = cls.block_converter_class.layout_names.get(layer_type, layer_type) + blocks[layout_name] = cls.block_converter_class.import_config(config, layer_type) + + # Create pattern using layout names + pattern = [cls.block_converter_class.layout_names.get(lt, lt) for lt in minimal_pattern] + + return { + "type": "pattern", + "blocks": blocks, + "pattern": pattern, + "num_blocks": config["num_hidden_layers"], + } + + @classmethod + def export_config(cls, config: BlockSequenceConfig) -> dict: + """Export decoder config, reconstructing layer_types.""" + match config: + case FixedBlockSequenceConfig(): + # All blocks are the same + block_configs = [config.block] + layer_type = cls._get_layer_type(config.block) + layer_types = [layer_type] * config.num_blocks + case PatternBlockSequenceConfig(): + # Multiple block types + block_configs = list(config.blocks.values()) + # Reconstruct layer_types from expanded pattern + # HuggingFace requires layer_types length to match num_hidden_layers + layer_types = [] + for block_name in config.expanded_pattern: + block_config = config.blocks[block_name] + layer_type = cls._get_layer_type(block_config) + layer_types.append(layer_type) + case _: + raise NotImplementedError(f"Unsupported block sequence type: {type(config)}") + + # Export each block config and handle sliding_window conflicts + exported_configs = [cls.block_converter_class.export_config(block_config) for block_config in block_configs] + + # Extract sliding_window values to handle heterogeneous blocks + sliding_window = None + for exported_config in exported_configs: + window = exported_config.pop("sliding_window", None) + if window is not None: + sliding_window = window + + # Merge all block configs + result = safe_merge_dicts( + *exported_configs, + { + "num_hidden_layers": config.num_blocks, + "layer_types": layer_types, + }, + ) + + # Add sliding_window back if any block had it + if sliding_window is not None: + result["sliding_window"] = sliding_window + + return result + + @classmethod + def get_converters( + cls, + config: BlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + """Get weight converters for all blocks in the decoder.""" + converters = [] + + if type(config) is FixedBlockSequenceConfig: + # All blocks use the same config + for block_index in range(config.num_blocks): + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + elif type(config) is PatternBlockSequenceConfig: + # Blocks follow a pattern + for block_index in range(config.num_blocks): + block_name = config.expanded_pattern[block_index] + block_config = config.blocks[block_name] + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + else: + raise NotImplementedError(f"Unsupported block sequence type: {type(config)}") + + return converters + + +class GptOssHeadConverter(LlamaHeadConverter): + block_converter_class: typing.ClassVar[type[GptOssBlockConverter]] = GptOssBlockConverter + + +class GptOssBaseModelConverter(LlamaBaseModelConverter): + """ + GPT-OSS base model converter. + + Handles: + - Vocab size ~201,088 (o200k_harmony tokenizer) + - Heterogeneous decoder with alternating attention patterns + - RMS normalization + - MoE layers + """ + + decoder_converter_class: typing.ClassVar[type[GptOssDecoderConverter]] = GptOssDecoderConverter + head_converter_class: typing.ClassVar[type[GptOssHeadConverter]] = GptOssHeadConverter + + +class GptOssHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + """ + Checkpoint handler for GPT-OSS models. + + Supports conversion between Fast-LLM and HuggingFace GPT-OSS format. + Handles both gpt-oss-120b (117B params) and gpt-oss-20b (21B params) variants. + + Key features: + - Mixture of Experts (32-128 experts, 4 active per token) + - Alternating sliding window and full attention patterns + - YARN RoPE scaling + - Grouped multi-query attention (8 KV heads) + """ + + format: typing.ClassVar[type[CheckpointFormat]] = GptOssCheckpointFormat + architecture: typing.ClassVar[str] = "GptOssForCausalLM" + base_model_converter_class: typing.ClassVar[type[GptOssBaseModelConverter]] = GptOssBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.GptOssConfig diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index a92492260..bbc4b82dc 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -122,7 +122,7 @@ class LlamaMLPConverter: def import_config(cls, config: dict) -> dict: return { "intermediate_size": config["intermediate_size"], - "add_linear_biases": config["mlp_bias"], + "add_linear_biases": config.get("mlp_bias", False), "activation": ActivationType.from_hf_name(config["hidden_act"]), "gated": True, } @@ -198,19 +198,20 @@ def import_config(cls, config: dict) -> dict: elif rope_type == "llama3": rotary_config.update( { - "scale_factor": config["factor"], - "low_frequency_factor": config["low_freq_factor"], - "high_frequency_factor": config["high_freq_factor"], - "original_context_length": config["original_max_position_embeddings"], + "scale_factor": config["rope_scaling"]["factor"], + "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], + "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], } ) elif rope_type == "yarn": rotary_config.update( { - "attention_factor": config["attention_factor"], - "beta_fast": config["beta_fast"], - "beta_slow": config["beta_slow"], - "original_context_length": config["original_max_position_embeddings"], + "scale_factor": config["rope_scaling"]["factor"], + "attention_factor": config["rope_scaling"].get("attention_factor"), + "beta_fast": config["rope_scaling"]["beta_fast"], + "beta_slow": config["rope_scaling"]["beta_slow"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], } ) else: @@ -220,7 +221,7 @@ def import_config(cls, config: dict) -> dict: "heads": config["num_attention_heads"], "head_groups": config["num_key_value_heads"], "head_size": config.get("head_dim"), - "add_linear_biases": config["attention_bias"], + "add_linear_biases": config.get("attention_bias", False), "dropout": config["attention_dropout"], } if out["head_size"] is None: @@ -253,6 +254,7 @@ def export_config(cls, config: AttentionConfig) -> dict: elif type(config.rotary) is YarnRotaryConfig: out["rope_scaling"] = { "rope_type": "yarn", + "factor": config.rotary.scale_factor, "attention_factor": config.rotary.attention_factor, "beta_fast": config.rotary.beta_fast, "beta_slow": config.rotary.beta_slow, diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 94670057f..726d3ed25 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -1,5 +1,7 @@ import typing +import torch + from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter from fast_llm.layers.decoder.mlp.config import MoEMLPConfig @@ -12,9 +14,58 @@ MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) +from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, safe_merge_dicts +class MoEMLPLayer2Converter(WeightConverter): + """ + Converter for MoE layer 2 (down projection) weights. + + HuggingFace format: Per-expert weights, each of shape [hidden_size, intermediate_size] + Fast-LLM format: Weight of shape [num_experts * intermediate_size, hidden_size] + + Fast-LLM stores MoE layer 2 weights with input dimension (intermediate) flattened across experts. + The output dimension (hidden) is NOT multiplied by experts - each expert outputs to the same hidden size. + This matches the MoEAffineLinearConfig which extracts only the feature dimension for transposed weights. + """ + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + # Fast-LLM: [num_experts * intermediate_size, hidden_size] + # HF needs: per-expert weights of [hidden_size, intermediate_size] + (merged_weight,) = weight + num_experts = len(self.export_name) + hidden_size = merged_weight.shape[1] + intermediate_size = merged_weight.shape[0] // num_experts + + # Reshape to [num_experts, intermediate_size, hidden_size] + reshaped = merged_weight[:].reshape(num_experts, intermediate_size, hidden_size) + + # Transpose each expert to [hidden_size, intermediate_size] (HF format) + return tuple(reshaped[i].t().contiguous() for i in range(num_experts)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + # HF: per-expert weights, each [hidden_size, intermediate_size] + # Need to create [num_experts * intermediate_size, hidden_size] + num_experts = len(weight) + + # Materialize first weight to get dtype, device, and shape + first_weight = weight[0][:] + hidden_size, intermediate_size = first_weight.shape # HF stores as [hidden, intermediate] + + # Transpose each expert's weights to [intermediate_size, hidden_size] and stack + expert_weights = [weight[i][:].t() for i in range(num_experts)] + + # Concatenate along first dimension: [num_experts * intermediate_size, hidden_size] + merged = torch.cat(expert_weights, dim=0) + + return (merged.contiguous(),) + + class MixtralMLPConverter(LlamaMLPConverter): @classmethod def import_config(cls, config: dict) -> dict: @@ -24,6 +75,13 @@ def import_config(cls, config: dict) -> dict: "type": "moe", "experts": config["num_local_experts"], "experts_per_token": config["num_experts_per_tok"], + # Use moe_affine_linear type for MoE expert layers to handle CompositeTensorDim correctly + "layer_1": { + "type": "moe_affine_linear", + }, + "layer_2": { + "type": "moe_affine_linear", + }, }, ) @@ -65,7 +123,7 @@ def get_converters( f"{fast_llm_prefix}.layer_2", tuple(f"{hf_prefix}.experts.{i}.w2" for i in range(config.experts)), False, - MLPLayer2Converter, + MoEMLPLayer2Converter, drop_on_export=drop_on_export, ), ] diff --git a/prepare_gpt_oss_checkpoint.py b/prepare_gpt_oss_checkpoint.py new file mode 100644 index 000000000..f7c649058 --- /dev/null +++ b/prepare_gpt_oss_checkpoint.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +Step 1: Download, dequantize, and truncate GPT-OSS model. +Saves the prepared checkpoint to a static directory. +""" + +import pathlib + +import torch +import transformers +from huggingface_hub import snapshot_download +from transformers import Mxfp4Config + +# Configuration +MODEL_PATH = "openai/gpt-oss-20b" +OUTPUT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +NUM_LAYERS_TO_KEEP = 2 + + +def main(): + print("=" * 80) + print(f"Preparing GPT-OSS {MODEL_PATH} ({NUM_LAYERS_TO_KEEP}-layer variant)") + print("=" * 80) + + # Create output directory + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + dequantized_path = OUTPUT_DIR / "dequantized_hf" + + print(f"\n1. Downloading HuggingFace model files...") + print(f" Source: {MODEL_PATH}") + + # Download the model files from HF Hub + hf_local_path = snapshot_download(repo_id=MODEL_PATH, local_dir_use_symlinks=False) + hf_local_path = pathlib.Path(hf_local_path) + print(f" Downloaded to: {hf_local_path}") + + print(f"\n2. Loading HuggingFace model with dequantization...") + # Load with dequantization to convert MXFP4 quantized weights to float + quantization_config = Mxfp4Config(dequantize=True) + + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_local_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + ).cuda() + + print(f"\n3. Trimming model to first {NUM_LAYERS_TO_KEEP} layers...") + # Keep only first N transformer blocks to reduce memory + original_num_layers = len(hf_model.model.layers) + print(f" Original layers: {original_num_layers}, keeping: {NUM_LAYERS_TO_KEEP}") + hf_model.model.layers = hf_model.model.layers[:NUM_LAYERS_TO_KEEP] + hf_model.config.num_hidden_layers = NUM_LAYERS_TO_KEEP + + # GPT-OSS has layer_types config that must match num_hidden_layers + if hasattr(hf_model.config, "layer_types"): + print(f" Original layer_types length: {len(hf_model.config.layer_types)}") + hf_model.config.layer_types = hf_model.config.layer_types[:NUM_LAYERS_TO_KEEP] + print(f" Trimmed layer_types: {hf_model.config.layer_types}") + + print(f"\n4. Saving trimmed dequantized model...") + print(f" Output: {dequantized_path}") + hf_model.save_pretrained(dequantized_path) + + print(f"\n✅ Checkpoint prepared successfully!") + print(f" Location: {dequantized_path}") + print(f" Vocab size: {hf_model.config.vocab_size}") + print(f" Hidden size: {hf_model.config.hidden_size}") + print(f" Num layers: {hf_model.config.num_hidden_layers}") + + # Free memory + del hf_model + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/test_activation.py b/test_activation.py new file mode 100644 index 000000000..4fa574627 --- /dev/null +++ b/test_activation.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +"""Test that the gpt_oss_glu activation matches HF implementation.""" + +import torch + + +# HF implementation (from the experts forward code) +def hf_activation(gate_up): + """ + HF GPT-OSS activation. + gate_up is interleaved [g0, u0, g1, u1, ...] + """ + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + alpha = 1.702 + limit = 7.0 + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + return (up + 1) * glu + + +# Fast-LLM implementation (from config.py) +def fast_llm_activation(x): + """ + Fast-LLM GPT-OSS activation. + x is concatenated [gate..., up...] + """ + gate, up = x.chunk(2, dim=-1) + alpha = 1.702 + limit = 7.0 + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + return (up + 1.0) * glu + + +# Test +torch.manual_seed(42) +batch, seq, dim = 2, 4, 8 + +# Create random gate and up +gate = torch.randn(batch, seq, dim) +up = torch.randn(batch, seq, dim) + +# HF format: interleaved +hf_input = torch.stack([gate, up], dim=-1).reshape(batch, seq, 2 * dim) +print("HF input shape:", hf_input.shape) +print("HF input [0,0,:10]:", hf_input[0, 0, :10]) + +# Fast-LLM format: concatenated +fl_input = torch.cat([gate, up], dim=-1) +print("\nFL input shape:", fl_input.shape) +print("FL input [0,0,:10]:", fl_input[0, 0, :10]) + +# Run both activations +hf_output = hf_activation(hf_input) +fl_output = fast_llm_activation(fl_input) + +print("\nHF output shape:", hf_output.shape) +print("HF output [0,0,:5]:", hf_output[0, 0, :5]) + +print("\nFL output shape:", fl_output.shape) +print("FL output [0,0,:5]:", fl_output[0, 0, :5]) + +# Compare +print("\nOutputs match:", torch.allclose(hf_output, fl_output, atol=1e-6)) +print("Max diff:", (hf_output - fl_output).abs().max().item()) diff --git a/test_converter.py b/test_converter.py new file mode 100644 index 000000000..f7f3cefd3 --- /dev/null +++ b/test_converter.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +"""Test that the gate_up_proj converter works correctly.""" + +import torch + +# Simulate HF format: (num_experts, in_features, 2*expert_dim) interleaved +num_experts = 2 +in_features = 4 +expert_dim = 3 +hf_gate_up = torch.randn(num_experts, in_features, 2 * expert_dim) + +print("HF format shape:", hf_gate_up.shape) +print("HF gate_up[0, 0, :]:", hf_gate_up[0, 0, :]) + +# HF extraction +hf_gate = hf_gate_up[:, :, 0::2] # even indices +hf_up = hf_gate_up[:, :, 1::2] # odd indices + +print("\nHF extracts:") +print(" gate[0, 0, :]:", hf_gate[0, 0, :]) +print(" up[0, 0, :]:", hf_up[0, 0, :]) + +# My converter (import) +gate = hf_gate_up[:, :, 0::2] # (num_experts, in_features, expert_dim) - even columns +up = hf_gate_up[:, :, 1::2] # (num_experts, in_features, expert_dim) - odd columns + +# Transpose each: (num_experts, expert_dim, in_features) +gate_t = gate.transpose(1, 2) +up_t = up.transpose(1, 2) + +# For each expert, concatenate gate and up +# Result: (num_experts, 2 * expert_dim, in_features) +weight_per_expert = torch.cat([gate_t, up_t], dim=1) + +# Reshape to (num_experts * 2 * expert_dim, in_features) +fast_llm_weight = weight_per_expert.reshape(num_experts * 2 * expert_dim, in_features) + +print("\nFast-LLM format shape:", fast_llm_weight.shape) +print("First expert gate (transposed):", fast_llm_weight[:expert_dim, :]) +print("First expert up (transposed):", fast_llm_weight[expert_dim : 2 * expert_dim, :]) + +# Now simulate Fast-LLM forward pass +# Input: (batch, seq, in_features) @ weight -> (batch, seq, expert_dim * 2) [concatenated gate, up] +input_vec = torch.randn(1, 1, in_features) +print("\nInput:", input_vec) + +# Fast-LLM: matmul gives [gate, up] concatenated +fast_llm_output = input_vec @ fast_llm_weight[: 2 * expert_dim, :].t() # First expert only +print("Fast-LLM output shape:", fast_llm_output.shape) +print("Fast-LLM output:", fast_llm_output) + +# Split into gate and up +fl_gate, fl_up = fast_llm_output.chunk(2, dim=-1) +print("Fast-LLM gate:", fl_gate) +print("Fast-LLM up:", fl_up) + +# HF: matmul gives [g0, u0, g1, u1, ...] interleaved +hf_output = input_vec @ hf_gate_up[0, :, :] # First expert: (1, 1, in_features) @ (in_features, 2*expert_dim) +print("\nHF output shape:", hf_output.shape) +print("HF output:", hf_output) + +# HF extracts +hf_gate_out = hf_output[:, :, 0::2] +hf_up_out = hf_output[:, :, 1::2] +print("HF gate:", hf_gate_out) +print("HF up:", hf_up_out) + +# Compare +print("\nGate match:", torch.allclose(fl_gate, hf_gate_out, atol=1e-5)) +print("Up match:", torch.allclose(fl_up, hf_up_out, atol=1e-5)) diff --git a/test_gpt_oss_debug.py b/test_gpt_oss_debug.py new file mode 100644 index 000000000..d4bd07e4f --- /dev/null +++ b/test_gpt_oss_debug.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" +Debug GPT-OSS forward pass differences. +Compare a single token through both models to identify divergence point. +""" + +import torch +import transformers + +from fast_llm.models.gpt.huggingface import GPTHuggingfaceModel + +# Set seed for reproducibility +torch.manual_seed(42) + +print("Loading HF model...") +hf_model = ( + transformers.AutoModelForCausalLM.from_pretrained( + "/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoints_tywyhgh1/dequantized_hf", + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +print("Loading Fast-LLM model...") +fast_llm_model = ( + GPTHuggingfaceModel.from_pretrained( + "/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoints_tywyhgh1/fast_llm", + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +# Create a single token input +test_input = torch.tensor([[199635]], device="cuda") +print(f"\nTest input: {test_input}") + +# Run HF model with hooks to capture intermediate values +hf_intermediates = {} + + +def make_hf_hook(name): + def hook(module, input, output): + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + hf_intermediates[name] = output_tensor.detach().float() + + return hook + + +# Register hooks on first layer components +hf_model.model.embed_tokens.register_forward_hook(make_hf_hook("embeddings")) +hf_model.model.layers[0].input_layernorm.register_forward_hook(make_hf_hook("layer0_norm1")) +hf_model.model.layers[0].self_attn.register_forward_hook(make_hf_hook("layer0_attn")) +hf_model.model.layers[0].post_attention_layernorm.register_forward_hook(make_hf_hook("layer0_norm2")) +hf_model.model.layers[0].mlp.router.register_forward_hook(make_hf_hook("layer0_router")) +hf_model.model.layers[0].mlp.register_forward_hook(make_hf_hook("layer0_mlp")) + +print("\nRunning HF model...") +with torch.no_grad(): + hf_output = hf_model(test_input) +hf_logits = hf_output.logits.float() + +print("\n=== HF Intermediate Values ===") +for name, tensor in hf_intermediates.items(): + print(f"{name}: shape={tensor.shape}, mean={tensor.mean():.6f}, std={tensor.std():.6f}") + if tensor.numel() <= 20: + print(f" values={tensor.flatten()[:20]}") + +# Now check Fast-LLM embeddings manually +print("\n=== Fast-LLM Manual Check ===") +# Get embedding weight from Fast-LLM +fl_embed_weight = fast_llm_model._model._embedding.embedding.weight.data +print(f"Fast-LLM embedding weight shape: {fl_embed_weight.shape}") +print(f"Fast-LLM embedding for token {test_input[0,0]}: {fl_embed_weight[test_input[0,0], :10]}") + +# Get HF embedding weight +hf_embed_weight = hf_model.model.embed_tokens.weight.data +print(f"HF embedding weight shape: {hf_embed_weight.shape}") +print(f"HF embedding for token {test_input[0,0]}: {hf_embed_weight[test_input[0,0], :10]}") + +print(f"\nEmbedding weights match: {torch.allclose(fl_embed_weight.float(), hf_embed_weight.float(), atol=1e-3)}") + +# Run Fast-LLM model +print("\nRunning Fast-LLM model...") +with torch.no_grad(): + fl_output = fast_llm_model(test_input) +fl_logits = fl_output.logits.float() + +print(f"\n=== Output Comparison ===") +print(f"HF logits: shape={hf_logits.shape}, mean={hf_logits.mean():.6f}, std={hf_logits.std():.6f}") +print(f"FL logits: shape={fl_logits.shape}, mean={fl_logits.mean():.6f}, std={fl_logits.std():.6f}") +print(f"Logits match: {torch.allclose(hf_logits, fl_logits, atol=0.01)}") +print(f"Max diff: {(hf_logits - fl_logits).abs().max():.6f}") +print(f"RMS diff: {((hf_logits - fl_logits) ** 2).mean().sqrt():.6f}") diff --git a/test_gpt_oss_forward_compare.py b/test_gpt_oss_forward_compare.py new file mode 100644 index 000000000..9be398ddc --- /dev/null +++ b/test_gpt_oss_forward_compare.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +""" +Step 2: Convert checkpoint and compare forward passes between HF and Fast-LLM. +""" + +import os +import pathlib +import sys + +import torch +import transformers + +from fast_llm.engine.checkpoint.config import ( + CheckpointLoadConfig, + CheckpointSaveConfig, + FastLLMCheckpointFormat, + ModelConfigType, +) +from fast_llm.engine.checkpoint.convert import ConvertConfig +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import GptOssCheckpointFormat +from tests.utils.compare_tensor_logs import CompareConfig + +# Set PyTorch memory allocator to use expandable segments +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + +sys.path.insert(0, "/home/ubuntu/Fast-LLM") + +# Configuration +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + + +def test_gpt_oss_forward_equivalence(): + """Test that HuggingFace and Fast-LLM produce equivalent outputs.""" + print("=" * 80) + print("Testing GPT-OSS Forward Pass Equivalence") + print("=" * 80) + + if not DEQUANTIZED_HF_PATH.exists(): + print(f"\n❌ Error: Checkpoint not found at {DEQUANTIZED_HF_PATH}") + print(f" Please run prepare_gpt_oss_checkpoint.py first!") + return False + + try: + # Load config to get vocab size + config = transformers.AutoConfig.from_pretrained(DEQUANTIZED_HF_PATH) + vocab_size = config.vocab_size + + print(f"\n1. Converting to Fast-LLM format...") + print(f" Source: {DEQUANTIZED_HF_PATH}") + print(f" Target: {FAST_LLM_PATH}") + + ConvertConfig( + input=CheckpointLoadConfig( + path=DEQUANTIZED_HF_PATH, + format=GptOssCheckpointFormat, + load_config=ModelConfigType.model, + ), + output=CheckpointSaveConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + ), + model=GPTModelConfig, + ).run() + + print(f"\n2. Creating test input...") + torch.manual_seed(42) + test_input = torch.randint( + 0, + vocab_size, + size=(2, 32), # Small batch and sequence length + dtype=torch.int64, + device="cuda", + ) + print(f" Input shape: {test_input.shape}") + print(f" Vocab size: {vocab_size}") + print(f" First 10 token IDs: {test_input[0, :10].tolist()}") + + print(f"\n3. Loading HuggingFace model and running forward pass...") + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ).cuda() + + # Add forward hooks for debugging + hf_activations = {} + + def make_hf_hook(name): + def hook(module, input, output): + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + hf_activations[name] = output_tensor.detach() + print( + f" HF {name}: shape={output_tensor.shape}, mean={output_tensor.mean().item():.6f}, std={output_tensor.std().item():.6f}" + ) + + return hook + + hf_model.model.embed_tokens.register_forward_hook(make_hf_hook("embeddings")) + hf_model.model.layers[0].self_attn.register_forward_hook(make_hf_hook("layer0_attn")) + hf_model.model.layers[0].mlp.register_forward_hook(make_hf_hook("layer0_mlp")) + hf_model.model.layers[0].register_forward_hook(make_hf_hook("layer0_output")) + if len(hf_model.model.layers) > 1: + hf_model.model.layers[1].register_forward_hook(make_hf_hook("layer1_output")) + hf_model.model.norm.register_forward_hook(make_hf_hook("final_norm")) + hf_model.lm_head.register_forward_hook(make_hf_hook("lm_head")) + + print(f" Running HuggingFace model...") + with torch.no_grad(): + hf_output = hf_model(test_input) + + # Save the output and free the model + hf_logits = hf_output.logits.clone().cpu() + del hf_model, hf_output + torch.cuda.empty_cache() + + # Memory cleanup + import gc + + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + + print(f" GPU memory after cleanup: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated") + + print(f"\n4. Loading Fast-LLM model and running forward pass...") + from fast_llm.engine.config_utils.logging import TensorLogs, TensorLogsConfig + from fast_llm.logging import set_model_debug_level + from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM + from fast_llm.models.gpt.model import GPTModel + + # Initialize TensorLogs and enable debug mode + TensorLogs.reset(TensorLogsConfig(save=False, show=True, max_elements=8)) + set_model_debug_level(3) + + print(f" Debug level set to: 3") + + # Load the base GPT model first + gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) + ) + + # Then wrap it with the HuggingFace interface + fast_llm_model = HuggingfaceGPTModelForCausalLM(gpt_model) + + print(f" Running Fast-LLM model...") + with torch.no_grad(): + fast_llm_output = fast_llm_model(test_input) + + fast_llm_logits = fast_llm_output.logits.clone() + + print(f"\n5. Comparing outputs...") + hf_logits = hf_logits.cuda() + + print(f" HF output shape: {hf_logits.shape}, dtype: {hf_logits.dtype}") + print(f" Fast-LLM output shape: {fast_llm_logits.shape}, dtype: {fast_llm_logits.dtype}") + + # Compare using Fast-LLM's comparison utility + errors = [] + CompareConfig().compare_tensors( + {"samples": hf_logits, "shape": hf_logits.shape, "step": 0}, + {"samples": fast_llm_logits, "shape": fast_llm_logits.shape, "step": 0}, + errors, + "HuggingFace vs Fast-LLM", + "logits", + ) + + if errors: + print(f"\n❌ Comparison failed:") + for error in errors: + print(f" {error}") + return False + + print(f"\n✅ Forward pass equivalence test passed!") + return True + + except Exception as e: + print(f"\n❌ Test failed:") + print(f" Error: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = test_gpt_oss_forward_equivalence() + sys.exit(0 if success else 1) diff --git a/test_gpt_oss_looped.py b/test_gpt_oss_looped.py new file mode 100644 index 000000000..b6c40edca --- /dev/null +++ b/test_gpt_oss_looped.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" +Test GPT-OSS forward pass using LOOPED MoE (not dropless) to isolate implementation differences. +""" + +import os +import pathlib +import sys + +import torch +import transformers + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat, ModelConfigType +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.models.gpt.model import GPTModel +from tests.utils.compare_tensor_logs import CompareConfig + +# Set PyTorch memory allocator +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + +sys.path.insert(0, "/home/ubuntu/Fast-LLM") + +# Configuration +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + +print("=" * 80) +print("Testing GPT-OSS Forward Pass with LOOPED MoE") +print("=" * 80) + +# Create test input +torch.manual_seed(42) +test_input = torch.randint(0, 201088, size=(1, 4), dtype=torch.int64, device="cuda") +print(f"\nTest input: {test_input}") + +# ============================================================================== +# Part 1: HuggingFace Model +# ============================================================================== +print("\n" + "=" * 80) +print("Part 1: HuggingFace Model") +print("=" * 80) + +hf_model = ( + transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +with torch.no_grad(): + hf_output = hf_model(test_input) + +hf_logits = hf_output.logits.clone().cpu() +print(f"HF logits shape: {hf_logits.shape}") +print(f"HF logits mean: {hf_logits.float().mean():.6f}, std: {hf_logits.float().std():.6f}") +print(f"HF logits [0, 0, :10]: {hf_logits[0, 0, :10].float()}") + +del hf_model +torch.cuda.empty_cache() + +# ============================================================================== +# Part 2: Fast-LLM Model with LOOPED MoE +# ============================================================================== +print("\n" + "=" * 80) +print("Part 2: Fast-LLM Model with LOOPED MoE (dropless=False)") +print("=" * 80) + +# Load model +gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) +) + +# Override dropless setting to force looped implementation +decoder_config = gpt_model.config.base_model.decoder +print(f"\nDecoder type: {type(decoder_config).__name__}") +print(f"Original dropless setting (full): {decoder_config.blocks['full'].mlp.dropless}") +print(f"Original dropless setting (sliding): {decoder_config.blocks['sliding'].mlp.dropless}") +decoder_config.blocks["full"].mlp.dropless = False +decoder_config.blocks["sliding"].mlp.dropless = False +print(f"Modified dropless setting: {decoder_config.blocks['full'].mlp.dropless}") + +# Re-initialize the MLP layers with the new config +# This is a bit hacky but necessary to apply the config change +for layer_idx, layer in enumerate(gpt_model.base_model.decoder): + mlp = layer.mlp + # Re-select the forward function based on updated config + dropless_moe = mlp._config.dropless + if dropless_moe and mlp._sequence_parallel: + import warnings + + warnings.warn( + "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." + ) + dropless_moe = False + mlp._mlp_forward = mlp._forward_dropless if dropless_moe else mlp._forward_looped + print(f"Layer {layer_idx}: Using {'dropless' if dropless_moe else 'looped'} MoE") + +# Wrap with HuggingFace interface +fast_llm_model = HuggingfaceGPTModelForCausalLM(gpt_model) + +with torch.no_grad(): + fast_llm_output = fast_llm_model(test_input) + +fast_llm_logits = fast_llm_output.logits.clone() +print(f"\nFast-LLM logits shape: {fast_llm_logits.shape}") +print(f"Fast-LLM logits mean: {fast_llm_logits.float().mean():.6f}, std: {fast_llm_logits.float().std():.6f}") +print(f"Fast-LLM logits [0, 0, :10]: {fast_llm_logits[0, 0, :10].float()}") + +# ============================================================================== +# Part 3: Comparison +# ============================================================================== +print("\n" + "=" * 80) +print("Part 3: Comparison") +print("=" * 80) + +hf_logits_gpu = hf_logits.cuda() +errors = [] +CompareConfig().compare_tensors( + {"samples": hf_logits_gpu, "shape": hf_logits_gpu.shape, "step": 0}, + {"samples": fast_llm_logits, "shape": fast_llm_logits.shape, "step": 0}, + errors, + "HuggingFace vs Fast-LLM (looped)", + "logits", +) + +if errors: + print(f"\n❌ Comparison failed:") + for error in errors: + print(f" {error}") +else: + print(f"\n✅ Forward pass outputs match!") + +print("\n" + "=" * 80) diff --git a/test_sparse_map_debug.py b/test_sparse_map_debug.py new file mode 100644 index 000000000..1456b26ae --- /dev/null +++ b/test_sparse_map_debug.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +""" +Comprehensive test suite for sparse_map_kernel debugging. + +This test compares the Triton kernel output against the PyTorch reference +implementation across various configurations to identify the bug. +""" + +import torch +import sys + +sys.path.insert(0, '/home/ubuntu/Fast-LLM') +from fast_llm.functional.triton.sparse_copy import get_sparse_map, sparse_map_pytorch + + +def test_sparse_map_correctness(num_experts, num_rows_dense, num_experts_per_token, seed=42): + """ + Test that Triton kernel produces same results as PyTorch reference. + + Args: + num_experts: Number of experts + num_rows_dense: Number of tokens (dense rows) + num_experts_per_token: Number of experts selected per token + seed: Random seed for reproducibility + """ + torch.manual_seed(seed) + top_experts = torch.randint(0, num_experts, (num_rows_dense, num_experts_per_token), device='cuda') + + # Get Triton result + sparse_map_triton = get_sparse_map(top_experts, num_experts=num_experts, use_triton=True) + + # Get PyTorch reference result + expert_ends_pt, expert_pad_begins_pt, sparse_rows_pt = sparse_map_pytorch( + top_experts.cpu(), num_experts=num_experts + ) + + # Compare results + expert_ends_match = torch.equal(sparse_map_triton.expert_ends.cpu(), expert_ends_pt) + expert_pad_begins_match = torch.equal(sparse_map_triton.expert_pad_begins.cpu(), expert_pad_begins_pt) + sparse_rows_match = torch.equal(sparse_map_triton.sparse_rows.cpu(), sparse_rows_pt) + + all_match = expert_ends_match and expert_pad_begins_match and sparse_rows_match + + if not all_match: + print(f"\n{'='*80}") + print(f"FAILED: experts={num_experts}, rows={num_rows_dense}, experts_per_token={num_experts_per_token}") + print(f"{'='*80}") + + if not expert_ends_match: + print(f"\n❌ expert_ends mismatch:") + print(f" Triton: {sparse_map_triton.expert_ends}") + print(f" PyTorch: {expert_ends_pt}") + + if not expert_pad_begins_match: + print(f"\n❌ expert_pad_begins mismatch:") + print(f" Triton: {sparse_map_triton.expert_pad_begins}") + print(f" PyTorch: {expert_pad_begins_pt}") + + if not sparse_rows_match: + print(f"\n❌ sparse_rows mismatch:") + print(f" Input top_experts:\n{top_experts}") + print(f"\n Triton sparse_rows:\n{sparse_map_triton.sparse_rows}") + print(f"\n PyTorch sparse_rows:\n{sparse_rows_pt}") + + # Find first mismatch + diff = (sparse_map_triton.sparse_rows.cpu() != sparse_rows_pt).nonzero() + if len(diff) > 0: + first_diff = diff[0] + print(f"\n First mismatch at position {first_diff.tolist()}:") + print(f" Triton: {sparse_map_triton.sparse_rows[first_diff[0], first_diff[1]].item()}") + print(f" PyTorch: {sparse_rows_pt[first_diff[0], first_diff[1]].item()}") + else: + print(f"✅ PASS: experts={num_experts}, rows={num_rows_dense}, experts_per_token={num_experts_per_token}") + + return all_match + + +def test_edge_cases(): + """Test various edge cases""" + print("\n" + "="*80) + print("Testing Edge Cases") + print("="*80) + + results = [] + + # Test 1: Minimal case + results.append(("Minimal (2 experts, 1 token)", test_sparse_map_correctness(2, 1, 1))) + + # Test 2: All tokens select same expert + print("\nTest: All tokens select same expert") + torch.manual_seed(100) + top_experts = torch.zeros((4, 2), dtype=torch.int64, device='cuda') # All select expert 0 + sparse_map_triton = get_sparse_map(top_experts, num_experts=4, use_triton=True) + _, _, sparse_rows_pt = sparse_map_pytorch(top_experts.cpu(), num_experts=4) + match = torch.equal(sparse_map_triton.sparse_rows.cpu(), sparse_rows_pt) + results.append(("All same expert", match)) + if not match: + print(f" Triton: {sparse_map_triton.sparse_rows}") + print(f" PyTorch: {sparse_rows_pt}") + else: + print(" ✅ PASS") + + # Test 3: Sequential experts + print("\nTest: Sequential expert selection") + top_experts = torch.arange(8, device='cuda').view(4, 2) % 4 + sparse_map_triton = get_sparse_map(top_experts, num_experts=4, use_triton=True) + _, _, sparse_rows_pt = sparse_map_pytorch(top_experts.cpu(), num_experts=4) + match = torch.equal(sparse_map_triton.sparse_rows.cpu(), sparse_rows_pt) + results.append(("Sequential experts", match)) + if not match: + print(f" Input: {top_experts}") + print(f" Triton: {sparse_map_triton.sparse_rows}") + print(f" PyTorch: {sparse_rows_pt}") + else: + print(" ✅ PASS") + + return results + + +def main(): + print("="*80) + print("SPARSE_MAP_KERNEL COMPREHENSIVE TEST SUITE") + print("="*80) + print(f"Device: CUDA") + print(f"Triton version: {__import__('triton').__version__}") + print(f"PyTorch version: {torch.__version__}") + import platform + print(f"Architecture: {platform.machine()}") + + results = [] + + # Test configurations from the actual failing test + print("\n" + "="*80) + print("Testing Actual Test Configuration") + print("="*80) + results.append(("Actual test config", test_sparse_map_correctness(4, 8, 4))) + + # Test various sizes + print("\n" + "="*80) + print("Testing Various Configurations") + print("="*80) + + test_configs = [ + # Small configs + (2, 4, 2, "Small: 2 experts, 4 tokens, 2 per token"), + (4, 4, 2, "Medium: 4 experts, 4 tokens, 2 per token"), + (4, 8, 2, "Medium: 4 experts, 8 tokens, 2 per token"), + + # Problematic config (experts_per_token=4) + (4, 16, 4, "Large: 4 experts, 16 tokens, 4 per token"), + (8, 8, 4, "Large: 8 experts, 8 tokens, 4 per token"), + + # Test with experts_per_token=1 + (4, 8, 1, "Simple: 4 experts, 8 tokens, 1 per token"), + (8, 16, 1, "Simple: 8 experts, 16 tokens, 1 per token"), + + # Test with experts_per_token=3 + (4, 8, 3, "Medium: 4 experts, 8 tokens, 3 per token"), + (8, 12, 3, "Medium: 8 experts, 12 tokens, 3 per token"), + + # Test different expert counts + (16, 32, 2, "Many experts: 16 experts, 32 tokens, 2 per token"), + (32, 64, 2, "Many experts: 32 experts, 64 tokens, 2 per token"), + + # Test with more tokens + (4, 32, 4, "More tokens: 4 experts, 32 tokens, 4 per token"), + (8, 64, 4, "More tokens: 8 experts, 64 tokens, 4 per token"), + + # Power of 2 variations + (4, 16, 2, "Power of 2: 4 experts, 16 tokens, 2 per token"), + (8, 16, 2, "Power of 2: 8 experts, 16 tokens, 2 per token"), + (16, 16, 2, "Power of 2: 16 experts, 16 tokens, 2 per token"), + + # Non-power of 2 + (5, 10, 2, "Non-pow2: 5 experts, 10 tokens, 2 per token"), + (7, 14, 3, "Non-pow2: 7 experts, 14 tokens, 3 per token"), + (12, 24, 4, "Non-pow2: 12 experts, 24 tokens, 4 per token"), + ] + + for num_experts, num_rows, experts_per_token, desc in test_configs: + results.append((desc, test_sparse_map_correctness(num_experts, num_rows, experts_per_token))) + + # Test edge cases + edge_results = test_edge_cases() + results.extend(edge_results) + + # Summary + print("\n" + "="*80) + print("TEST SUMMARY") + print("="*80) + passed = sum(1 for _, result in results if result) + total = len(results) + print(f"Passed: {passed}/{total}") + + if passed == total: + print("\n🎉 ALL TESTS PASSED!") + else: + print(f"\n❌ {total - passed} TESTS FAILED") + print("\nFailed tests:") + for name, result in results: + if not result: + print(f" - {name}") + + return passed == total + + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) diff --git a/test_triton_glu.py b/test_triton_glu.py new file mode 100644 index 000000000..d546da5e7 --- /dev/null +++ b/test_triton_glu.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +"""Test that Triton and Torch GPT-OSS GLU implementations match.""" + +import torch + +from fast_llm.functional.config import ActivationType +from fast_llm.functional.triton.mlp import torch_mlp_activation, triton_mlp_activation_forward + +# Set seed +torch.manual_seed(42) + +# Create test input: concatenated [gate, up] +batch, seq, dim = 2, 4, 128 +gate = torch.randn(batch, seq, dim, device="cuda") +up = torch.randn(batch, seq, dim, device="cuda") +input_concat = torch.cat([gate, up], dim=-1) # shape: (batch, seq, 2*dim) + +print(f"Input shape: {input_concat.shape}") +print(f"Gate [:5]: {gate[0, 0, :5]}") +print(f"Up [:5]: {up[0, 0, :5]}") + +# Run torch implementation +torch_output = torch_mlp_activation(input_concat, gated=True, activation_type=ActivationType.gpt_oss_glu) + +print(f"\nTorch output shape: {torch_output.shape}") +print(f"Torch output [0,0,:5]: {torch_output[0, 0, :5]}") + +# Run triton implementation +# Make input contiguous for Triton +input_concat_contig = input_concat.contiguous() +triton_output, _ = triton_mlp_activation_forward( + input_concat_contig, gated=True, activation_type=ActivationType.gpt_oss_glu +) + +print(f"\nTriton output shape: {triton_output.shape}") +print(f"Triton output [0,0,:5]: {triton_output[0, 0, :5]}") + +# Compare +print(f"\nOutputs match (atol=1e-5): {torch.allclose(torch_output, triton_output, atol=1e-5)}") +print(f"Max diff: {(torch_output - triton_output).abs().max().item()}") +print(f"RMS diff: {((torch_output - triton_output) ** 2).mean().sqrt().item()}") + +# Also check individual values +print(f"\nDetailed comparison:") +for i in range(min(5, dim)): + print( + f" dim {i}: torch={torch_output[0,0,i]:.6f}, triton={triton_output[0,0,i]:.6f}, diff={abs(torch_output[0,0,i] - triton_output[0,0,i]):.6e}" + ) diff --git a/tests/integration/README.md b/tests/integration/README.md new file mode 100644 index 000000000..72a78fd67 --- /dev/null +++ b/tests/integration/README.md @@ -0,0 +1,172 @@ +# Integration Tests + +These tests verify that real production models from the HuggingFace Hub can be converted to Fast-LLM format and produce equivalent forward pass results. + +## Overview + +The integration tests (`tests/integration/test_hub_integration.py`) perform the following steps: + +1. **Download real models** from HuggingFace Hub +2. **Truncate to first N layers** to reduce memory requirements (default: 2 layers) +3. **Convert to Fast-LLM format** +4. **Verify forward pass equivalence** between HuggingFace and Fast-LLM implementations +5. **Test implementation variants** where applicable (e.g., different kernel paths) + +## Test Flow (with Dependencies) + +Tests are organized with pytest dependencies to ensure proper execution order: + +1. `test_download_and_truncate_{model}` - Downloads and truncates model +2. `test_conversion_{model}` - Converts to Fast-LLM (depends on step 1) +3. `test_forward_equivalence_{model}` - Compares outputs (depends on step 2) +4. `test_{variant}_implementation_{model}` - Tests implementation variants (depends on step 2) + +## Why Skip by Default? + +These tests are marked with `@pytest.mark.extra_slow` and are **skipped by default** because they: +- Download large models from the Hub (multi-GB downloads) +- Require significant GPU memory +- Take considerable time to run + +## Running the Tests + +### Run all integration tests: +```bash +pytest tests/integration --run-extra-slow +``` + +### Run a specific test: +```bash +pytest tests/integration/test_hub_integration.py::test_hub_model_conversion --run-extra-slow +``` + +### Run with specific model: +```bash +pytest tests/integration -k mixtral --run-extra-slow +``` + +### Run only implementation variant tests: +```bash +pytest tests/integration -k "test_moe_implementation or test_implementation" --run-extra-slow +``` + +### Run with verbose output: +```bash +pytest tests/integration --run-extra-slow -v -s +``` + +## Test Structure + +### Test Functions + +1. **`test_download_and_truncate`** + - Downloads model from HuggingFace Hub + - Truncates to first N layers to reduce memory + - Verifies config is updated correctly + +2. **`test_conversion`** + - Converts truncated model to Fast-LLM format + - Verifies checkpoint files exist + +3. **`test_forward_equivalence`** + - Compares forward pass outputs between HF and Fast-LLM + - Uses CompareConfig with appropriate thresholds + - Scales thresholds per model as needed + +4. **`test_moe_implementation` (or other variants)** + - Parametrized tests for implementation variants + - Verifies all variants produce correct results + - Critical for ensuring correctness after code changes + +### Fixtures + +- **`hub_test_cache_dir`**: Temporary directory with automatic cleanup +- **`model_name`**: Parametrized fixture for model names +- **`model_config`**: Configuration for specific model +- **`truncated_hf_path`**: Downloads and truncates model from Hub +- **`fast_llm_path`**: Converts to Fast-LLM with default settings +- **`fast_llm_path_{variant}`**: Converts with specific variant settings + +## Supported Models + +Currently supported models in `HUB_TEST_CONFIGS`: + +- **Mixtral** (`mistralai/Mixtral-8x7B-v0.1`) + - Truncated to 2 layers + - Tests conversion and implementation variants + - Compare factor: 2.0 + +## Adding New Models + +To add a new model to the integration tests: + +1. Add configuration to `HUB_TEST_CONFIGS`: + +```python +HUB_TEST_CONFIGS["model_name"] = { + "model_id": "org/model-name", # HuggingFace Hub ID + "checkpoint_format": ModelCheckpointFormat, # Format class + "model_config": GPTModelConfig, # Model config class + "num_layers_to_keep": 2, # Number of layers after truncation + "test_params": { + "batch_size": 2, + "sequence_length": 32, + "compare_factor": 1.0, # Increase for models with higher numerical error + }, +} +``` + +2. Add the model name to the `model_name` fixture parameters: + +```python +@pytest.fixture(scope="module", params=["mixtral", "model_name"]) +def model_name(request): + return request.param +``` + +3. (Optional) Add variant-specific fixtures and tests if the model has multiple implementation paths + +## Requirements + +- **GPU Memory**: Tests require sufficient GPU memory (varies by model) +- **Disk Space**: Models are cached in temp directory during tests +- **Network**: HuggingFace Hub access for model downloads + +## Troubleshooting + +### Out of Memory (OOM) +Reduce batch size or sequence length in test_params, or use a larger GPU. + +### Download Failures +Check HuggingFace Hub access and network connectivity. Models may require authentication for gated models. + +### Comparison Failures +- Check if recent code changes affected model implementations or conversion logic +- Verify compare_factor is appropriate for the model architecture +- Review error messages for specific tensor mismatches +- Compare against known baseline results if available + +## Development Workflow + +After making changes to model code or conversion logic: + +1. Run local unit tests first +2. Run integration tests to verify real models still work: + ```bash + pytest tests/integration -k model_name --run-extra-slow + ``` +3. If tests fail, investigate numerical differences or conversion issues +4. Update compare thresholds only if the differences are acceptable and understood + +## CI/CD Integration + +These tests are **not part of the regular CI pipeline** due to their resource requirements. They should be run: + +- **Manually** before major releases +- **After significant changes** to model implementations or conversion code +- **Periodically** to ensure compatibility with upstream models + +To run in CI (if infrastructure supports it): +```bash +pytest tests/integration --run-extra-slow --tb=short +``` diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_hub_integration.py b/tests/integration/test_hub_integration.py new file mode 100644 index 000000000..17e04b273 --- /dev/null +++ b/tests/integration/test_hub_integration.py @@ -0,0 +1,421 @@ +""" +Integration tests for HuggingFace Hub model conversion and forward pass equivalence. + +These tests download real production models from HuggingFace Hub, truncate them to a small +number of layers to reduce memory requirements, convert them to Fast-LLM format, and verify +that the forward passes produce equivalent results to the original HuggingFace implementation. + +Test flow (with pytest dependencies): +1. test_download_and_truncate_{model} - Downloads and truncates model from Hub +2. test_conversion_{model} - Converts to Fast-LLM format (depends on step 1) +3. test_forward_equivalence_{model} - Compares HF vs Fast-LLM outputs (depends on step 2) +4. test_{variant}_implementation_{model} - Tests implementation variants (depends on step 2) + +These tests are marked as @pytest.mark.extra_slow and are skipped by default. +Run with: pytest tests/integration --run-extra-slow +""" + +import logging +import pathlib +import shutil + +import pytest +import torch +import transformers +from huggingface_hub import snapshot_download + +from fast_llm.engine.checkpoint.config import ( + CheckpointLoadConfig, + CheckpointSaveConfig, + FastLLMCheckpointFormat, + ModelConfigType, +) +from fast_llm.engine.checkpoint.convert import ConvertConfig +from fast_llm.engine.config_utils.logging import TensorLogs, TensorLogsConfig +from fast_llm.logging import set_model_debug_level +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.models.gpt.model import GPTModel +from tests.utils.compare_tensor_logs import CompareConfig +from tests.utils.utils import requires_cuda + +logger = logging.getLogger(__name__) + + +# Model configurations for hub integration tests +HUB_TEST_CONFIGS = { + "mixtral": { + "model_id": "mistralai/Mixtral-8x7B-v0.1", + "checkpoint_format": MixtralCheckpointFormat, + "model_config": GPTModelConfig, + "num_layers_to_keep": 2, # Truncate to 2 layers to reduce memory + "test_params": { + "batch_size": 2, + "sequence_length": 32, + "compare_factor": 2.0, # MoE models have higher numerical error + }, + }, +} + + +@pytest.fixture(scope="module", autouse=True) +def reset_gpu_memory_limit(): + """Reset GPU memory limit for integration tests (they need the full model).""" + if torch.cuda.is_available(): + # Reset to allow full GPU memory (tests/conftest.py limits to 5GB by default) + torch.cuda.set_per_process_memory_fraction(1.0, 0) + yield + + +@pytest.fixture(scope="module") +def hub_test_cache_dir(tmp_path_factory): + """Create a cache directory for hub integration tests.""" + cache_dir = tmp_path_factory.mktemp("hub_integration_cache") + yield cache_dir + # Cleanup after all tests complete + if cache_dir.exists(): + logger.info(f"Cleaning up cache directory: {cache_dir}") + shutil.rmtree(cache_dir, ignore_errors=True) + + +@pytest.fixture(scope="module", params=["mixtral"]) +def model_name(request): + """Parametrized fixture for model names.""" + return request.param + + +@pytest.fixture(scope="module") +def model_config(model_name): + """Get configuration for a specific model.""" + if model_name not in HUB_TEST_CONFIGS: + pytest.skip(f"Unknown model: {model_name}") + return HUB_TEST_CONFIGS[model_name] + + +@pytest.fixture(scope="module") +def truncated_hf_path(hub_test_cache_dir, model_name, model_config): + """ + Download model from HF Hub and truncate to first N layers to reduce memory. + + Steps: + 1. Download from HuggingFace Hub + 2. Load model (with any necessary dequantization) + 3. Truncate to num_layers_to_keep + 4. Update config (including model-specific fields) + 5. Save truncated model + """ + model_id = model_config["model_id"] + num_layers = model_config["num_layers_to_keep"] + truncated_path = hub_test_cache_dir / f"{model_name}_truncated" + + if truncated_path.exists(): + logger.info(f"Truncated model already exists at {truncated_path}") + return truncated_path + + logger.info(f"Downloading and truncating {model_id} to {num_layers} layers...") + + # Download from HF Hub + logger.info(f" Downloading from Hub: {model_id}") + hf_local_path = snapshot_download(repo_id=model_id, local_dir_use_symlinks=False) + hf_local_path = pathlib.Path(hf_local_path) + + # Load model on CPU to avoid OOM when loading full model + logger.info(f" Loading model on CPU...") + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_local_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map="cpu", + ) + + # Truncate to first N layers + logger.info(f" Truncating to {num_layers} layers...") + original_num_layers = len(hf_model.model.layers) + logger.info(f" Original layers: {original_num_layers}, keeping: {num_layers}") + hf_model.model.layers = hf_model.model.layers[:num_layers] + hf_model.config.num_hidden_layers = num_layers + + # Handle model-specific config updates (e.g., layer_types for GPT-OSS) + if hasattr(hf_model.config, "layer_types"): + hf_model.config.layer_types = hf_model.config.layer_types[:num_layers] + logger.info(f" Updated layer_types: {hf_model.config.layer_types}") + + # Save truncated model + logger.info(f" Saving truncated model to {truncated_path}") + hf_model.save_pretrained(truncated_path) + + # Also save tokenizer if available + try: + tokenizer = transformers.AutoTokenizer.from_pretrained(hf_local_path, trust_remote_code=True) + tokenizer.save_pretrained(truncated_path) + except Exception as e: + logger.warning(f" Failed to save tokenizer: {e}") + + logger.info(f"✓ Truncated model saved to {truncated_path}") + logger.info(f" Vocab size: {hf_model.config.vocab_size}") + logger.info(f" Hidden size: {hf_model.config.hidden_size}") + logger.info(f" Num layers: {hf_model.config.num_hidden_layers}") + + # Free CPU memory + del hf_model + + return truncated_path + + +@pytest.fixture(scope="module") +def fast_llm_path(hub_test_cache_dir, model_name, model_config, truncated_hf_path): + """Convert truncated HF model to Fast-LLM format (default MoE settings).""" + fast_llm_path = hub_test_cache_dir / f"{model_name}_fast_llm" + + if fast_llm_path.exists(): + logger.info(f"Fast-LLM checkpoint already exists at {fast_llm_path}") + return fast_llm_path + + logger.info(f"Converting {model_name} to Fast-LLM format (on CPU)...") + + ConvertConfig( + input=CheckpointLoadConfig( + path=truncated_hf_path, + format=model_config["checkpoint_format"], + load_config=ModelConfigType.model, + ), + output=CheckpointSaveConfig( + path=fast_llm_path, + format=FastLLMCheckpointFormat, + ), + model=model_config["model_config"], + use_cpu=True, # Convert on CPU to avoid OOM + ).run() + + logger.info(f"✓ Converted to {fast_llm_path}") + return fast_llm_path + + + + +# ============================================================================ +# Test 1: Download and Truncate +# ============================================================================ + + +@requires_cuda +@pytest.mark.extra_slow +def test_download_and_truncate(model_name, model_config, truncated_hf_path): + """Test that model can be downloaded and truncated.""" + assert truncated_hf_path.exists(), f"Truncated model not found at {truncated_hf_path}" + assert (truncated_hf_path / "config.json").exists(), "config.json not found" + + # Verify the truncation worked + config = transformers.AutoConfig.from_pretrained(truncated_hf_path, trust_remote_code=True) + expected_layers = model_config["num_layers_to_keep"] + assert config.num_hidden_layers == expected_layers, ( + f"Expected {expected_layers} layers, got {config.num_hidden_layers}" + ) + logger.info(f"✓ Model truncated to {config.num_hidden_layers} layers") + + +# ============================================================================ +# Test 2: Conversion +# ============================================================================ + + +@requires_cuda +@pytest.mark.extra_slow +@pytest.mark.depends_on(on=["test_download_and_truncate[{model_name}]"]) +def test_conversion(model_name, fast_llm_path): + """Test that truncated model can be converted to Fast-LLM format.""" + assert fast_llm_path.exists(), f"Fast-LLM checkpoint not found at {fast_llm_path}" + assert (fast_llm_path / "metadata.yaml").exists(), "metadata.yaml not found" + logger.info(f"✓ Conversion successful: {fast_llm_path}") + + +# ============================================================================ +# Test 3: Forward Pass Equivalence +# ============================================================================ + + +@requires_cuda +@pytest.mark.extra_slow +@pytest.mark.depends_on(on=["test_conversion[{model_name}]"]) +def test_forward_equivalence(model_name, model_config, truncated_hf_path, fast_llm_path): + """Test that HuggingFace and Fast-LLM produce equivalent forward pass results.""" + test_params = model_config["test_params"] + batch_size = test_params["batch_size"] + sequence_length = test_params["sequence_length"] + compare_factor = test_params.get("compare_factor", 1.0) + + # Load HF config to get vocab size + hf_config = transformers.AutoConfig.from_pretrained(truncated_hf_path, trust_remote_code=True) + vocab_size = hf_config.vocab_size + + # Create test input + torch.manual_seed(42) + test_input = torch.randint( + 0, + vocab_size, + size=(batch_size, sequence_length), + dtype=torch.int64, + device="cuda", + ) + + # Run HuggingFace model + logger.info("Loading HuggingFace model...") + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + truncated_hf_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ).cuda() + + with torch.no_grad(): + hf_output = hf_model(test_input) + + hf_logits = hf_output.logits.clone().cpu() + + # Cleanup HF model + del hf_model, hf_output + torch.cuda.empty_cache() + + # Run Fast-LLM model + logger.info("Loading Fast-LLM model...") + TensorLogs.reset(TensorLogsConfig(save=False, show=False)) + set_model_debug_level(0) + + gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=fast_llm_path, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) + ) + fast_llm_model = HuggingfaceGPTModelForCausalLM(gpt_model) + + with torch.no_grad(): + fast_llm_output = fast_llm_model(test_input) + + fast_llm_logits = fast_llm_output.logits.clone() + + # Compare outputs + logger.info("Comparing outputs...") + hf_logits = hf_logits.cuda() + + errors = [] + compare_config = CompareConfig() + if compare_factor != 1.0: + # Scale thresholds for models with higher numerical error (e.g., MoE) + compare_config = CompareConfig( + max_rms_diff_abs=compare_config.max_rms_diff_abs * compare_factor, + max_rms_diff_scaled=compare_config.max_rms_diff_scaled * compare_factor, + max_max_diff_abs=compare_config.max_max_diff_abs * compare_factor, + max_max_diff_scaled=compare_config.max_max_diff_scaled * compare_factor, + ) + + compare_config.compare_tensors( + {"samples": hf_logits, "shape": hf_logits.shape, "step": 0}, + {"samples": fast_llm_logits, "shape": fast_llm_logits.shape, "step": 0}, + errors, + f"{model_name}_HF_vs_FastLLM", + "logits", + ) + + if errors: + for error in errors: + logger.error(error) + pytest.fail(f"Forward pass comparison failed with {len(errors)} errors") + + logger.info(f"✓ Forward pass equivalence test passed for {model_name}") + + +# ============================================================================ +# Test 4: MoE Implementation Variants (Dropless vs Looped) +# ============================================================================ + + +@requires_cuda +@pytest.mark.extra_slow +@pytest.mark.depends_on(on=["test_conversion[{model_name}]"]) +def test_moe_implementation(model_name, model_config, fast_llm_path): + """Test that dropless and looped MoE implementations produce equivalent results.""" + # Only run for MoE models + if model_name not in ["mixtral"]: + pytest.skip(f"MoE implementation test not applicable for {model_name}") + + test_params = model_config["test_params"] + batch_size = test_params["batch_size"] + sequence_length = test_params["sequence_length"] + compare_factor = test_params.get("compare_factor", 1.0) + + # Load config to get vocab size + import yaml + with open(fast_llm_path / "metadata.yaml") as f: + metadata = yaml.safe_load(f) + vocab_size = metadata["config"]["base_model"]["embeddings"]["vocab_size"] + + # Create test input + torch.manual_seed(42) + test_input = torch.randint( + 0, + vocab_size, + size=(batch_size, sequence_length), + dtype=torch.int64, + device="cuda", + ) + + # Test both implementations + outputs = {} + for variant_name, dropless_value in [("dropless", True), ("looped", False)]: + logger.info(f"Testing {variant_name} MoE implementation (dropless={dropless_value})...") + TensorLogs.reset(TensorLogsConfig(save=False, show=False)) + set_model_debug_level(0) + + # Load model with config override + gpt_model = GPTModel.from_pretrained( + CheckpointLoadConfig( + path=fast_llm_path, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ), + {("base_model", "decoder", "block", "mlp", "dropless"): dropless_value}, + ) + fast_llm_model = HuggingfaceGPTModelForCausalLM(gpt_model) + + with torch.no_grad(): + output = fast_llm_model(test_input) + + outputs[variant_name] = output.logits.clone() + + # Cleanup + del gpt_model, fast_llm_model, output + torch.cuda.empty_cache() + + logger.info(f"✓ {variant_name} implementation forward pass complete") + + # Compare dropless vs looped implementations + logger.info("Comparing dropless vs looped implementations...") + errors = [] + compare_config = CompareConfig() + if compare_factor != 1.0: + # Scale thresholds for models with higher numerical error + compare_config = CompareConfig( + max_rms_diff_abs=compare_config.max_rms_diff_abs * compare_factor, + max_rms_diff_scaled=compare_config.max_rms_diff_scaled * compare_factor, + max_max_diff_abs=compare_config.max_max_diff_abs * compare_factor, + max_max_diff_scaled=compare_config.max_max_diff_scaled * compare_factor, + ) + + compare_config.compare_tensors( + {"samples": outputs["dropless"], "shape": outputs["dropless"].shape, "step": 0}, + {"samples": outputs["looped"], "shape": outputs["looped"].shape, "step": 0}, + errors, + f"{model_name}_dropless_vs_looped", + "logits", + ) + + if errors: + for error in errors: + logger.error(error) + pytest.fail(f"MoE implementation comparison failed with {len(errors)} errors") + + logger.info(f"✓ MoE implementation variant test passed for {model_name}") + + diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7b..3f358c4f7 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -15,6 +15,7 @@ AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, + GptOssCheckpointFormat, LlamaCheckpointFormat, MistralCheckpointFormat, MixtralCheckpointFormat, @@ -544,6 +545,8 @@ def _update_and_add_testing_config( updates={ ("model", "base_model", "decoder", "block", "mlp", "type"): "moe", ("model", "base_model", "decoder", "block", "mlp", "router", "weight"): init_1, + ("model", "base_model", "decoder", "block", "mlp", "layer_1", "type"): "moe_affine_linear", + ("model", "base_model", "decoder", "block", "mlp", "layer_2", "type"): "moe_affine_linear", ("model", "base_model", "decoder", "block", "mlp", "experts"): 4, ("model", "base_model", "decoder", "block", "mlp", "experts_per_token"): 4, }, @@ -694,6 +697,78 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests GPT-OSS: heterogeneous blocks (alternating sliding/full attention), MoE, YARN RoPE, attention biases. + "llama", + "gpt_oss", + updates={ + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "sliding": { + **copy.deepcopy(_llama_block), + "mixer": { + **copy.deepcopy(_llama_block["mixer"]), + "add_linear_biases": True, + "window_size": 128, + "rotary": {"type": "yarn"}, + "sinks": {"enabled": True, **init_1}, + }, + "mlp": { + "type": "moe", + "router": {"type": "affine_linear", "weight": init_1, "bias": {"enabled": True}}, + "layer_1": {"type": "moe_affine_linear", "weight": init_1, "bias": {"enabled": True}}, + "layer_2": {"type": "moe_affine_linear", "weight": init_2, "bias": {"enabled": True}}, + "experts": 4, + "experts_per_token": 4, + "intermediate_size": 1024, + "gated": True, + "activation": "silu", + "add_linear_biases": True, + }, + }, + "full": { + **copy.deepcopy(_llama_block), + "mixer": { + **copy.deepcopy(_llama_block["mixer"]), + "add_linear_biases": True, + "rotary": {"type": "yarn"}, + "sinks": {"enabled": True, **init_1}, + }, + "mlp": { + "type": "moe", + "router": {"type": "affine_linear", "weight": init_1, "bias": {"enabled": True}}, + "layer_1": {"type": "moe_affine_linear", "weight": init_1, "bias": {"enabled": True}}, + "layer_2": {"type": "moe_affine_linear", "weight": init_2, "bias": {"enabled": True}}, + "experts": 4, + "experts_per_token": 4, + "intermediate_size": 1024, + "gated": True, + "activation": "silu", + "add_linear_biases": True, + }, + }, + }, + "num_blocks": 4, + "pattern": ["sliding", "full"], + }, + }, + megatron_args=None, + checkpoint_format=GptOssCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=2.0, + # Micro-sequence split not supported (due to MoE). + skip_tests=("ms",), +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models") diff --git a/trace_mlp_detailed.py b/trace_mlp_detailed.py new file mode 100644 index 000000000..7fa1533c6 --- /dev/null +++ b/trace_mlp_detailed.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +""" +Add hooks to both HF and Fast-LLM MLP to trace intermediate values. +""" + +import pathlib + +import torch +import transformers + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat, ModelConfigType + +# Monkey-patch the mlp_autograd_looped to add tracing +from fast_llm.functional.triton import mlp as mlp_module +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM + +CHECKPOINT_DIR = pathlib.Path("/home/ubuntu/Fast-LLM/test_gpt_oss_checkpoint") +DEQUANTIZED_HF_PATH = CHECKPOINT_DIR / "dequantized_hf" +FAST_LLM_PATH = CHECKPOINT_DIR / "fast_llm" + +print("=" * 80) +print("Tracing MLP Components with Hooks") +print("=" * 80) + +# Create test input +torch.manual_seed(42) +test_input = torch.randint(0, 201088, size=(1, 4), dtype=torch.int64, device="cuda") # Smaller for detailed tracing +print(f"\nTest input shape: {test_input.shape}") +print(f"Test input: {test_input}") + +# ============================================================================== +# Part 1: Trace HuggingFace Model +# ============================================================================== +print("\n" + "=" * 80) +print("Part 1: HuggingFace Model") +print("=" * 80) + +hf_model = ( + transformers.AutoModelForCausalLM.from_pretrained( + DEQUANTIZED_HF_PATH, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + .cuda() + .eval() +) + +hf_traces = {} + + +# Hook the MLP experts to trace gate_up and activation +def make_hf_experts_hook(): + def hook(module, input, output): + # Save the input to experts + hf_traces["experts_input"] = input[0].detach().float() + hf_traces["experts_output"] = output.detach().float() + + return hook + + +hf_model.model.layers[0].mlp.experts.register_forward_hook(make_hf_experts_hook()) + +print("\nRunning HF model...") +with torch.no_grad(): + hf_output = hf_model(test_input) + +print( + f"HF experts input: shape={hf_traces['experts_input'].shape}, mean={hf_traces['experts_input'].mean():.6f}, std={hf_traces['experts_input'].std():.6f}" +) +print( + f"HF experts output: shape={hf_traces['experts_output'].shape}, mean={hf_traces['experts_output'].mean():.6f}, std={hf_traces['experts_output'].std():.6f}" +) +print( + f"HF final logits: shape={hf_output.logits.shape}, mean={hf_output.logits.mean():.6f}, std={hf_output.logits.std():.6f}" +) + +# Save for comparison +hf_logits = hf_output.logits.clone().cpu() + +del hf_model +torch.cuda.empty_cache() + +# ============================================================================== +# Part 2: Trace Fast-LLM Model +# ============================================================================== +print("\n" + "=" * 80) +print("Part 2: Fast-LLM Model") +print("=" * 80) + + +original_mlp_autograd_looped = mlp_module.mlp_autograd_looped +fl_traces = {} + + +def traced_mlp_autograd_looped( + hidden_states, + scores, + top_experts, + weight_1, + weight_2, + num_experts, + gated, + activation_type, + group, + sequence_parallel, + training, + recompute_level, + bias_1=None, + bias_2=None, +): + # Save inputs + fl_traces["mlp_input"] = hidden_states.detach().clone().cpu() + fl_traces["scores"] = scores.detach().clone().cpu() + fl_traces["top_experts"] = top_experts.detach().clone().cpu() + + # Call original + result = original_mlp_autograd_looped( + hidden_states, + scores, + top_experts, + weight_1, + weight_2, + num_experts, + gated, + activation_type, + group, + sequence_parallel, + training, + recompute_level, + bias_1, + bias_2, + ) + + # Save output + fl_traces["mlp_output"] = result.detach().clone().cpu() + + return result + + +mlp_module.mlp_autograd_looped = traced_mlp_autograd_looped + +fast_llm_model = HuggingfaceGPTModelForCausalLM.from_pretrained( + CheckpointLoadConfig( + path=FAST_LLM_PATH, + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) +) + +print("\nRunning Fast-LLM model...") +with torch.no_grad(): + fl_output = fast_llm_model(test_input) + +print( + f"FL MLP input: shape={fl_traces['mlp_input'].shape}, mean={fl_traces['mlp_input'].mean():.6f}, std={fl_traces['mlp_input'].std():.6f}" +) +print( + f"FL scores: shape={fl_traces['scores'].shape}, mean={fl_traces['scores'].mean():.6f}, std={fl_traces['scores'].std():.6f}" +) +print(f"FL top_experts: shape={fl_traces['top_experts'].shape}") +print(f"FL top_experts: {fl_traces['top_experts']}") +print( + f"FL MLP output: shape={fl_traces['mlp_output'].shape}, mean={fl_traces['mlp_output'].mean():.6f}, std={fl_traces['mlp_output'].std():.6f}" +) +print( + f"FL final logits: shape={fl_output.logits.shape}, mean={fl_output.logits.mean():.6f}, std={fl_output.logits.std():.6f}" +) + +# Compare +print("\n" + "=" * 80) +print("Comparison") +print("=" * 80) + +print(f"\nMLP input mean: HF={hf_traces['experts_input'].mean():.6f}, FL={fl_traces['mlp_input'].mean():.6f}") +print(f"MLP input std: HF={hf_traces['experts_input'].std():.6f}, FL={fl_traces['mlp_input'].std():.6f}") +print(f"MLP output mean: HF={hf_traces['experts_output'].mean():.6f}, FL={fl_traces['mlp_output'].mean():.6f}") +print(f"MLP output std: HF={hf_traces['experts_output'].std():.6f}, FL={fl_traces['mlp_output'].std():.6f}") + +fl_logits = fl_output.logits.cpu() +hf_logits = hf_logits.cuda() +fl_logits_gpu = fl_output.logits + +print(f"\nFinal logits mean: HF={hf_logits.float().mean():.6f}, FL={fl_logits.mean():.6f}") +print(f"Final logits std: HF={hf_logits.float().std():.6f}, FL={fl_logits.std():.6f}") +print(f"Logits max diff: {(hf_logits.float() - fl_logits_gpu.float()).abs().max():.6f}") +print(f"Logits RMS diff: {((hf_logits.float() - fl_logits_gpu.float()) ** 2).mean().sqrt():.6f}")