-
Notifications
You must be signed in to change notification settings - Fork 301
ADD RWKV7 #2421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
ADD RWKV7 #2421
Changes from all commits
195ef79
7bc36b5
7d4a7a1
e5bb446
afcff31
ec0baf3
bd6c618
4201a7f
897a64b
ff11f94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block | ||
|
||
|
||
def rwkv7_kernel_initializer(stddev=0.02): | ||
return keras.initializers.TruncatedNormal(stddev=stddev) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.RWKV7Backbone") | ||
class RWKV7Backbone(Backbone): | ||
"""The [RWKV-7](https://arxiv.org/abs/2503.14456) core architecture. | ||
|
||
This network implements a Modern RNN architecture based on linear | ||
attention mechanisms with recurrent processing, as described in the | ||
RWKV papers. It includes the embedding lookups and RWKV-7 blocks. | ||
|
||
The default constructor gives a fully customizable, randomly initialized | ||
RWKV-7 model with any number of layers, heads, and embedding dimensions. | ||
To load preset architectures and weights, use the `from_preset` | ||
constructor. | ||
|
||
Args: | ||
hidden_size: int. The size of the transformer encoding and pooling | ||
layers. | ||
head_size: int. The size of each attention head. | ||
num_layers: int. The number of transformer layers. | ||
vocabulary_size: int. The size of the token vocabulary. | ||
intermediate_dim: int. The output dimension of the first Dense layer in | ||
a two-layer feedforward network for each transformer. | ||
gate_lora: int. LoRA dimension for gating. | ||
mv_lora: int. LoRA dimension for value mixing. | ||
aaa_lora: int. LoRA dimension for alpha parameters. | ||
decay_lora: int. LoRA dimension for decay parameters. | ||
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||
for model computations and weights. Note that some computations, | ||
such as softmax and layer normalization, will always be done at | ||
float32 precision regardless of dtype. | ||
dropout_rate: float. Dropout rate for the dropout layer. | ||
|
||
Examples: | ||
|
||
```python | ||
input_data = np.ones(shape=(1, 12), dtype="int32") | ||
|
||
|
||
# Randomly initialized RWKV-7 decoder with custom config. | ||
model = keras_hub.models.RWKV7Backbone( | ||
vocabulary_size=10, | ||
hidden_size=512, | ||
num_layers=2, | ||
head_size=64, | ||
intermediate_dim=1024, | ||
dtype="float32" | ||
) | ||
model(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size, | ||
head_size, | ||
num_layers, | ||
vocabulary_size, | ||
intermediate_dim, | ||
gate_lora=128, | ||
mv_lora=32, | ||
aaa_lora=64, | ||
decay_lora=64, | ||
dtype=None, | ||
dropout_rate=0, | ||
**kwargs, | ||
): | ||
"""Initialize RWKV7 backbone. | ||
|
||
Args: | ||
hidden_size: Hidden dimension size. | ||
head_size: Attention head size. | ||
num_layers: Number of RWKV blocks. | ||
vocabulary_size: Size of vocabulary. | ||
intermediate_dim: Intermediate dimension for FFN. | ||
gate_lora: LoRA dimension for gating. | ||
mv_lora: LoRA dimension for value mixing. | ||
aaa_lora: LoRA dimension for alpha parameters. | ||
decay_lora: LoRA dimension for decay parameters. | ||
dtype: Data type for the layer. | ||
dropout_rate: Dropout rate for regularization. | ||
**kwargs: Additional arguments. | ||
""" | ||
# === Layers === | ||
self.token_embedding = keras.layers.Embedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_size, | ||
embeddings_initializer=rwkv7_kernel_initializer(), | ||
dtype=dtype, | ||
name="token_embedding", | ||
) | ||
self.token_embedding.build([None, None]) | ||
|
||
self.output_layer_norm = keras.layers.LayerNormalization( | ||
epsilon=1e-5, name="output_norm" | ||
) | ||
self.output_layer_norm.build([None, None, hidden_size]) | ||
self.dropout = keras.layers.Dropout( | ||
dropout_rate, | ||
dtype=dtype, | ||
name="dropout", | ||
) | ||
self.rwkv_layers = [] | ||
for i in range(num_layers): | ||
layer = RWKV7_Block( | ||
hidden_size, | ||
head_size, | ||
intermediate_dim, | ||
gate_lora, | ||
mv_lora, | ||
aaa_lora, | ||
decay_lora, | ||
use_initial_norm=i == 0, | ||
kernel_initializer=rwkv7_kernel_initializer(), | ||
dtype=dtype, | ||
name=f"rwkv_layer_{i}", | ||
) | ||
|
||
self.rwkv_layers.append(layer) | ||
self.head = keras.layers.Dense( | ||
units=vocabulary_size, | ||
kernel_initializer=rwkv7_kernel_initializer(), | ||
use_bias=False, | ||
name="head", | ||
) | ||
# === Functional Model === | ||
token_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="token_ids" | ||
) | ||
|
||
padding_mask = ops.not_equal(token_id_input, 0) | ||
|
||
x = self.token_embedding(token_id_input) | ||
padding_mask = ops.cast(padding_mask, dtype=x.dtype) | ||
v_first = None | ||
for rwkv_layer in self.rwkv_layers: | ||
x, v_first = rwkv_layer(x, v_first, padding_mask) | ||
x = self.dropout(x) | ||
sequence_output = self.output_layer_norm(x) | ||
sequence_output = self.head(sequence_output) | ||
super().__init__( | ||
inputs=token_id_input, | ||
outputs=sequence_output, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
Comment on lines
+151
to
+156
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The backbone's Style Guide ReferencesFootnotes
|
||
# Initialize the graph to avoid potential errors in some cases | ||
self.call(ops.ones([1, 16], "int32")) | ||
|
||
self.num_layers = num_layers | ||
self.head_size = head_size | ||
self.hidden_size = hidden_size | ||
self.gate_lora = gate_lora | ||
self.mv_lora = mv_lora | ||
self.aaa_lora = aaa_lora | ||
self.decay_lora = decay_lora | ||
self.vocabulary_size = vocabulary_size | ||
self.dropout_rate = dropout_rate | ||
self.intermediate_dim = intermediate_dim | ||
|
||
def get_config(self): | ||
config = { | ||
"hidden_size": self.hidden_size, | ||
"head_size": self.head_size, | ||
"gate_lora": self.gate_lora, | ||
"mv_lora": self.mv_lora, | ||
"aaa_lora": self.aaa_lora, | ||
"decay_lora": self.decay_lora, | ||
"vocabulary_size": self.vocabulary_size, | ||
"dropout_rate": self.dropout_rate, | ||
"intermediate_dim": self.intermediate_dim, | ||
"num_layers": self.num_layers, | ||
} | ||
base_config = super().get_config() | ||
return dict(list(base_config.items()) + list(config.items())) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from keras import ops | ||
|
||
from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class RWKV7BackboneTest(TestCase): | ||
def setUp(self): | ||
""" | ||
Set up the test case with default arguments and input data. | ||
""" | ||
self.init_kwargs = { | ||
"vocabulary_size": 10, | ||
"hidden_size": 16, | ||
"num_layers": 2, | ||
"head_size": 4, | ||
"intermediate_dim": 32, | ||
"gate_lora": 32, | ||
"mv_lora": 16, | ||
"aaa_lora": 16, | ||
"decay_lora": 16, | ||
} | ||
self.input_data = ops.ones((2, 5), dtype="int32") | ||
self.backbone = RWKV7Backbone(**self.init_kwargs) | ||
|
||
def test_backbone_basics(self): | ||
""" | ||
Test basic functionality of the RWKV7 backbone. | ||
""" | ||
y = self.backbone(self.input_data) | ||
self.assertEqual(y.shape, (2, 5, 10)) | ||
|
||
def test_num_parameters(self): | ||
""" | ||
Test that the model has the expected number of parameters. | ||
""" | ||
self.assertEqual(self.backbone.count_params(), 10208) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
RWKV7Backbone
class is missing a docstring. Please add a Google-style docstring explaining the model's architecture, its parameters, and include a usage example, as specified in the style guide.1Style Guide References
Footnotes
All public classes, methods, and functions must have Google-style docstrings, including a concise summary, comprehensive examples, and documentation for all parameters, return values, and exceptions. ↩