Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
386 changes: 386 additions & 0 deletions sharktank/sharktank/models/gpt_oss/orig_pytorch_model.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can rename to ref_pytorch_model.py and if only used in testing, can be moved to tests/models/gpt_oss/

Original file line number Diff line number Diff line change
@@ -0,0 +1,386 @@
import json
Copy link
Collaborator

@archana-ramalingam archana-ramalingam Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the following to all the new files:

# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import math
import os
from dataclasses import dataclass

import torch
import torch.distributed as dist


@dataclass
class ModelConfig:
num_hidden_layers: int = 36
num_experts: int = 128
experts_per_token: int = 4
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
swiglu_limit: float = 7.0
head_dim: int = 64
num_attention_heads: int = 64
num_key_value_heads: int = 8
sliding_window: int = 128
initial_context_length: int = 4096
rope_theta: float = 150000.0
rope_scaling_factor: float = 32.0
rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0


class RMSNorm(torch.nn.Module):
def __init__(
self, num_features: int, eps: float = 1e-05, device: torch.device | None = None
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.scale = torch.nn.Parameter(
torch.ones(num_features, device=device, dtype=torch.float32)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == self.num_features
t, dtype = x.float(), x.dtype
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
return (t * self.scale).to(dtype)


def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
x1, x2 = torch.chunk(x, 2, dim=-1)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1)


class RotaryEmbedding(torch.nn.Module):
def __init__(
self,
head_dim: int,
base: int,
dtype: torch.dtype,
initial_context_length: int = 4096,
scaling_factor: float = 1.0,
ntk_alpha: float = 1.0,
ntk_beta: float = 32.0,
device: torch.device | None = None,
) -> None:
super().__init__()
self.head_dim = head_dim
self.base = base
self.dtype = dtype
self.initial_context_length = initial_context_length
self.scaling_factor = scaling_factor
self.ntk_alpha = ntk_alpha
self.ntk_beta = ntk_beta
self.device = device

def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
"""See YaRN paper: https://arxiv.org/abs/2309.00071"""
freq = self.base ** (
torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)
/ self.head_dim
)
if self.scaling_factor > 1.0:
concentration = (
0.1 * math.log(self.scaling_factor) + 1.0
) # YaRN concentration

d_half = self.head_dim / 2
# NTK by parts
low = (
d_half
* math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
/ math.log(self.base)
)
high = (
d_half
* math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
/ math.log(self.base)
)
assert 0 < low < high < d_half - 1

interpolation = 1.0 / (self.scaling_factor * freq)
extrapolation = 1.0 / freq

ramp = (
torch.arange(d_half, dtype=torch.float32, device=freq.device) - low
) / (high - low)
mask = 1 - ramp.clamp(0, 1)

inv_freq = interpolation * (1 - mask) + extrapolation * mask
else:
concentration = 1.0
inv_freq = 1.0 / freq

return concentration, inv_freq

def _compute_cos_sin(self, num_tokens: int):
concentration, inv_freq = self._compute_concentration_and_inv_freq()
t = torch.arange(num_tokens, dtype=torch.float32, device=self.device)
freqs = torch.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * concentration
sin = freqs.sin() * concentration
return cos, sin

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = query.shape[0]
cos, sin = self._compute_cos_sin(num_tokens)

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_dim)
query = _apply_rotary_emb(query, cos, sin)
query = query.reshape(query_shape)

key_shape = key.shape
key = key.view(num_tokens, -1, self.head_dim)
key = _apply_rotary_emb(key, cos, sin)
key = key.reshape(key_shape)
return query, key


def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
# sliding_window == 0 means no sliding window
n_tokens, n_heads, q_mult, d_head = Q.shape
assert K.shape == (n_tokens, n_heads, d_head)
assert V.shape == (n_tokens, n_heads, d_head)
K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
mask += torch.tril(
mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
)
QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
QK *= sm_scale
QK += mask[None, None, :, :]
QK = torch.cat([QK, S], dim=-1)
W = torch.softmax(QK, dim=-1)
W = W[..., :-1]
attn = torch.einsum("hmqk,khmd->qhmd", W, V)
return attn.reshape(n_tokens, -1)


class AttentionBlock(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
layer_idx: int = 0,
device: torch.device | None = None,
):
super().__init__()
self.head_dim = config.head_dim
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
# Only apply sliding window to every other layer
self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16)
)
self.norm = RMSNorm(config.hidden_size, device=device)
qkv_dim = config.head_dim * (
config.num_attention_heads + 2 * config.num_key_value_heads
)
self.qkv = torch.nn.Linear(
config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16
)
self.out = torch.nn.Linear(
config.head_dim * config.num_attention_heads,
config.hidden_size,
device=device,
dtype=torch.bfloat16,
)
self.sm_scale = 1 / math.sqrt(config.head_dim)
self.rope = RotaryEmbedding(
config.head_dim,
config.rope_theta,
torch.float32,
initial_context_length=config.initial_context_length,
scaling_factor=config.rope_scaling_factor,
ntk_alpha=config.rope_ntk_alpha,
ntk_beta=config.rope_ntk_beta,
device=device,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
qkv = self.qkv(t)
q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()
k = qkv[
:,
self.num_attention_heads
* self.head_dim : (self.num_attention_heads + self.num_key_value_heads)
* self.head_dim,
].contiguous()
v = qkv[
:,
(self.num_attention_heads + self.num_key_value_heads)
* self.head_dim : (self.num_attention_heads + 2 * self.num_key_value_heads)
* self.head_dim,
].contiguous()

q = q.view(
-1,
self.num_key_value_heads,
self.num_attention_heads // self.num_key_value_heads,
self.head_dim,
)
k = k.view(-1, self.num_key_value_heads, self.head_dim)
v = v.view(-1, self.num_key_value_heads, self.head_dim)
q, k = self.rope(q, k)
t = sdpa(q, k, v, self.sinks, self.sm_scale, self.sliding_window)
t = self.out(t)
t = x + t
return t


def swiglu(x, alpha: float = 1.702, limit: float = 7.0):
x_glu, x_linear = x[..., ::2], x[..., 1::2]
# Clamp the input values
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
# Note we add an extra bias of 1 to the linear layer
return out_glu * (x_linear + 1)


class MLPBlock(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
device: torch.device | None = None,
):
super().__init__()
self.num_experts = config.num_experts
self.experts_per_token = config.experts_per_token
self.swiglu_limit = config.swiglu_limit
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.norm = RMSNorm(config.hidden_size, device=device)
self.gate = torch.nn.Linear(
config.hidden_size, config.num_experts, device=device, dtype=torch.bfloat16
)
assert config.intermediate_size % self.world_size == 0
self.mlp1_weight = torch.nn.Parameter(
torch.empty(
(
config.num_experts,
config.intermediate_size * 2 // self.world_size,
config.hidden_size,
),
device=device,
dtype=torch.bfloat16,
)
)
self.mlp1_bias = torch.nn.Parameter(
torch.empty(
(config.num_experts, config.intermediate_size * 2 // self.world_size),
device=device,
dtype=torch.bfloat16,
)
)
self.mlp2_weight = torch.nn.Parameter(
torch.empty(
(
config.num_experts,
config.hidden_size,
config.intermediate_size // self.world_size,
),
device=device,
dtype=torch.bfloat16,
)
)
self.mlp2_bias = torch.nn.Parameter(
torch.empty(
(config.num_experts, config.hidden_size),
device=device,
dtype=torch.bfloat16,
)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
g = self.gate(t)
experts = torch.topk(g, k=self.experts_per_token, dim=-1, sorted=True)
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
expert_indices = experts.indices

# MLP #1
mlp1_weight = self.mlp1_weight[expert_indices, ...]
mlp1_bias = self.mlp1_bias[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, limit=self.swiglu_limit)

# MLP #2
mlp2_weight = self.mlp2_weight[expert_indices, ...]
mlp2_bias = self.mlp2_bias[expert_indices, ...]
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
if self.world_size > 1:
dist.all_reduce(t, op=dist.ReduceOp.SUM)
t += mlp2_bias

# Weighted sum of experts
t = torch.einsum("bec,be->bc", t, expert_weights)

return x + t


class TransformerBlock(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
layer_idx: int,
device: torch.device | None = None,
):
super().__init__()
self.layer_idx = layer_idx
self.attn = AttentionBlock(config, layer_idx, device)
self.mlp = MLPBlock(config, device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.mlp(x)
return x


class Transformer(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
device: torch.device | None = None,
):
super().__init__()
self.embedding = torch.nn.Embedding(
config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
)
self.block = torch.nn.ModuleList(
[
TransformerBlock(config, layer_idx, device)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, device=device)
self.unembedding = torch.nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
device=device,
dtype=torch.bfloat16,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embedding(x)
for block in self.block:
x = block(x)
x = self.norm(x)
x = self.unembedding(x)
return x
Loading
Loading