diff --git a/keras_hub/src/models/mistral/mistral_attention.py b/keras_hub/src/models/mistral/mistral_attention.py index 6916133b78..02322663bd 100644 --- a/keras_hub/src/models/mistral/mistral_attention.py +++ b/keras_hub/src/models/mistral/mistral_attention.py @@ -26,6 +26,7 @@ def __init__( rope_scaling_factor=1.0, kernel_initializer="glorot_uniform", sliding_window=512, + mistral_type='Mistral', dropout=0, **kwargs, ): @@ -34,7 +35,7 @@ def __init__( self._num_key_value_heads = num_key_value_heads self._sliding_window = sliding_window self._dropout = dropout - + self._type = mistral_type self._num_key_value_groups = num_query_heads // num_key_value_heads self._rope_max_wavelength = rope_max_wavelength @@ -57,13 +58,17 @@ def build(self, inputs_shape): self._head_dim = self._hidden_dim // self._num_query_heads self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) + self._query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", output_shape=(None, self._num_query_heads, self._head_dim), kernel_initializer=self._kernel_initializer, dtype=self.dtype_policy, - name="query", + name= + "query", ) + if self._type == 'devstral': + inputs_shape = (None,None,4096) self._query_dense.build(inputs_shape) self._key_dense = keras.layers.EinsumDense( @@ -110,6 +115,8 @@ def build(self, inputs_shape): dtype=self.dtype_policy, name="attention_output", ) + if self._type == 'devstral': + self._head_dim = 128 self._output_dense.build( (None, None, self._num_query_heads, self._head_dim) ) diff --git a/keras_hub/src/models/mistral/mistral_backbone.py b/keras_hub/src/models/mistral/mistral_backbone.py index 09a5d38129..76319729e4 100644 --- a/keras_hub/src/models/mistral/mistral_backbone.py +++ b/keras_hub/src/models/mistral/mistral_backbone.py @@ -101,6 +101,7 @@ def __init__( layer_norm_epsilon=1e-6, sliding_window=512, dropout=0, + mistral_type='Mistral', dtype=None, **kwargs, ): @@ -127,6 +128,7 @@ def __init__( sliding_window=sliding_window, dropout=dropout, dtype=dtype, + mistral_type=mistral_type, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) diff --git a/keras_hub/src/models/mistral/mistral_transformer_decoder.py b/keras_hub/src/models/mistral/mistral_transformer_decoder.py index 79d5e93f7a..ba17dc5352 100644 --- a/keras_hub/src/models/mistral/mistral_transformer_decoder.py +++ b/keras_hub/src/models/mistral/mistral_transformer_decoder.py @@ -31,6 +31,7 @@ def __init__( kernel_initializer="glorot_uniform", sliding_window=512, dropout=0, + mistral_type='Mistral', **kwargs, ): super().__init__(**kwargs) @@ -40,7 +41,7 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor - + self.mistral_type = mistral_type self.dropout = dropout self.sliding_window = sliding_window @@ -64,6 +65,7 @@ def build(self, decoder_sequence_shape): kernel_initializer=clone_initializer(self.kernel_initializer), dropout=self.dropout, dtype=self.dtype_policy, + mistral_type=self.mistral_type, name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) diff --git a/keras_hub/src/utils/transformers/convert_mistral.py b/keras_hub/src/utils/transformers/convert_mistral.py index 9c52a708ef..7f7f421de5 100644 --- a/keras_hub/src/utils/transformers/convert_mistral.py +++ b/keras_hub/src/utils/transformers/convert_mistral.py @@ -1,3 +1,4 @@ +import re import numpy as np from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone @@ -50,7 +51,6 @@ def convert_weights(backbone, loader, transformers_config): hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight", hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), ) - # Attention layers loader.port_weight( keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, @@ -59,6 +59,7 @@ def convert_weights(backbone, loader, transformers_config): np.transpose(hf_tensor.astype(np.float16)), keras_shape ), ) + loader.port_weight( keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, hf_weight_key=f"model.layers.{index}.self_attn.k_proj.weight", @@ -112,5 +113,9 @@ def convert_weights(backbone, loader, transformers_config): ) + def convert_tokenizer(cls, preset, **kwargs): - return cls(get_file(preset, "tokenizer.model"), **kwargs) + tokenizer_name = "tokenizer.model" + if re.search(r'devstral', preset,re.I): + tokenizer_name = "tekken.json" + return cls(get_file(preset, tokenizer_name), **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 4accea67a1..5365c37bd7 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -20,7 +20,7 @@ from keras_hub.src.utils.transformers import convert_qwen_moe from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader - +import re class TransformersPresetLoader(PresetLoader): def __init__(self, preset, config): @@ -70,7 +70,12 @@ def check_backbone_class(self): def load_backbone(self, cls, load_weights, **kwargs): keras_config = self.converter.convert_backbone_config(self.config) + + if re.search(r'devstral', self.preset,re.I): + keras_config["mistral_type"] = "devstral" + backbone = cls(**{**keras_config, **kwargs}) + if load_weights: jax_memory_cleanup(backbone) with SafetensorLoader(self.preset) as loader: