-
Notifications
You must be signed in to change notification settings - Fork 301
Generated GPT_OSS model files through porter script. #2384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
laxmareddyp
wants to merge
19
commits into
keras-team:master
Choose a base branch
from
laxmareddyp:test_gpt_oss_model
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 5 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
26867ba
Test GPT_OSS files through porter
laxmareddyp f1c055b
generate API and moved files to respective folders
laxmareddyp d4da96c
Fix format issues
laxmareddyp b14cfb5
Add gpt_oss to preset loader and Fix format issues
laxmareddyp b675610
Add gpt_oss to preset loader
laxmareddyp 8cf71ce
generated files through 2.5-pro model
laxmareddyp 2242ef4
Format fix
laxmareddyp eb25d19
Add converter, RoPE update
laxmareddyp ba50a9f
Fix format
laxmareddyp 1854d80
Fix BPE tests
laxmareddyp 76139cd
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp 00ec305
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp 9447990
Update converter
laxmareddyp 340aa85
Fix converter, checkpoints conversion and attention
laxmareddyp b02cfea
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp 47dcdda
Fix the parameter count and debug code
laxmareddyp 5e16f80
Add dequantization logic to converter
laxmareddyp 79c5664
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp 59b6930
Add YaRN support,Fix Serialisation,Fix dequantization
laxmareddyp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import GptOssBackbone | ||
from keras_hub.src.models.gpt_oss.gpt_oss_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, GptOssBackbone) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,313 @@ | ||
import math | ||
|
||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
from keras_hub.src.utils.keras_utils import clone_initializer | ||
|
||
|
||
class CachedGptOssAttention(keras.layers.Layer): | ||
"""A cached attention layer for GPT-OSS with sink tokens and sliding window. | ||
|
||
This layer implements the attention mechanism for the GPT-OSS model, | ||
including grouped query attention (GQA),rotary positional embeddings(RoPE) | ||
and a specific handling for "sink" tokens which are added to the attention | ||
logits before softmax. It also supports caching for efficient generation. | ||
|
||
Args: | ||
num_query_heads: Number of attention heads for queries. | ||
num_key_value_heads: Number of attention heads for keys and values. | ||
If `num_query_heads != num_key_value_heads`, grouped query attention | ||
is used. | ||
rope_max_wavelength: The maximum wavelength for the rotary embedding. | ||
rope_scaling_factor: Scaling factor for rotary embeddings. | ||
kernel_initializer: Initializer for the dense layer kernels. | ||
sliding_window: The size of the sliding window for attention. | ||
Tokens outside this window are masked. This parameter is used for | ||
configuration but the actual masking should be handled by the | ||
`attention_mask` input. | ||
dropout: Dropout rate for attention probabilities. | ||
use_bias: Whether to include bias terms in the dense projections. | ||
**kwargs: Additional keyword arguments passed to the base Layer class. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_query_heads, | ||
num_key_value_heads, | ||
rope_max_wavelength=10000, | ||
rope_scaling_factor=1.0, | ||
kernel_initializer="glorot_uniform", | ||
sliding_window=4096, | ||
dropout=0, | ||
use_bias=False, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.num_query_heads = num_query_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.sliding_window = sliding_window | ||
self.dropout = dropout | ||
self.use_bias = use_bias | ||
|
||
if self.num_query_heads % self.num_key_value_heads != 0: | ||
raise ValueError( | ||
f"num_query_heads({self.num_query_heads})must be divisible by" | ||
f"num_key_value_heads ({self.num_key_value_heads})" | ||
) | ||
self.num_key_value_groups = ( | ||
self.num_query_heads // self.num_key_value_heads | ||
) | ||
self.rope_max_wavelength = rope_max_wavelength | ||
self.rope_scaling_factor = rope_scaling_factor | ||
|
||
self._kernel_initializer = keras.initializers.get( | ||
clone_initializer(kernel_initializer) | ||
) | ||
|
||
def build(self, inputs_shape): | ||
# Einsum variables: | ||
# b = batch size | ||
# q = query length | ||
# k = key/value length | ||
# m = model dim | ||
# u = num query heads | ||
# v = num key/value heads | ||
# h = head dim | ||
self._hidden_dim = inputs_shape[-1] | ||
self._head_dim = self._hidden_dim // self.num_query_heads | ||
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) | ||
|
||
self.query_dense = keras.layers.EinsumDense( | ||
equation="bqm,muh->bquh", | ||
output_shape=(None, self.num_query_heads, self._head_dim), | ||
kernel_initializer=self._kernel_initializer, | ||
use_bias=self.use_bias, | ||
dtype=self.dtype_policy, | ||
name="q_proj", | ||
) | ||
self.query_dense.build(inputs_shape) | ||
|
||
self.key_dense = keras.layers.EinsumDense( | ||
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self.num_key_value_heads, | ||
self._head_dim, | ||
), | ||
kernel_initializer=self._kernel_initializer, | ||
use_bias=self.use_bias, | ||
dtype=self.dtype_policy, | ||
name="k_proj", | ||
) | ||
self.key_dense.build(inputs_shape) | ||
|
||
self.value_dense = keras.layers.EinsumDense( | ||
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self.num_key_value_heads, | ||
self._head_dim, | ||
), | ||
kernel_initializer=self._kernel_initializer, | ||
use_bias=self.use_bias, | ||
dtype=self.dtype_policy, | ||
name="v_proj", | ||
) | ||
self.value_dense.build(inputs_shape) | ||
|
||
stddev = ( | ||
self._kernel_initializer.stddev | ||
if hasattr(self._kernel_initializer, "stddev") | ||
else 0.02 | ||
) | ||
self.sinks = self.add_weight( | ||
name="sinks", | ||
shape=(self.num_query_heads,), | ||
initializer=keras.initializers.RandomNormal( | ||
mean=0.0, stddev=stddev | ||
), | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.softmax = keras.layers.Softmax( | ||
axis=-1, | ||
dtype="float32", | ||
name="attention_softmax", | ||
) | ||
|
||
self.dropout_layer = keras.layers.Dropout( | ||
rate=self.dropout, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.output_dense = keras.layers.EinsumDense( | ||
equation="bquh,uhm->bqm", | ||
output_shape=(None, self._hidden_dim), | ||
kernel_initializer=self._kernel_initializer, | ||
use_bias=self.use_bias, | ||
dtype=self.dtype_policy, | ||
name="o_proj", | ||
) | ||
self.output_dense.build( | ||
(None, None, self.num_query_heads, self._head_dim) | ||
) | ||
|
||
self.rotary_embedding_layer = RotaryEmbedding( | ||
max_wavelength=self.rope_max_wavelength, | ||
scaling_factor=self.rope_scaling_factor, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self._dot_product_equation = "bquh,bkuh->buqk" | ||
self._combine_equation = "buqk,bkuh->bquh" | ||
|
||
self.built = True | ||
|
||
def call( | ||
self, | ||
hidden_states, | ||
attention_mask=None, | ||
cache=None, | ||
cache_update_index=None, | ||
training=None, | ||
): | ||
start_index = ( | ||
cache_update_index if cache_update_index is not None else 0 | ||
) | ||
|
||
query = self.query_dense(hidden_states) | ||
|
||
# Compute RoPE for queries | ||
query = self.rotary_embedding_layer(query, start_index=start_index) | ||
|
||
def _compute_key_value(x): | ||
key, value = self.key_dense(x), self.value_dense(x) | ||
# Compute RoPE for keys | ||
key = self.rotary_embedding_layer(key, start_index=start_index) | ||
return key, value | ||
|
||
if cache is not None: | ||
key_cache = cache[:, 0, ...] | ||
value_cache = cache[:, 1, ...] | ||
if cache_update_index is None: | ||
key = key_cache | ||
value = value_cache | ||
else: | ||
key_update, value_update = _compute_key_value(hidden_states) | ||
start = [0, cache_update_index, 0, 0] | ||
key = ops.slice_update(key_cache, start, key_update) | ||
value = ops.slice_update(value_cache, start, value_update) | ||
cache = ops.stack((key, value), axis=1) | ||
else: | ||
if cache_update_index is not None: | ||
raise ValueError( | ||
"`cache_update_index` should not be set if `cache` is " | ||
f"`None`. Received: cache={cache}, " | ||
f"cache_update_index={cache_update_index}" | ||
) | ||
key, value = _compute_key_value(hidden_states) | ||
if self.num_key_value_groups > 1: | ||
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) | ||
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) | ||
|
||
attention_output = self._compute_attention( | ||
query, key, value, attention_mask, training=training | ||
) | ||
|
||
attention_output = self.dropout_layer( | ||
attention_output, training=training | ||
) | ||
|
||
attention_output = self.output_dense(attention_output) | ||
|
||
if cache is not None: | ||
return attention_output, cache | ||
return attention_output | ||
|
||
def _use_fused_attention_op(self): | ||
# GPT-OSS attention includes "sink" tokens which are added to the logits | ||
# before softmax. The Keras `ops.dot_product_attention` does not support | ||
# this custom modification to the logits. Therefore, we must use the | ||
# manual attention calculation path. | ||
return False | ||
|
||
def _compute_attention( | ||
self, query, key, value, attention_mask=None, training=None | ||
): | ||
# The _use_fused_attention_op is explicitly False for GptOssAttention | ||
# due to the sink token mechanism. | ||
|
||
# 1. Calculate raw attention scores | ||
attention_scores = ops.einsum(self._dot_product_equation, query, key) | ||
attention_scores = ops.multiply( | ||
attention_scores, | ||
ops.cast(self._inv_norm_factor, self.compute_dtype), | ||
) | ||
|
||
# 2. Apply attention mask (if any) | ||
if attention_mask is not None: | ||
if ops.ndim(attention_mask) == 3: | ||
attention_mask = ops.expand_dims(attention_mask, axis=1) | ||
attention_scores = attention_scores + attention_mask | ||
|
||
# 3. Prepare and concatenate sink tokens | ||
# sinks shape: (num_query_heads,) | ||
sinks_expanded = ops.reshape( | ||
self.sinks, (1, self.num_query_heads, 1, 1) | ||
) | ||
# The attention_scores shape is (batch, num_heads, query_len, key_len) | ||
sinks_expanded = ops.broadcast_to( | ||
sinks_expanded, ops.shape(attention_scores)[:-1] + (1,) | ||
) | ||
|
||
# Concatenate attention scores with sinks along the last dimension | ||
# Resulting shape: (batch, num_query_heads, query_len, key_len + 1) | ||
combined_logits = ops.concatenate( | ||
[attention_scores, sinks_expanded], axis=-1 | ||
) | ||
|
||
# 4. Apply numerical stability clamping before softmax | ||
max_logits = ops.max(combined_logits, axis=-1, keepdims=True) | ||
combined_logits = combined_logits - max_logits | ||
|
||
# 5. Apply softmax | ||
# Softmax is applied to the combined logits (scores + sinks) | ||
probs = self.softmax(combined_logits) # self.softmax is float32 | ||
|
||
# 6. Drop the sink token probability to get final attention weights | ||
# scores = probs[..., :-1] | ||
scores = ops.slice( | ||
probs, | ||
[0, 0, 0, 0], | ||
ops.shape(probs)[:-1] + (ops.shape(probs)[-1] - 1,), | ||
) | ||
|
||
# 7. Cast to compute_dtype (dropout is handled outside this method) | ||
attention_weights = ops.cast(scores, self.compute_dtype) | ||
|
||
# 8. Compute weighted sum of values | ||
attention_output = ops.einsum( | ||
self._combine_equation, attention_weights, value | ||
) | ||
|
||
return attention_output | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_query_heads": self.num_query_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"rope_max_wavelength": self.rope_max_wavelength, | ||
"rope_scaling_factor": self.rope_scaling_factor, | ||
"kernel_initializer": keras.initializers.serialize( | ||
self._kernel_initializer | ||
), | ||
"sliding_window": self.sliding_window, | ||
"dropout": self.dropout, | ||
"use_bias": self.use_bias, | ||
} | ||
) | ||
return config |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.