Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions keras_hub/src/models/mistral/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
sliding_window=512,
mistral_type='Mistral',
dropout=0,
**kwargs,
):
Expand All @@ -34,7 +35,7 @@ def __init__(
self._num_key_value_heads = num_key_value_heads
self._sliding_window = sliding_window
self._dropout = dropout

self._type = mistral_type
self._num_key_value_groups = num_query_heads // num_key_value_heads
self._rope_max_wavelength = rope_max_wavelength

Expand All @@ -57,13 +58,17 @@ def build(self, inputs_shape):
self._head_dim = self._hidden_dim // self._num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)


self._query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self._num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
name=
"query",
)
if self._type == 'devstral':
inputs_shape = (None,None,4096)
self._query_dense.build(inputs_shape)

self._key_dense = keras.layers.EinsumDense(
Expand Down Expand Up @@ -110,6 +115,8 @@ def build(self, inputs_shape):
dtype=self.dtype_policy,
name="attention_output",
)
if self._type == 'devstral':
self._head_dim = 128
self._output_dense.build(
(None, None, self._num_query_heads, self._head_dim)
)
Expand Down
2 changes: 2 additions & 0 deletions keras_hub/src/models/mistral/mistral_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
layer_norm_epsilon=1e-6,
sliding_window=512,
dropout=0,
mistral_type='Mistral',
dtype=None,
**kwargs,
):
Expand All @@ -127,6 +128,7 @@ def __init__(
sliding_window=sliding_window,
dropout=dropout,
dtype=dtype,
mistral_type=mistral_type,
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/mistral/mistral_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
kernel_initializer="glorot_uniform",
sliding_window=512,
dropout=0,
mistral_type='Mistral',
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -40,7 +41,7 @@ def __init__(

self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor

self.mistral_type = mistral_type
self.dropout = dropout

self.sliding_window = sliding_window
Expand All @@ -64,6 +65,7 @@ def build(self, decoder_sequence_shape):
kernel_initializer=clone_initializer(self.kernel_initializer),
dropout=self.dropout,
dtype=self.dtype_policy,
mistral_type=self.mistral_type,
name="self_attention",
)
self._self_attention_layer.build(decoder_sequence_shape)
Expand Down
9 changes: 7 additions & 2 deletions keras_hub/src/utils/transformers/convert_mistral.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import numpy as np

from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone
Expand Down Expand Up @@ -50,7 +51,6 @@ def convert_weights(backbone, loader, transformers_config):
hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight",
hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16),
)

# Attention layers
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
Expand All @@ -59,6 +59,7 @@ def convert_weights(backbone, loader, transformers_config):
np.transpose(hf_tensor.astype(np.float16)), keras_shape
),
)

loader.port_weight(
keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
hf_weight_key=f"model.layers.{index}.self_attn.k_proj.weight",
Expand Down Expand Up @@ -112,5 +113,9 @@ def convert_weights(backbone, loader, transformers_config):
)



def convert_tokenizer(cls, preset, **kwargs):
return cls(get_file(preset, "tokenizer.model"), **kwargs)
tokenizer_name = "tokenizer.model"
if re.search(r'devstral', preset,re.I):
tokenizer_name = "tekken.json"
return cls(get_file(preset, tokenizer_name), **kwargs)
7 changes: 6 additions & 1 deletion keras_hub/src/utils/transformers/preset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from keras_hub.src.utils.transformers import convert_qwen_moe
from keras_hub.src.utils.transformers import convert_vit
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader

import re

class TransformersPresetLoader(PresetLoader):
def __init__(self, preset, config):
Expand Down Expand Up @@ -70,7 +70,12 @@ def check_backbone_class(self):

def load_backbone(self, cls, load_weights, **kwargs):
keras_config = self.converter.convert_backbone_config(self.config)

if re.search(r'devstral', self.preset,re.I):
keras_config["mistral_type"] = "devstral"

backbone = cls(**{**keras_config, **kwargs})

if load_weights:
jax_memory_cleanup(backbone)
with SafetensorLoader(self.preset) as loader:
Expand Down
Loading