-
Notifications
You must be signed in to change notification settings - Fork 301
Added Voxtral backbone and it's test #2394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sanskarmodi8
wants to merge
3
commits into
keras-team:master
Choose a base branch
from
sanskarmodi8:adding_voxtral/pr#1
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
import tensorflow as tf | ||
from keras import initializers | ||
from keras import layers | ||
from keras import mixed_precision | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
|
||
def voxtral_kernel_initializer(stddev=0.02): | ||
"""Initializer for VoxTral layers (TruncatedNormal).""" | ||
return initializers.TruncatedNormal(stddev=stddev) | ||
|
||
|
||
class ChunkAndPad(layers.Layer): | ||
sanskarmodi8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Pads and splits spectrogram into fixed-length chunks.""" | ||
|
||
def __init__(self, frames_per_chunk, **kwargs): | ||
super().__init__(**kwargs) | ||
self.frames_per_chunk = int(frames_per_chunk) | ||
|
||
def call(self, x): | ||
B, T = tf.shape(x)[0], tf.shape(x)[1] | ||
pad_len = (-T) % self.frames_per_chunk | ||
x = tf.pad(x, [[0, 0], [0, pad_len], [0, 0]]) | ||
n_chunks = tf.math.floordiv(T + pad_len, self.frames_per_chunk) | ||
return tf.reshape( | ||
x, [B * n_chunks, self.frames_per_chunk, tf.shape(x)[2]] | ||
) | ||
sanskarmodi8 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class PositionalEmbedding(layers.Layer): | ||
"""Learnable positional embedding per chunk.""" | ||
|
||
def __init__(self, length, dim, **kwargs): | ||
super().__init__(**kwargs) | ||
self.length = int(length) | ||
self.dim = int(dim) | ||
|
||
def build(self, input_shape): | ||
self.pos_emb = self.add_weight( | ||
name="pos_emb", | ||
shape=(self.length, self.dim), | ||
initializer=initializers.RandomNormal(stddev=0.02), | ||
trainable=True, | ||
) | ||
super().build(input_shape) | ||
|
||
def call(self, x): | ||
return x + self.pos_emb[None, :, :] | ||
|
||
|
||
class ReassembleChunks(layers.Layer): | ||
"""Reassembles chunked outputs back into (B, T, H).""" | ||
|
||
def __init__(self, frames_per_chunk, postproc_chunk_len=None, **kwargs): | ||
super().__init__(**kwargs) | ||
self.frames_per_chunk = int(frames_per_chunk) | ||
self.postproc_chunk_len = postproc_chunk_len | ||
|
||
def call(self, processed_chunks, orig_spectrogram): | ||
B, T = tf.shape(orig_spectrogram)[0], tf.shape(orig_spectrogram)[1] | ||
n_chunks = tf.cast( | ||
tf.math.floordiv( | ||
T + self.frames_per_chunk - 1, self.frames_per_chunk | ||
), | ||
tf.int32, | ||
) | ||
T_chunk, H = ( | ||
tf.shape(processed_chunks)[1], | ||
tf.shape(processed_chunks)[2], | ||
) | ||
return tf.reshape(processed_chunks, [B, n_chunks * T_chunk, H]) | ||
sanskarmodi8 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
|
||
@keras_hub_export("keras_hub.models.VoxTralBackbone") | ||
class VoxTralBackbone(Backbone): | ||
"""VoxTral audio encoder + adapter backbone.""" | ||
sanskarmodi8 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
num_layers=32, | ||
num_heads=20, | ||
hidden_dim=1280, | ||
intermediate_dim=5120, | ||
adapter_downsample=4, | ||
dropout=0.1, | ||
max_chunk_seconds=30, | ||
sr=16000, | ||
hop_length=160, | ||
dtype="float32", | ||
**kwargs, | ||
): | ||
self.num_layers = int(num_layers) | ||
self.num_heads = int(num_heads) | ||
self.hidden_dim = int(hidden_dim) | ||
self.intermediate_dim = int(intermediate_dim) | ||
self.adapter_downsample = int(adapter_downsample) | ||
self.dropout = float(dropout) | ||
self.max_chunk_seconds = int(max_chunk_seconds) | ||
self.sr = int(sr) | ||
self.hop_length = int(hop_length) | ||
|
||
# Frames per chunk before conv | ||
self.frames_per_chunk_preconv = int( | ||
self.max_chunk_seconds * (self.sr / self.hop_length) | ||
) | ||
self.postconv_frames_per_chunk = self.frames_per_chunk_preconv // 2 | ||
|
||
# Determine layer dtype for mixed precision | ||
if isinstance(dtype, mixed_precision.Policy): | ||
self.layer_dtype = dtype.compute_dtype | ||
else: | ||
self.layer_dtype = dtype | ||
|
||
# Conv1D stem | ||
self.conv_stem_1 = layers.Conv1D( | ||
filters=self.hidden_dim, | ||
kernel_size=3, | ||
strides=2, | ||
padding="same", | ||
activation="relu", | ||
kernel_initializer=voxtral_kernel_initializer(), | ||
dtype=self.layer_dtype, | ||
name="conv_stem_1", | ||
) | ||
self.conv_stem_2 = layers.Conv1D( | ||
filters=self.hidden_dim, | ||
kernel_size=3, | ||
strides=1, | ||
padding="same", | ||
activation="relu", | ||
kernel_initializer=voxtral_kernel_initializer(), | ||
dtype=self.layer_dtype, | ||
name="conv_stem_2", | ||
) | ||
|
||
# Transformer layers | ||
self.transformer_layers = [ | ||
TransformerEncoder( | ||
num_heads=self.num_heads, | ||
intermediate_dim=self.intermediate_dim, | ||
dropout=self.dropout, | ||
name=f"transformer_layer_{i}", | ||
) | ||
for i in range(self.num_layers) | ||
] | ||
|
||
# Adapter | ||
self.adapter_dense = layers.Dense( | ||
self.hidden_dim, | ||
activation="relu", | ||
kernel_initializer=voxtral_kernel_initializer(), | ||
dtype=self.layer_dtype, | ||
name="adapter_dense", | ||
) | ||
self.adapter_pool = layers.AveragePooling1D( | ||
pool_size=self.adapter_downsample, | ||
strides=self.adapter_downsample, | ||
padding="valid", | ||
name="adapter_downsample", | ||
) | ||
|
||
# Positional embeddings | ||
self.pos_emb = PositionalEmbedding( | ||
self.postconv_frames_per_chunk, self.hidden_dim, name="pos_emb" | ||
) | ||
|
||
# Functional model | ||
spectrogram_input = tf.keras.Input( | ||
shape=(None, 128), dtype=self.layer_dtype, name="spectrogram" | ||
) | ||
sanskarmodi8 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
x = ChunkAndPad(self.frames_per_chunk_preconv, name="chunk_and_pad")( | ||
spectrogram_input | ||
) | ||
x = self.conv_stem_1(x) | ||
x = self.conv_stem_2(x) | ||
x = self.pos_emb(x) | ||
for transformer_layer in self.transformer_layers: | ||
x = transformer_layer(x) | ||
x = self.adapter_dense(x) | ||
x = self.adapter_pool(x) | ||
outputs = ReassembleChunks( | ||
self.frames_per_chunk_preconv, name="reassemble_chunks" | ||
)(x, spectrogram_input) | ||
|
||
super().__init__( | ||
inputs=spectrogram_input, outputs=outputs, dtype=dtype, **kwargs | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_layers": self.num_layers, | ||
"num_heads": self.num_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"adapter_downsample": self.adapter_downsample, | ||
"dropout": self.dropout, | ||
"max_chunk_seconds": self.max_chunk_seconds, | ||
"sr": self.sr, | ||
"hop_length": self.hop_length, | ||
} | ||
) | ||
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
from keras import mixed_precision | ||
from keras import ops | ||
|
||
from keras_hub.src.models.voxtral.voxtral_backbone import VoxTralBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class VoxTralBackboneTest(TestCase): | ||
"""Unit tests for VoxTralBackbone.""" | ||
|
||
def setUp(self): | ||
"""Initialize default backbone arguments and input data.""" | ||
self.init_kwargs = { | ||
"num_layers": 2, | ||
"num_heads": 2, | ||
"hidden_dim": 16, | ||
"intermediate_dim": 32, | ||
"adapter_downsample": 2, | ||
"dropout": 0.0, | ||
"max_chunk_seconds": 1, | ||
"sr": 16000, | ||
"hop_length": 160, | ||
"dtype": "float32", | ||
} | ||
# Dummy input: shape (batch, time, features) | ||
self.input_data = ops.ones((1, 2542, 128), dtype="float32") | ||
|
||
def test_backbone_basics(self): | ||
"""Test forward pass and output shape with float32.""" | ||
mixed_precision.set_global_policy("float32") | ||
model = VoxTralBackbone(**self.init_kwargs) | ||
output = model(self.input_data) | ||
assert tuple(output.shape) == (1, 650, 16) | ||
assert output.dtype.name == "float32" | ||
sanskarmodi8 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
"""Test saving and loading the model.""" | ||
self.run_model_saving_test( | ||
cls=VoxTralBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) | ||
sanskarmodi8 marked this conversation as resolved.
Show resolved
Hide resolved
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.