diff --git a/sharktank/sharktank/models/gpt_oss/orig_pytorch_model.py b/sharktank/sharktank/models/gpt_oss/orig_pytorch_model.py new file mode 100644 index 00000000000..dbbb6ffcbfc --- /dev/null +++ b/sharktank/sharktank/models/gpt_oss/orig_pytorch_model.py @@ -0,0 +1,386 @@ +import json +import math +import os +from dataclasses import dataclass + +import torch +import torch.distributed as dist + + +@dataclass +class ModelConfig: + num_hidden_layers: int = 36 + num_experts: int = 128 + experts_per_token: int = 4 + vocab_size: int = 201088 + hidden_size: int = 2880 + intermediate_size: int = 2880 + swiglu_limit: float = 7.0 + head_dim: int = 64 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + sliding_window: int = 128 + initial_context_length: int = 4096 + rope_theta: float = 150000.0 + rope_scaling_factor: float = 32.0 + rope_ntk_alpha: float = 1.0 + rope_ntk_beta: float = 32.0 + + +class RMSNorm(torch.nn.Module): + def __init__( + self, num_features: int, eps: float = 1e-05, device: torch.device | None = None + ): + super().__init__() + self.num_features = num_features + self.eps = eps + self.scale = torch.nn.Parameter( + torch.ones(num_features, device=device, dtype=torch.float32) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.shape[-1] == self.num_features + t, dtype = x.float(), x.dtype + t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps) + return (t * self.scale).to(dtype) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + x1, x2 = torch.chunk(x, 2, dim=-1) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + return torch.cat((o1, o2), dim=-1) + + +class RotaryEmbedding(torch.nn.Module): + def __init__( + self, + head_dim: int, + base: int, + dtype: torch.dtype, + initial_context_length: int = 4096, + scaling_factor: float = 1.0, + ntk_alpha: float = 1.0, + ntk_beta: float = 32.0, + device: torch.device | None = None, + ) -> None: + super().__init__() + self.head_dim = head_dim + self.base = base + self.dtype = dtype + self.initial_context_length = initial_context_length + self.scaling_factor = scaling_factor + self.ntk_alpha = ntk_alpha + self.ntk_beta = ntk_beta + self.device = device + + def _compute_concentration_and_inv_freq(self) -> torch.Tensor: + """See YaRN paper: https://arxiv.org/abs/2309.00071""" + freq = self.base ** ( + torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device) + / self.head_dim + ) + if self.scaling_factor > 1.0: + concentration = ( + 0.1 * math.log(self.scaling_factor) + 1.0 + ) # YaRN concentration + + d_half = self.head_dim / 2 + # NTK by parts + low = ( + d_half + * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi)) + / math.log(self.base) + ) + high = ( + d_half + * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi)) + / math.log(self.base) + ) + assert 0 < low < high < d_half - 1 + + interpolation = 1.0 / (self.scaling_factor * freq) + extrapolation = 1.0 / freq + + ramp = ( + torch.arange(d_half, dtype=torch.float32, device=freq.device) - low + ) / (high - low) + mask = 1 - ramp.clamp(0, 1) + + inv_freq = interpolation * (1 - mask) + extrapolation * mask + else: + concentration = 1.0 + inv_freq = 1.0 / freq + + return concentration, inv_freq + + def _compute_cos_sin(self, num_tokens: int): + concentration, inv_freq = self._compute_concentration_and_inv_freq() + t = torch.arange(num_tokens, dtype=torch.float32, device=self.device) + freqs = torch.einsum("i,j->ij", t, inv_freq) + cos = freqs.cos() * concentration + sin = freqs.sin() * concentration + return cos, sin + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = query.shape[0] + cos, sin = self._compute_cos_sin(num_tokens) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_dim) + query = _apply_rotary_emb(query, cos, sin) + query = query.reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_dim) + key = _apply_rotary_emb(key, cos, sin) + key = key.reshape(key_shape) + return query, key + + +def sdpa(Q, K, V, S, sm_scale, sliding_window=0): + # sliding_window == 0 means no sliding window + n_tokens, n_heads, q_mult, d_head = Q.shape + assert K.shape == (n_tokens, n_heads, d_head) + assert V.shape == (n_tokens, n_heads, d_head) + K = K[:, :, None, :].expand(-1, -1, q_mult, -1) + V = V[:, :, None, :].expand(-1, -1, q_mult, -1) + S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1) + mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1) + if sliding_window > 0: + mask += torch.tril( + mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window + ) + QK = torch.einsum("qhmd,khmd->hmqk", Q, K) + QK *= sm_scale + QK += mask[None, None, :, :] + QK = torch.cat([QK, S], dim=-1) + W = torch.softmax(QK, dim=-1) + W = W[..., :-1] + attn = torch.einsum("hmqk,khmd->qhmd", W, V) + return attn.reshape(n_tokens, -1) + + +class AttentionBlock(torch.nn.Module): + def __init__( + self, + config: ModelConfig, + layer_idx: int = 0, + device: torch.device | None = None, + ): + super().__init__() + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + # Only apply sliding window to every other layer + self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0 + self.sinks = torch.nn.Parameter( + torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16) + ) + self.norm = RMSNorm(config.hidden_size, device=device) + qkv_dim = config.head_dim * ( + config.num_attention_heads + 2 * config.num_key_value_heads + ) + self.qkv = torch.nn.Linear( + config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16 + ) + self.out = torch.nn.Linear( + config.head_dim * config.num_attention_heads, + config.hidden_size, + device=device, + dtype=torch.bfloat16, + ) + self.sm_scale = 1 / math.sqrt(config.head_dim) + self.rope = RotaryEmbedding( + config.head_dim, + config.rope_theta, + torch.float32, + initial_context_length=config.initial_context_length, + scaling_factor=config.rope_scaling_factor, + ntk_alpha=config.rope_ntk_alpha, + ntk_beta=config.rope_ntk_beta, + device=device, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + qkv = self.qkv(t) + q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous() + k = qkv[ + :, + self.num_attention_heads + * self.head_dim : (self.num_attention_heads + self.num_key_value_heads) + * self.head_dim, + ].contiguous() + v = qkv[ + :, + (self.num_attention_heads + self.num_key_value_heads) + * self.head_dim : (self.num_attention_heads + 2 * self.num_key_value_heads) + * self.head_dim, + ].contiguous() + + q = q.view( + -1, + self.num_key_value_heads, + self.num_attention_heads // self.num_key_value_heads, + self.head_dim, + ) + k = k.view(-1, self.num_key_value_heads, self.head_dim) + v = v.view(-1, self.num_key_value_heads, self.head_dim) + q, k = self.rope(q, k) + t = sdpa(q, k, v, self.sinks, self.sm_scale, self.sliding_window) + t = self.out(t) + t = x + t + return t + + +def swiglu(x, alpha: float = 1.702, limit: float = 7.0): + x_glu, x_linear = x[..., ::2], x[..., 1::2] + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + + +class MLPBlock(torch.nn.Module): + def __init__( + self, + config: ModelConfig, + device: torch.device | None = None, + ): + super().__init__() + self.num_experts = config.num_experts + self.experts_per_token = config.experts_per_token + self.swiglu_limit = config.swiglu_limit + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.norm = RMSNorm(config.hidden_size, device=device) + self.gate = torch.nn.Linear( + config.hidden_size, config.num_experts, device=device, dtype=torch.bfloat16 + ) + assert config.intermediate_size % self.world_size == 0 + self.mlp1_weight = torch.nn.Parameter( + torch.empty( + ( + config.num_experts, + config.intermediate_size * 2 // self.world_size, + config.hidden_size, + ), + device=device, + dtype=torch.bfloat16, + ) + ) + self.mlp1_bias = torch.nn.Parameter( + torch.empty( + (config.num_experts, config.intermediate_size * 2 // self.world_size), + device=device, + dtype=torch.bfloat16, + ) + ) + self.mlp2_weight = torch.nn.Parameter( + torch.empty( + ( + config.num_experts, + config.hidden_size, + config.intermediate_size // self.world_size, + ), + device=device, + dtype=torch.bfloat16, + ) + ) + self.mlp2_bias = torch.nn.Parameter( + torch.empty( + (config.num_experts, config.hidden_size), + device=device, + dtype=torch.bfloat16, + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + g = self.gate(t) + experts = torch.topk(g, k=self.experts_per_token, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + + # MLP #1 + mlp1_weight = self.mlp1_weight[expert_indices, ...] + mlp1_bias = self.mlp1_bias[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, limit=self.swiglu_limit) + + # MLP #2 + mlp2_weight = self.mlp2_weight[expert_indices, ...] + mlp2_bias = self.mlp2_bias[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + if self.world_size > 1: + dist.all_reduce(t, op=dist.ReduceOp.SUM) + t += mlp2_bias + + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + + return x + t + + +class TransformerBlock(torch.nn.Module): + def __init__( + self, + config: ModelConfig, + layer_idx: int, + device: torch.device | None = None, + ): + super().__init__() + self.layer_idx = layer_idx + self.attn = AttentionBlock(config, layer_idx, device) + self.mlp = MLPBlock(config, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn(x) + x = self.mlp(x) + return x + + +class Transformer(torch.nn.Module): + def __init__( + self, + config: ModelConfig, + device: torch.device | None = None, + ): + super().__init__() + self.embedding = torch.nn.Embedding( + config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16 + ) + self.block = torch.nn.ModuleList( + [ + TransformerBlock(config, layer_idx, device) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, device=device) + self.unembedding = torch.nn.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + device=device, + dtype=torch.bfloat16, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.embedding(x) + for block in self.block: + x = block(x) + x = self.norm(x) + x = self.unembedding(x) + return x diff --git a/sharktank/sharktank/models/gpt_oss/testing.py b/sharktank/sharktank/models/gpt_oss/testing.py new file mode 100644 index 00000000000..68d29dc6118 --- /dev/null +++ b/sharktank/sharktank/models/gpt_oss/testing.py @@ -0,0 +1,230 @@ +from typing import Optional, Callable +import torch + + +from sharktank.types.tensors import DefaultPrimitiveTensor +from sharktank.types.theta import Theta +from sharktank.layers.configs import LlamaModelConfig +from sharktank.utils.random import make_rand_torch +from sharktank.layers.testing import make_random_moe_block_theta +from sharktank.models.llama.testing import ( + make_wide_range_weights, + make_simple_calculable_weight_torch, +) + + +def make_gpt_oss_attention_block_theta( + *, + block_idx: int, + head_count: int, + head_count_kv: int, + head_dim: int, + embedding_length: int, + dtype: torch.dtype, + dtype_norm: torch.dtype, + weight_generator: Callable[ + [list[int], torch.dtype], torch.Tensor + ] = make_wide_range_weights, +) -> Theta: + """Create theta for GPT-OSS attention block with fused QKV weights.""" + + q_size = head_count * head_dim + k_size = head_count_kv * head_dim + v_size = head_count_kv * head_dim + qkv_size = q_size + k_size + v_size + + return Theta( + { + "attn_norm.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn_norm.weight", + data=weight_generator((embedding_length,), dtype_norm), + ), + "attn.wqkv.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn.wqkv.weight", + data=weight_generator((qkv_size, embedding_length), dtype), + ), + "attn.wqkv.bias": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn.wqkv.bias", + data=weight_generator((qkv_size,), dtype), + ), + "attn_output.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn_output.weight", + data=weight_generator((embedding_length, head_count * head_dim), dtype), + ), + "attn_output.bias": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn_output.bias", + data=weight_generator((embedding_length,), dtype), + ), + "attn_sinks": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.attn_sinks", + data=weight_generator((head_count,), dtype), + ), + } + ) + + +def make_gpt_oss_moe_block_theta( + *, + block_idx: int, + embedding_length: int, + expert_feed_forward_length: int, + expert_count: int, + dtype: torch.dtype, + dtype_norm: torch.dtype, + weight_generator: Callable[ + [list[int], torch.dtype], torch.Tensor + ] = make_wide_range_weights, +) -> Theta: + """Create theta for GPT-OSS MoE block.""" + + return Theta( + { + "ffn_gate_inp.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_gate_inp.weight", + data=weight_generator((expert_count, embedding_length), dtype), + ), + "ffn_gate_inp.bias": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_gate_inp.bias", + data=weight_generator((expert_count,), dtype), + ), + "ffn_gate_exps.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_gate_exps.weight", + data=weight_generator( + (expert_count, expert_feed_forward_length, embedding_length), dtype + ), + ), + "ffn_gate_exps.bias": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_gate_exps.bias", + data=weight_generator( + (expert_count, expert_feed_forward_length), dtype + ), + ), + "ffn_up_exps.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_up_exps.weight", + data=weight_generator( + (expert_count, expert_feed_forward_length, embedding_length), dtype + ), + ), + "ffn_up_exps.bias": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_up_exps.bias", + data=weight_generator( + (expert_count, expert_feed_forward_length), dtype + ), + ), + "ffn_down_exps.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_down_exps.weight", + data=weight_generator( + (expert_count, embedding_length, expert_feed_forward_length), dtype + ), + ), + "ffn_down_exps.bias": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_down_exps.bias", + data=weight_generator((expert_count, embedding_length), dtype), + ), + # Add norm scale weight for gpt-oss + "ffn_norm_scale.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_norm_scale.weight", + data=weight_generator((embedding_length,), dtype_norm), + ), + } + ) + + +def make_gpt_oss_attention_moe_block_theta( + block_idx: int, + config: LlamaModelConfig, + dtype_rest: torch.dtype, + dtype_norm: torch.dtype, + weight_generator: Callable[ + [list[int], torch.dtype], torch.Tensor + ] = make_wide_range_weights, +) -> Theta: + """Create combined attention + MoE block theta for GPT-OSS.""" + res_dict = {} + + # Attention part with fused QKV + attention_theta = make_gpt_oss_attention_block_theta( + block_idx=block_idx, + head_count=config.hp.attention_head_count, + head_count_kv=config.hp.attention_head_count_kv, + head_dim=config.hp.attn_head_dim, + embedding_length=config.hp.embedding_length, + dtype=dtype_rest, + dtype_norm=dtype_norm, + weight_generator=weight_generator, + ) + res_dict.update(attention_theta.tree) + + # MoE part + moe_theta = make_gpt_oss_moe_block_theta( + block_idx=block_idx, + embedding_length=config.hp.embedding_length, + expert_feed_forward_length=config.hp.expert_feed_forward_length, + expert_count=config.hp.expert_count, + dtype=dtype_rest, + dtype_norm=dtype_norm, + weight_generator=weight_generator, + ) + res_dict.update(moe_theta.tree) + + return Theta(res_dict) + + +def make_random_gpt_oss_theta( + config: LlamaModelConfig, + vocab_size: Optional[int] = None, + dtype_rest: torch.dtype = torch.bfloat16, + dtype_norm: torch.dtype = torch.bfloat16, + weight_generator: Callable[ + [list[int], torch.dtype], torch.Tensor + ] = make_wide_range_weights, +) -> Theta: + """Generate a GPT-OSS theta with configurable weight generation.""" + if vocab_size is None: + vocab_size = config.hp.vocab_size + + res = { + "token_embd.weight": DefaultPrimitiveTensor( + name="token_embd.weight", + data=weight_generator((vocab_size, config.hp.embedding_length), dtype_rest), + ) + } + + # Create blocks - all are MoE blocks for GPT-OSS + for i in range(config.hp.block_count): + block = make_gpt_oss_attention_moe_block_theta( + config=config, + block_idx=i, + dtype_rest=dtype_rest, + dtype_norm=dtype_norm, + weight_generator=weight_generator, + ).tree + res[f"blk.{i}"] = block + + # Output layers + res["output.weight"] = DefaultPrimitiveTensor( + name="output.weight", + data=weight_generator((vocab_size, config.hp.embedding_length), dtype_rest), + ) + res["output_norm.weight"] = DefaultPrimitiveTensor( + name="output_norm.weight", + data=weight_generator((config.hp.embedding_length,), dtype_norm), + ) + + return Theta(res) + + +def make_simple_analytical_gpt_oss_theta( + config: LlamaModelConfig, + vocab_size: Optional[int] = None, + dtype_rest: torch.dtype = torch.bfloat16, + dtype_norm: torch.dtype = torch.bfloat16, +) -> Theta: + """Generate a GPT-OSS theta with simple analytical weights for hand calculation.""" + return make_random_gpt_oss_theta( + config=config, + vocab_size=vocab_size, + dtype_rest=dtype_rest, + dtype_norm=dtype_norm, + weight_generator=make_simple_calculable_weight_torch, + ) diff --git a/sharktank/sharktank/models/gpt_oss/toy_gpt_oss.py b/sharktank/sharktank/models/gpt_oss/toy_gpt_oss.py new file mode 100644 index 00000000000..b2ff92e1f17 --- /dev/null +++ b/sharktank/sharktank/models/gpt_oss/toy_gpt_oss.py @@ -0,0 +1,301 @@ +"""Toy GPT-OSS model generator for testing and development.""" + +from typing import Callable +from sharktank.layers.configs import LlamaHParams, LlamaModelConfig +from sharktank.models.gpt_oss.testing import ( + make_random_gpt_oss_theta, + make_simple_analytical_gpt_oss_theta, + make_wide_range_weights, +) +from sharktank.types import Dataset + +import argparse +import torch + +parser = argparse.ArgumentParser() +parser.add_argument( + "-s", "--seed", default=12345, help="Random seed for deterministic generation" +) +parser.add_argument( + "-o", "--output", default="/tmp/toy_gpt_oss.irpa", help="Output file path" +) +parser.add_argument( + "--analytical", + action="store_true", + help="Use analytical model with simple weights for hand calculation", +) + + +def copy_weights_to_reference(shark_theta, ref_model, hp): + if ref_model is None: + return + + # Copy token embeddings + ref_model.embedding.weight.data = shark_theta("token_embd", "weight").as_torch() + + # Copy transformer blocks + for block_idx in range(hp.block_count): + ref_block = ref_model.block[block_idx] + ref_block.attn.norm.scale.data = ( + shark_theta("blk", block_idx, "attn_norm", "weight").as_torch().float() + ) + ref_block.attn.qkv.weight.data = shark_theta( + "blk", block_idx, "attn", "wqkv", "weight" + ).as_torch() + ref_block.attn.qkv.bias.data = shark_theta( + "blk", block_idx, "attn", "wqkv", "bias" + ).as_torch() + ref_block.attn.out.weight.data = shark_theta( + "blk", block_idx, "attn_output", "weight" + ).as_torch() + ref_block.attn.out.bias.data = shark_theta( + "blk", block_idx, "attn_output", "bias" + ).as_torch() + ref_block.attn.sinks.data = shark_theta( + "blk", block_idx, "attn_sinks" + ).as_torch() + + ref_block.mlp.norm.scale.data = ( + shark_theta("blk", block_idx, "ffn_norm_scale", "weight").as_torch().float() + ) + ref_block.mlp.gate.weight.data = shark_theta( + "blk", block_idx, "ffn_gate_inp", "weight" + ).as_torch() + ref_block.mlp.gate.bias.data = shark_theta( + "blk", block_idx, "ffn_gate_inp", "bias" + ).as_torch() + + # Concatenate gate and up weights for SwiGLU + gate_exps_weight = shark_theta( + "blk", block_idx, "ffn_gate_exps", "weight" + ).as_torch() + gate_exps_bias = shark_theta( + "blk", block_idx, "ffn_gate_exps", "bias" + ).as_torch() + up_exps_weight = shark_theta( + "blk", block_idx, "ffn_up_exps", "weight" + ).as_torch() + up_exps_bias = shark_theta("blk", block_idx, "ffn_up_exps", "bias").as_torch() + + num_experts = gate_exps_weight.shape[0] + intermediate_size = gate_exps_weight.shape[1] + hidden_size = gate_exps_weight.shape[2] + + mlp1_weight = torch.zeros( + num_experts, + intermediate_size * 2, + hidden_size, + dtype=gate_exps_weight.dtype, + device=gate_exps_weight.device, + ) + mlp1_bias = torch.zeros( + num_experts, + intermediate_size * 2, + dtype=gate_exps_bias.dtype, + device=gate_exps_bias.device, + ) + + mlp1_weight[:, :intermediate_size, :] = gate_exps_weight + mlp1_weight[:, intermediate_size:, :] = up_exps_weight + mlp1_bias[:, :intermediate_size] = gate_exps_bias + mlp1_bias[:, intermediate_size:] = up_exps_bias + + ref_block.mlp.mlp1_weight.data = mlp1_weight + ref_block.mlp.mlp1_bias.data = mlp1_bias + + ref_block.mlp.mlp2_weight.data = shark_theta( + "blk", block_idx, "ffn_down_exps", "weight" + ).as_torch() + ref_block.mlp.mlp2_bias.data = shark_theta( + "blk", block_idx, "ffn_down_exps", "bias" + ).as_torch() + + # Copy output layers + ref_model.norm.scale.data = shark_theta("output_norm", "weight").as_torch().float() + ref_model.unembedding.weight.data = shark_theta("output", "weight").as_torch() + + +def calculate_cross_entropy_manual( + model_instance, sequence: list[int], use_prefill: bool = True +) -> tuple[float, float]: + """Calculate cross entropy and perplexity manually for debugging.""" + evaluator = model_instance.make_perplexity_eval() + if use_prefill: + res = evaluator.prefill_cross_entropy([sequence])[0] + else: + res = evaluator.decode_cross_entropy([sequence])[0] + + assert res.valid + ce = res.score + ppl = float(torch.exp(torch.tensor(ce))) + + print("cross_entropy_nats:", ce) + print("perplexity:", ppl) + return ce, ppl + + +def generate( + seed: int, + dtype_rest: torch.dtype = torch.bfloat16, + dtype_norm: torch.dtype = torch.bfloat16, + weight_generator: Callable[ + [list[int], torch.dtype], torch.Tensor + ] = make_wide_range_weights, +): + """Generate a minimal deterministic GPT-OSS model for testing.""" + torch.manual_seed(seed) + + # Model architecture parameters + block_seq_stride = 16 + max_blocks = 8 + attention_head_count = 8 + attn_head_dim = 32 + attention_head_count_kv = 4 + rope_dimension_count = 32 + vocabulary_size = 128 + block_count = 3 + feed_forward_length = 64 + + # MoE parameters + expert_count = 4 + expert_used_count = 2 + expert_feed_forward_length = 32 + + config = LlamaModelConfig( + hp=LlamaHParams( + model_arch="gpt-oss", + vocab_size=vocabulary_size, + context_length=block_seq_stride * max_blocks, + embedding_length=attention_head_count * attn_head_dim, + block_count=block_count, + feed_forward_length=feed_forward_length, + attention_head_count=attention_head_count, + attn_head_dim=attn_head_dim, + attention_layer_norm_rms_epsilon=1e-5, + attention_head_count_kv=attention_head_count_kv, + rope_dimension_count=rope_dimension_count, + rope_freq_base=150000.0, + rope_interleave_emb=False, + yarn_factor=32.0, + yarn_beta_slow=1.0, + yarn_beta_fast=32.0, + yarn_original_context_len=4096, + expert_count=expert_count, + expert_used_count=expert_used_count, + expert_feed_forward_length=expert_feed_forward_length, + sliding_window=128, + swiglu_limit=7.0, + use_base_frequency_scaling=True, + use_fused_qkv=True, + topk_then_softmax=True, + use_residual_moe=True, + moe_block_type="PreGatherFFNMOE", + use_moe_swiglu=True, + ), + block_seq_stride=block_seq_stride, + activation_dtype=dtype_rest, + attention_dtype=dtype_rest, + ) + + theta = make_random_gpt_oss_theta( + config=config, + vocab_size=vocabulary_size, + dtype_rest=dtype_rest, + dtype_norm=dtype_norm, + weight_generator=weight_generator, + ) + return theta, config + + +def generate_analytical( + seed: int, + dtype_rest: torch.dtype = torch.bfloat16, + dtype_norm: torch.dtype = torch.bfloat16, +): + """Generate a minimal analytical GPT-OSS model with simple weights for hand calculation. + + This creates a tiny GPT-OSS model with all key features enabled but scaled down: + - Simple weights (0, 1, -1, 0.5, 2) for hand calculation + """ + torch.manual_seed(seed) + + # Minimal model architecture for analytical testing + block_seq_stride = 4 + max_blocks = 2 + attention_head_count = 2 + attn_head_dim = 2 + attention_head_count_kv = 1 + rope_dimension_count = 2 + vocabulary_size = 8 + block_count = 2 + feed_forward_length = 4 + + # Minimal MoE configuration + expert_count = 2 + expert_used_count = 1 + expert_feed_forward_length = 2 + + config = LlamaModelConfig( + hp=LlamaHParams( + model_arch="gpt-oss", + vocab_size=vocabulary_size, + context_length=block_seq_stride * max_blocks, + embedding_length=attention_head_count * attn_head_dim, + block_count=block_count, + feed_forward_length=feed_forward_length, + attention_head_count=attention_head_count, + attn_head_dim=attn_head_dim, + attention_layer_norm_rms_epsilon=1e-5, + attention_head_count_kv=attention_head_count_kv, + rope_dimension_count=rope_dimension_count, + rope_freq_base=10000.0, + rope_interleave_emb=False, + yarn_factor=1.0, + yarn_beta_slow=1.0, + yarn_beta_fast=1.0, + yarn_original_context_len=block_seq_stride * max_blocks, + expert_count=expert_count, + expert_used_count=expert_used_count, + expert_feed_forward_length=expert_feed_forward_length, + sliding_window=4, + swiglu_limit=7.0, + use_base_frequency_scaling=True, + use_fused_qkv=True, + topk_then_softmax=True, + use_residual_moe=True, + moe_block_type="PreGatherFFNMOE", + use_moe_swiglu=True, + ), + block_seq_stride=block_seq_stride, + activation_dtype=dtype_rest, + attention_dtype=dtype_rest, + ) + + theta = make_simple_analytical_gpt_oss_theta( + config=config, + vocab_size=vocabulary_size, + dtype_rest=dtype_rest, + dtype_norm=dtype_norm, + ) + return theta, config + + +def main(): + args = parser.parse_args() + + if args.analytical: + print("Generating analytical GPT-OSS model with simple weights...") + theta, config = generate_analytical(args.seed) + else: + print("Generating standard GPT-OSS model with wide-range weights...") + theta, config = generate(args.seed) + + # Convert to GGUF format and save + config_dict = config.hp.to_gguf_props() + dataset = Dataset(config_dict, theta) + dataset.save(args.output) + print(f"Model saved to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py index c3cea56beed..5da1a32e659 100644 --- a/sharktank/sharktank/models/llama/testing.py +++ b/sharktank/sharktank/models/llama/testing.py @@ -8,6 +8,7 @@ import functools import torch import re +import math from sharktank.types.tensors import * from sharktank.types import DynamicFp4BlockQuantizer, StaticScaledQuantizer from sharktank.types.theta import Theta @@ -20,6 +21,44 @@ ) +def make_wide_range_weights( + shape: list[int], dtype: torch.dtype = torch.bfloat16 +) -> torch.Tensor: + """Generate weights with proper variance scaling to prevent numerical explosions. + + Uses Xavier-like initialization: scale by 1/sqrt(fan_in) to keep output variance + stable regardless of layer dimensions. The 0.8 factor provides diversity while + maintaining numerical stability. + + """ + seed = 12345 + generator = torch.Generator() + generator.manual_seed(seed) + fan_in = shape[-1] if len(shape) > 1 else shape[0] + std = 0.8 / math.sqrt(fan_in) + weights = torch.randn(shape, dtype=dtype, generator=generator) * std + + return weights + + +def make_simple_calculable_weight_torch( + shape: list[int], dtype: torch.dtype = torch.bfloat16 +) -> torch.Tensor: + """ + Create simple weights that can be calculated by hand for analytical testing. + """ + weights = torch.zeros(shape, dtype=dtype) + flat_weights = weights.view(-1) + + # Simple pattern: 0, 1, -1, 0.5, 2, repeat... + simple_values = [0.0, 1.0, -1.0, 0.5, 2.0] + + for i in range(flat_weights.numel()): + flat_weights[i] = simple_values[i % len(simple_values)] + + return weights + + def make_attention_block_theta( feature_dim: int, ffn_dim: int, diff --git a/sharktank/tests/models/gpt_oss/forward_pass_comparison.py b/sharktank/tests/models/gpt_oss/forward_pass_comparison.py new file mode 100644 index 00000000000..5ab1f55f458 --- /dev/null +++ b/sharktank/tests/models/gpt_oss/forward_pass_comparison.py @@ -0,0 +1,500 @@ +"""Forward pass comparison tests between sharktank and reference GPT-OSS implementations.""" + +import unittest +import torch +import torch.nn.functional as F +import logging + +from sharktank.models.gpt_oss.toy_gpt_oss import ( + generate_analytical, + copy_weights_to_reference, +) +from sharktank.models.gpt_oss.orig_pytorch_model import ( + ModelConfig, + Transformer, + RotaryEmbedding, + sdpa, +) +from sharktank.utils.llm_utils import TorchInstance + +from sharktank.layers.paged_attention import PagedGQAttention, build_cache +import math + + +class ForwardPassComparisonTest(unittest.TestCase): + def setUp(self): + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + + torch.set_default_dtype(torch.bfloat16) + self.seed = 12345 + torch.manual_seed(self.seed) + self.test_sequence = [0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0] + self.expected_pattern = [0.0, 1.0, -1.0, 0.5, 2.0] + self.input_tokens = torch.tensor([[0, 1, 2]], dtype=torch.long) + self.initialized_model() + + def initialized_model(self): + # Initialize sharktank model + self.shark_theta, self.shark_config = generate_analytical(self.seed) + self.hp = self.shark_config.hp + self.shark_model = TorchInstance( + theta=self.shark_theta, config=self.shark_config + )._model + self.shark_model.eval() + # Initialize reference model + ref_config = ModelConfig( + num_hidden_layers=self.hp.block_count, + num_experts=self.hp.expert_count, + experts_per_token=self.hp.expert_used_count, + vocab_size=self.hp.vocab_size, + hidden_size=self.hp.embedding_length, + intermediate_size=self.hp.feed_forward_length, + head_dim=self.hp.attn_head_dim, + num_attention_heads=self.hp.attention_head_count, + num_key_value_heads=self.hp.attention_head_count_kv, + sliding_window=self.hp.sliding_window if self.hp.sliding_window else 128, + initial_context_length=self.hp.context_length, + rope_theta=self.hp.rope_freq_base, + rope_scaling_factor=1.0, # Disable YARN for tiny head_dim compatibility + ) + + self.ref_model = Transformer(ref_config, device=torch.device("cpu")) + self.ref_model.eval() + + self.ref_model.vocab_size = ref_config.vocab_size + self.ref_model.hidden_size = ref_config.hidden_size + self.ref_model.num_hidden_layers = ref_config.num_hidden_layers + + copy_weights_to_reference(self.shark_theta, self.ref_model, self.hp) + + def test_token_embeddings(self): + """Test token embeddings match expected analytical pattern.""" + with torch.no_grad(): + shark_emb = self.shark_model.token_embedding(self.input_tokens) + ref_emb = self.ref_model.embedding(self.input_tokens) + + expected_embeddings = { + 0: [0.0, 1.0, -1.0, 0.5], + 1: [2.0, 0.0, 1.0, -1.0], + 2: [0.5, 2.0, 0.0, 1.0], + } + + for token_idx in range(3): + shark_values = shark_emb[0, token_idx, :].tolist() + ref_values = ref_emb[0, token_idx, :].tolist() + expected = expected_embeddings[token_idx] + + for i, (actual, exp) in enumerate(zip(shark_values, expected)): + self.assertAlmostEqual( + actual, + exp, + places=3, + msg=f"Token {token_idx} position {i} mismatch", + ) + + match = torch.allclose( + shark_emb[0, token_idx, :], + ref_emb[0, token_idx, :], + rtol=1e-4, + atol=1e-4, + ) + self.assertTrue(match) + + self.logger.debug( + f"Token {token_idx} - Sharktank: {shark_values}, Reference: {ref_values}, Expected: {expected}" + ) + + def test_rmsnorm(self): + """Test RMSNorm computation.""" + attn_norm_weight = self.shark_theta("blk", 0, "attn_norm", "weight").as_torch() + shark_emb = self.shark_model.token_embedding(self.input_tokens) + x = shark_emb[0, 0, :] # Token 0: [0.0, 1.0, -1.0, 0.5] + + # Manual RMSNorm calculation + x_float = x.float() + mean_sq = torch.mean(x_float**2) + eps = 1e-5 + rms = torch.sqrt(mean_sq + eps) + normalized = x_float / rms + normed_scaled = normalized * attn_norm_weight.float() + with torch.no_grad(): + shark_block = self.shark_model.attn_blocks[0] + shark_norm = shark_block.attn.attn_norm(shark_emb) + shark_result = shark_norm[0, 0, :] + + ref_emb = self.ref_model.embedding(self.input_tokens) + ref_block = self.ref_model.block[0] + ref_norm = ref_block.attn.norm(ref_emb) + ref_result = ref_norm[0, 0, :] + + torch.testing.assert_close( + shark_result.float(), normed_scaled, rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + ref_result.float(), normed_scaled, rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + shark_result.float(), ref_result.float(), rtol=1e-4, atol=1e-4 + ) + + self.logger.debug( + f"RMSNorm - Sharktank: {shark_result.tolist()}, Reference: {ref_result.tolist()}, Hand: {normed_scaled.tolist()}" + ) + + def test_qkv_projection(self): + """Test fused QKV projection.""" + qkv_weight = self.shark_theta("blk", 0, "attn", "wqkv", "weight").as_torch() + qkv_bias = self.shark_theta("blk", 0, "attn", "wqkv", "bias").as_torch() + + with torch.no_grad(): + shark_emb = self.shark_model.token_embedding(self.input_tokens) + shark_block = self.shark_model.attn_blocks[0] + h_norm = shark_block.attn.attn_norm(shark_emb)[0, 0, :] + + # Manual QKV calculation + hand_qkv_output = torch.matmul(h_norm, qkv_weight.t()) + qkv_bias + q_size = self.hp.attention_head_count * self.hp.attn_head_dim + kv_size = self.hp.attention_head_count_kv * self.hp.attn_head_dim + + hand_q = hand_qkv_output[:q_size] + hand_k = hand_qkv_output[q_size : q_size + kv_size] + hand_v = hand_qkv_output[q_size + kv_size : q_size + 2 * kv_size] + + with torch.no_grad(): + shark_qkv = shark_block.attn.attn_qkv(h_norm) + shark_q = shark_qkv[:q_size] + shark_k = shark_qkv[q_size : q_size + kv_size] + shark_v = shark_qkv[q_size + kv_size : q_size + 2 * kv_size] + + ref_emb = self.ref_model.embedding(self.input_tokens) + ref_block = self.ref_model.block[0] + ref_norm = ref_block.attn.norm(ref_emb)[0, 0, :] + ref_qkv = ref_block.attn.qkv(ref_norm) + + ref_q = ref_qkv[:q_size] + ref_k = ref_qkv[q_size : q_size + kv_size] + ref_v = ref_qkv[q_size + kv_size : q_size + 2 * kv_size] + torch.testing.assert_close(shark_q, hand_q, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(shark_k, hand_k, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(shark_v, hand_v, rtol=1e-4, atol=1e-4) + + torch.testing.assert_close(ref_q, hand_q, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(ref_k, hand_k, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(ref_v, hand_v, rtol=1e-4, atol=1e-4) + + self.logger.debug( + f"QKV - Q: Hand={hand_q.tolist()}, Shark={shark_q.tolist()}, Ref={ref_q.tolist()}" + ) + + def test_rotary_embedding(self): + """Test RoPE implementation with simple 0/1 values.""" + bs = 1 + seq_len = 3 + n_heads = self.hp.attention_head_count + head_dim = self.hp.attn_head_dim + + q_ref = torch.tensor( + [ + [[1.0, 0.0], [1.0, 0.0]], # token 0: [1, 0] for both heads + [[0.0, 1.0], [0.0, 1.0]], + [[1.0, 0.0], [1.0, 0.0]], + ], + dtype=torch.bfloat16, + ) + + k_ref = torch.tensor( + [ + [[1.0, 0.0], [1.0, 0.0]], + [[0.0, 1.0], [0.0, 1.0]], + [[1.0, 0.0], [1.0, 0.0]], + ], + dtype=torch.bfloat16, + ) + + q_shark = q_ref.unsqueeze(0) + k_shark = k_ref.unsqueeze(0) + position_ids = torch.arange(0, seq_len, device=q_shark.device)[None, :].repeat( + bs, 1 + ) + + # YARN disabled (factor=1.0) to avoid assertion with head_dim=2 + with torch.no_grad(): + from sharktank.layers.rotary_embedding_hf import RotaryEmbeddingLayer + + shark_rope = RotaryEmbeddingLayer( + head_dim=head_dim, + rope_theta=self.hp.rope_freq_base, + use_base_frequency_scaling=False, + interleaved=False, + yarn_beta_slow=self.hp.yarn_beta_slow, + yarn_beta_fast=self.hp.yarn_beta_fast, + yarn_factor=1.0, + yarn_original_context_len=self.hp.yarn_original_context_len, + ) + cossin_cache = shark_rope.compute_sincos_cache(position_ids, q_shark.dtype) + shark_q = shark_rope(q_shark, cossin_cache) + shark_k = shark_rope(k_shark, cossin_cache) + + ref_rope = RotaryEmbedding( + head_dim=head_dim, + base=int(self.hp.rope_freq_base), + dtype=torch.bfloat16, + initial_context_length=self.hp.yarn_original_context_len, + scaling_factor=1.0, + ntk_alpha=self.hp.yarn_beta_slow, + ntk_beta=self.hp.yarn_beta_fast, + ) + ref_q, ref_k = ref_rope(q_ref, k_ref) + torch.testing.assert_close(shark_q.squeeze(0), ref_q, rtol=2e-2, atol=1e-2) + torch.testing.assert_close(shark_k.squeeze(0), ref_k, rtol=2e-2, atol=1e-2) + + self.logger.debug( + f"RoPE - Shark Q[0,0,0,:]: {shark_q[0, 0, 0, :].tolist()}, Ref Q[0,0,:]: {ref_q[0, 0, :].tolist()}" + ) + + def test_sdpa_vs_paged_attention_prefill(self): + """Compare reference sdpa with sharktank paged attention.""" + + seq_len = 6 + bs = 1 + n_kv_heads = self.hp.attention_head_count_kv + n_heads = self.hp.attention_head_count + head_dim = self.hp.attn_head_dim + q_mult = n_heads // n_kv_heads + dtype = torch.bfloat16 + + # Deterministic test tensors + q_ref = torch.tensor( + [ + [[[1.0, 0.0], [1.0, 0.0]]], + [[[0.0, 1.0], [0.0, 1.0]]], + [[[1.0, 1.0], [1.0, 1.0]]], + [[[0.0, 0.0], [0.0, 0.0]]], + [[[1.0, 0.0], [1.0, 0.0]]], + [[[0.0, 1.0], [0.0, 1.0]]], + ], + dtype=dtype, + ) + + k_ref = torch.tensor( + [ + [[1.0, 0.0]], + [[0.0, 1.0]], + [[1.0, 1.0]], + [[0.0, 0.0]], + [[1.0, 0.0]], + [[0.0, 1.0]], + ], + dtype=dtype, + ) + + v_ref = torch.tensor( + [ + [[1.0, 0.0]], + [[0.0, 1.0]], + [[1.0, 1.0]], + [[0.0, 0.0]], + [[1.0, 0.0]], + [[0.0, 1.0]], + ], + dtype=dtype, + ) + + ref_sinks = self.ref_model.block[0].attn.sinks + sm_scale = 1.0 / (head_dim**0.5) + sliding_window = self.hp.sliding_window if self.hp.sliding_window else 0 + + with torch.no_grad(): + # Reference implementation + ref_out = sdpa( + q_ref, + k_ref, + v_ref, + S=ref_sinks, + sm_scale=sm_scale, + sliding_window=sliding_window, + ) + ref_out = ref_out.view(seq_len, n_kv_heads, q_mult, head_dim) + + # Paged attention setup + kv_cache = build_cache( + transformer_block_count=1, + attn_head_count=n_kv_heads, + attn_head_dim=head_dim, + block_seq_stride=seq_len, + cache_dtype=dtype, + ) + pa = PagedGQAttention( + kv_cache=kv_cache, + transformer_block_index=0, + attn_dtype=dtype, + activation_dtype=dtype, + use_rope=False, + attention_chunk_size=None, + ) + + # Convert to paged attention format + q_pa = q_ref.unsqueeze(0).reshape(bs, seq_len, n_heads, head_dim) + k_pa = k_ref.unsqueeze(0).reshape(bs, seq_len, n_kv_heads, head_dim) + v_pa = v_ref.unsqueeze(0).reshape(bs, seq_len, n_kv_heads, head_dim) + + # Cache setup + blocks = math.ceil(seq_len / kv_cache.block_seq_stride) + seq_block_ids = torch.arange(blocks, dtype=torch.int64).unsqueeze(0) + cache_state = pa.allocate(page_count=blocks) + seq_lens = torch.tensor([seq_len], dtype=torch.long) + + # Paged attention forward + pa_out = pa.forward_prefill( + q=q_pa, + k=k_pa, + v=v_pa, + cache_state=cache_state, + seq_lens=seq_lens, + seq_block_ids=seq_block_ids, + attention_kernel="decomposed", + head_count_attn=n_heads, + cache_quantizer=None, + fake_quant=False, + scale=sm_scale, + sliding_window=sliding_window, + sink=ref_sinks, + ) + pa_out = pa_out.squeeze(0).permute(1, 0, 2) + pa_out = pa_out.reshape(seq_len, n_kv_heads, q_mult, head_dim) + + torch.testing.assert_close( + ref_out, + pa_out, + rtol=2e-2, + atol=2e-2, + msg="SDPA implementations should match", + ) + + def test_moe_block(self): + """Test MoE block.""" + shark_moe = self.shark_model.attn_blocks[0].ffn + ref_moe = self.ref_model.block[0].mlp + + torch.manual_seed(self.seed) + test_input = torch.tensor( + [[[0.5, 1.0, -0.5, 0.25], [1.0, 0.5, 0.25, -0.5]]], dtype=torch.bfloat16 + ) + + with torch.no_grad(): + shark_moe_out = shark_moe(test_input) + + ref_input = test_input.squeeze(0) + ref_moe_out = ref_moe(ref_input) + ref_moe_out = ref_moe_out.unsqueeze(0) + + shark_moe_out_bf16 = shark_moe_out.to(torch.bfloat16) + torch.testing.assert_close( + shark_moe_out_bf16, ref_moe_out, rtol=1e-2, atol=1e-2 + ) + + self.logger.debug( + f"MoE - Experts: {self.hp.expert_count}, Used: {self.hp.expert_used_count}" + ) + + def test_e2e_prefill_cross_entropy(self): + """Compare sharktank vs reference prefill cross-entropy. + + Manual calculation ce and ppl: calculate_cross_entropy_manual(instance, self.sequence, use_prefill=False) + """ + + shark_ce = 3.3252 + shark_ppl = 27.8750 + + with torch.no_grad(): + input_ids = torch.tensor([self.test_sequence], dtype=torch.long) + ref_input = input_ids.squeeze(0) + ref_logits = self.ref_model(ref_input) + + shift_logits = ref_logits[:-1, :].contiguous() + shift_labels = ref_input[1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss(reduction="mean") + ref_ce = loss_fct(shift_logits, shift_labels).item() + ref_ppl = float(torch.exp(torch.tensor(ref_ce))) + + self.logger.info( + f"Prefill CE - Sharktank: {shark_ce:.4f}, Reference: {ref_ce:.4f}, Diff: {abs(shark_ce - ref_ce):.4f}" + ) + self.logger.info( + f"Prefill PPL - Sharktank: {shark_ppl:.4f}, Reference: {ref_ppl:.4f}, Diff: {abs(shark_ppl - ref_ppl):.4f}" + ) + + torch.testing.assert_close( + torch.tensor(shark_ce), + torch.tensor(ref_ce), + rtol=0.15, + atol=0.15, + msg=f"Prefill CE mismatch: shark={shark_ce:.4f} vs ref={ref_ce:.4f}", + ) + + torch.testing.assert_close( + torch.tensor(shark_ppl), + torch.tensor(ref_ppl), + rtol=0.15, + atol=0.15, + msg=f"Prefill PPL mismatch: shark={shark_ppl:.4f} vs ref={ref_ppl:.4f}", + ) + + def test_e2e_decode_cross_entropy(self): + """Compare sharktank vs reference decode cross-entropy. + + Manual calculation ce and ppl: calculate_cross_entropy_manual(instance, self.sequence, use_prefill=False) + """ + + shark_ce = 3.4136 + shark_ppl = 30.1250 + + with torch.no_grad(): + total_loss = 0.0 + count = 0 + + for i in range(1, len(self.test_sequence)): + prefix = self.test_sequence[:i] + target = self.test_sequence[i] + + ref_input = torch.tensor(prefix, dtype=torch.long) + ref_logits = self.ref_model(ref_input) + last_logits = ref_logits[-1, :] + + log_probs = torch.nn.functional.log_softmax(last_logits, dim=-1) + token_loss = -log_probs[target].item() + + total_loss += token_loss + count += 1 + + ref_ce = total_loss / count + ref_ppl = float(torch.exp(torch.tensor(ref_ce))) + + self.logger.info( + f"Decode CE - Sharktank: {shark_ce:.4f}, Reference: {ref_ce:.4f}, Diff: {abs(shark_ce - ref_ce):.4f}" + ) + self.logger.info( + f"Decode PPL - Sharktank: {shark_ppl:.4f}, Reference: {ref_ppl:.4f}, Diff: {abs(shark_ppl - ref_ppl):.4f}" + ) + + torch.testing.assert_close( + torch.tensor(shark_ce), + torch.tensor(ref_ce), + rtol=0.15, + atol=0.15, + msg=f"Decode CE mismatch: shark={shark_ce:.4f} vs ref={ref_ce:.4f}", + ) + + torch.testing.assert_close( + torch.tensor(shark_ppl), + torch.tensor(ref_ppl), + rtol=0.15, + atol=0.15, + msg=f"Decode PPL mismatch: shark={shark_ppl:.4f} vs ref={ref_ppl:.4f}", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/sharktank/tests/models/gpt_oss/toy_gpt_oss_test.py b/sharktank/tests/models/gpt_oss/toy_gpt_oss_test.py new file mode 100644 index 00000000000..05e5be3b05b --- /dev/null +++ b/sharktank/tests/models/gpt_oss/toy_gpt_oss_test.py @@ -0,0 +1,223 @@ +"""Tests for toy GPT-OSS model generation and inference.""" + +import torch +import unittest +import logging +from sharktank.models.gpt_oss.toy_gpt_oss import generate, copy_weights_to_reference +from sharktank.utils.llm_utils import ( + LlmInstance, + TorchInstance, + llama_config_page_sizes, +) +from sharktank.models.gpt_oss.orig_pytorch_model import ModelConfig, Transformer + + +class ToyGptOssTest(unittest.TestCase): + def setUp(self): + torch.set_default_dtype(torch.bfloat16) + self.seed = 12345 + + # Hardcoded for CI performance - regenerate with self.generate_sequence() if weights change + self.sequence = [0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] + self.initialized_model() + + def initialized_model(self): + + theta, config = generate(self.seed) + + model = TorchInstance(theta=theta, config=config) + page_sizes = llama_config_page_sizes(config) + block_count = 128 + self.shark_instance = LlmInstance( + model_instance=model, + page_sizes=page_sizes, + block_seq_stride=config.block_seq_stride, + block_count=block_count, + ) + + def generate_sequence(self): + """Generate test sequence dynamically. Repetitive output expected with random weights.""" + theta, config = generate(self.seed) + model = TorchInstance(theta=theta, config=config) + page_sizes = llama_config_page_sizes(config) + block_count = 128 + instance = LlmInstance( + model_instance=model, + page_sizes=page_sizes, + block_seq_stride=config.block_seq_stride, + block_count=block_count, + ) + decoder = instance.make_decoder() + generated_tokens = decoder.greedy_decode([[0]], steps=14)[0] + + full_sequence = [0] + generated_tokens + print(f"Generated tokens: {generated_tokens}") + print(f"Full test sequence: {full_sequence}") + return full_sequence + + def testDecodeSequence(self): + """Test deterministic token generation.""" + + decoder = self.shark_instance.make_decoder() + expected = self.sequence[1:] + + decoded = decoder.greedy_decode([[0]], steps=len(expected))[0] + decoded2 = decoder.greedy_decode([[0]], steps=len(expected))[0] + + self.assertEqual(decoded, decoded2) + self.assertEqual(decoded, expected) + + def testPrefillPerplexity(self): + """Test prefill perplexity calculation. + Manual calculation ce and ppl: calculate_cross_entropy_manual(instance, self.sequence, use_prefill=True) + """ + + decoder = self.shark_instance.make_perplexity_eval() + result = decoder.prefill_cross_entropy([self.sequence])[0] + assert result.valid + + shark_ce = 4.6970133781433105 + torch.testing.assert_close(result.score, shark_ce, atol=1e-2, rtol=1e-2) + + result2 = decoder.prefill_cross_entropy([self.sequence])[0] + self.assertEqual(result.score, result2.score) + + def testDecodePerplexity(self): + """Test decode perplexity calculation. + Manual calculation ce and ppl: calculate_cross_entropy_manual(instance, self.sequence, use_prefill=False) + """ + + decoder = self.shark_instance.make_perplexity_eval() + result = decoder.decode_cross_entropy([self.sequence])[0] + assert result.valid + + shark_ce = 4.6970133781433105 + torch.testing.assert_close(result.score, shark_ce, atol=1e-2, rtol=1e-2) + + result2 = decoder.decode_cross_entropy([self.sequence])[0] + self.assertEqual(result.score, result2.score) + + +class RefSharktankE2ETest(unittest.TestCase): + """Test reference and sharktank model e2e comparison.""" + + def setUp(self): + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + + torch.set_default_dtype(torch.bfloat16) + self.seed = 12345 + torch.manual_seed(self.seed) + + self.initialized_model() + # Hardcoded for CI performance - regenerate with self.generate_sequence() if weights change + self.sequence = [0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] + + def initialized_model(self): + + # Initialize sharktank model + self.shark_theta, self.shark_config = generate(self.seed) + self.hp = self.shark_config.hp + model = TorchInstance(theta=self.shark_theta, config=self.shark_config) + page_sizes = llama_config_page_sizes(self.shark_config) + block_count = 128 + instance = LlmInstance( + model_instance=model, + page_sizes=page_sizes, + block_seq_stride=self.shark_config.block_seq_stride, + block_count=block_count, + ) + self.shark_instance = instance + + # Configure reference model + expert_weight_sample = self.shark_theta( + "blk", 0, "ffn_gate_exps", "weight" + ).as_torch() + actual_intermediate_size = expert_weight_sample.shape[1] + + ref_config = ModelConfig( + num_hidden_layers=self.hp.block_count, + num_experts=self.hp.expert_count, + experts_per_token=self.hp.expert_used_count, + vocab_size=self.hp.vocab_size, + hidden_size=self.hp.embedding_length, + intermediate_size=actual_intermediate_size, + head_dim=self.hp.attn_head_dim, + num_attention_heads=self.hp.attention_head_count, + num_key_value_heads=self.hp.attention_head_count_kv, + sliding_window=self.hp.sliding_window if self.hp.sliding_window else 128, + initial_context_length=self.hp.context_length, + rope_theta=self.hp.rope_freq_base, + rope_scaling_factor=1.0, + ) + + self.ref_model = Transformer(ref_config, device=torch.device("cpu")) + self.ref_model.eval() + + copy_weights_to_reference(self.shark_theta, self.ref_model, self.hp) + + def test_ref_sharktank_prefill_cross_entropy(self): + """Test prefill cross-entropy matches expected values + Manual calculation ce and ppl: calculate_cross_entropy_manual(instance, self.sequence, use_prefill=True) + """ + + decoder = self.shark_instance.make_perplexity_eval() + shark_result = decoder.prefill_cross_entropy([self.sequence])[0] + assert shark_result.valid + expected_ce = 4.6970133781433105 + + with torch.no_grad(): + input_ids = torch.tensor([self.sequence], dtype=torch.long) + ref_input = input_ids.squeeze(0) + ref_logits = self.ref_model(ref_input) + + shift_logits = ref_logits[:-1, :].contiguous() + shift_labels = ref_input[1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss(reduction="mean") + ref_ce = loss_fct(shift_logits, shift_labels).item() + ref_ppl = float(torch.exp(torch.tensor(ref_ce))) + + torch.testing.assert_close( + shark_result.score, expected_ce, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(ref_ce, expected_ce, atol=1e-2, rtol=1e-2) + + def test_ref_sharktank_decode_cross_entropy(self): + """Test decode cross-entropy matches expected values + Manual calculation ce and ppl: calculate_cross_entropy_manual(instance, self.sequence, use_prefill=False) + """ + decoder = self.shark_instance.make_perplexity_eval() + shark_result = decoder.decode_cross_entropy([self.sequence])[0] + assert shark_result.valid + expected_ce = 4.6970133781433105 + + with torch.no_grad(): + total_loss = 0.0 + count = 0 + + for i in range(1, len(self.sequence)): + prefix = self.sequence[:i] + target = self.sequence[i] + + ref_input = torch.tensor(prefix, dtype=torch.long) + ref_logits = self.ref_model(ref_input) + + last_logits = ref_logits[-1, :] + log_probs = torch.nn.functional.log_softmax(last_logits, dim=-1) + token_loss = -log_probs[target].item() + + total_loss += token_loss + count += 1 + + ref_ce = total_loss / count + ref_ppl = float(torch.exp(torch.tensor(ref_ce))) + + torch.testing.assert_close( + shark_result.score, expected_ce, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(ref_ce, expected_ce, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + unittest.main()