Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,15 @@
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
RoformerV2Tokenizer as RoformerV2Tokenizer,
)
from keras_hub.src.models.rwkv7.rwkv7_backbone import (
RWKV7Backbone as RWKV7Backbone,
)
from keras_hub.src.models.rwkv7.rwkv7_causal_lm import (
RWKV7CausalLM as RWKV7CausalLM,
)
from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import (
RWKV7CausalLMPreprocessor as RWKV7CausalLMPreprocessor,
)
from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone
from keras_hub.src.models.sam.sam_image_segmenter import (
SAMImageSegmenter as SAMImageSegmenter,
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
RoformerV2Tokenizer as RoformerV2Tokenizer,
)
from keras_hub.src.models.rwkv7.rwkv7_tokenizer import (
RWKVTokenizer as RWKVTokenizer,
)
from keras_hub.src.models.siglip.siglip_tokenizer import (
SigLIPTokenizer as SigLIPTokenizer,
)
Expand Down
185 changes: 185 additions & 0 deletions keras_hub/src/models/rwkv7/rwkv7_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block


def rwkv7_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras_hub_export("keras_hub.models.RWKV7Backbone")
class RWKV7Backbone(Backbone):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The RWKV7Backbone class is missing a docstring. Please add a Google-style docstring explaining the model's architecture, its parameters, and include a usage example, as specified in the style guide.1

Style Guide References

Footnotes

  1. All public classes, methods, and functions must have Google-style docstrings, including a concise summary, comprehensive examples, and documentation for all parameters, return values, and exceptions.

"""The [RWKV-7](https://arxiv.org/abs/2503.14456) core architecture.

This network implements a Modern RNN architecture based on linear
attention mechanisms with recurrent processing, as described in the
RWKV papers. It includes the embedding lookups and RWKV-7 blocks.

The default constructor gives a fully customizable, randomly initialized
RWKV-7 model with any number of layers, heads, and embedding dimensions.
To load preset architectures and weights, use the `from_preset`
constructor.

Args:
hidden_size: int. The size of the transformer encoding and pooling
layers.
head_size: int. The size of each attention head.
num_layers: int. The number of transformer layers.
vocabulary_size: int. The size of the token vocabulary.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
gate_lora: int. LoRA dimension for gating.
mv_lora: int. LoRA dimension for value mixing.
aaa_lora: int. LoRA dimension for alpha parameters.
decay_lora: int. LoRA dimension for decay parameters.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.
dropout_rate: float. Dropout rate for the dropout layer.

Examples:

```python
input_data = np.ones(shape=(1, 12), dtype="int32")


# Randomly initialized RWKV-7 decoder with custom config.
model = keras_hub.models.RWKV7Backbone(
vocabulary_size=10,
hidden_size=512,
num_layers=2,
head_size=64,
intermediate_dim=1024,
dtype="float32"
)
model(input_data)
```
"""

def __init__(
self,
hidden_size,
head_size,
num_layers,
vocabulary_size,
intermediate_dim,
gate_lora=128,
mv_lora=32,
aaa_lora=64,
decay_lora=64,
dtype=None,
dropout_rate=0,
**kwargs,
):
"""Initialize RWKV7 backbone.

Args:
hidden_size: Hidden dimension size.
head_size: Attention head size.
num_layers: Number of RWKV blocks.
vocabulary_size: Size of vocabulary.
intermediate_dim: Intermediate dimension for FFN.
gate_lora: LoRA dimension for gating.
mv_lora: LoRA dimension for value mixing.
aaa_lora: LoRA dimension for alpha parameters.
decay_lora: LoRA dimension for decay parameters.
dtype: Data type for the layer.
dropout_rate: Dropout rate for regularization.
**kwargs: Additional arguments.
"""
# === Layers ===
self.token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_size,
embeddings_initializer=rwkv7_kernel_initializer(),
dtype=dtype,
name="token_embedding",
)
self.token_embedding.build([None, None])

self.output_layer_norm = keras.layers.LayerNormalization(
epsilon=1e-5, name="output_norm"
)
self.output_layer_norm.build([None, None, hidden_size])
self.dropout = keras.layers.Dropout(
dropout_rate,
dtype=dtype,
name="dropout",
)
self.rwkv_layers = []
for i in range(num_layers):
layer = RWKV7_Block(
hidden_size,
head_size,
intermediate_dim,
gate_lora,
mv_lora,
aaa_lora,
decay_lora,
use_initial_norm=i == 0,
kernel_initializer=rwkv7_kernel_initializer(),
dtype=dtype,
name=f"rwkv_layer_{i}",
)

self.rwkv_layers.append(layer)
self.head = keras.layers.Dense(
units=vocabulary_size,
kernel_initializer=rwkv7_kernel_initializer(),
use_bias=False,
name="head",
)
# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

padding_mask = ops.not_equal(token_id_input, 0)

x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask, dtype=x.dtype)
v_first = None
for rwkv_layer in self.rwkv_layers:
x, v_first = rwkv_layer(x, v_first, padding_mask)
x = self.dropout(x)
sequence_output = self.output_layer_norm(x)
sequence_output = self.head(sequence_output)
super().__init__(
inputs=token_id_input,
outputs=sequence_output,
dtype=dtype,
**kwargs,
)
Comment on lines +151 to +156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The backbone's __init__ method only accepts a single token_ids tensor as input. For consistency with other models in keras_hub and to improve interoperability, the backbone should be modified to accept a dictionary of inputs, including token_ids and padding_mask.1 The padding_mask is currently computed inside the backbone, but it's better practice to have it as an explicit input.

Style Guide References

Footnotes

  1. Use standardized names for model input arguments to ensure interoperability. For text models, this includes token_ids and padding_mask. The backbone should accept a dictionary of these inputs.

# Initialize the graph to avoid potential errors in some cases
self.call(ops.ones([1, 16], "int32"))

self.num_layers = num_layers
self.head_size = head_size
self.hidden_size = hidden_size
self.gate_lora = gate_lora
self.mv_lora = mv_lora
self.aaa_lora = aaa_lora
self.decay_lora = decay_lora
self.vocabulary_size = vocabulary_size
self.dropout_rate = dropout_rate
self.intermediate_dim = intermediate_dim

def get_config(self):
config = {
"hidden_size": self.hidden_size,
"head_size": self.head_size,
"gate_lora": self.gate_lora,
"mv_lora": self.mv_lora,
"aaa_lora": self.aaa_lora,
"decay_lora": self.decay_lora,
"vocabulary_size": self.vocabulary_size,
"dropout_rate": self.dropout_rate,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
37 changes: 37 additions & 0 deletions keras_hub/src/models/rwkv7/rwkv7_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from keras import ops

from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
from keras_hub.src.tests.test_case import TestCase


class RWKV7BackboneTest(TestCase):
def setUp(self):
"""
Set up the test case with default arguments and input data.
"""
self.init_kwargs = {
"vocabulary_size": 10,
"hidden_size": 16,
"num_layers": 2,
"head_size": 4,
"intermediate_dim": 32,
"gate_lora": 32,
"mv_lora": 16,
"aaa_lora": 16,
"decay_lora": 16,
}
self.input_data = ops.ones((2, 5), dtype="int32")
self.backbone = RWKV7Backbone(**self.init_kwargs)

def test_backbone_basics(self):
"""
Test basic functionality of the RWKV7 backbone.
"""
y = self.backbone(self.input_data)
self.assertEqual(y.shape, (2, 5, 10))

def test_num_parameters(self):
"""
Test that the model has the expected number of parameters.
"""
self.assertEqual(self.backbone.count_params(), 10208)
Loading
Loading