From 38619fd8591fe27a06b19d2c5af63eed47636e68 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 20 Nov 2025 17:15:28 -0500 Subject: [PATCH 1/3] Llava conversion --- fast_llm/layers/vision/config.py | 3 +- fast_llm/models/gpt/conversion/llama.py | 4 +- fast_llm/models/multimodal/config.py | 7 +- .../models/multimodal/conversion/__init__.py | 0 fast_llm/models/multimodal/conversion/auto.py | 17 ++ .../models/multimodal/conversion/config.py | 25 ++ .../models/multimodal/conversion/llava.py | 279 ++++++++++++++++++ .../multimodal/conversion/llava_hybrid.py | 46 +++ .../configuration_llava_hybrid.py | 119 ++++++++ .../llava_hybrid/modeling_llava_hybrid.py | 132 +++++++++ tests/utils/model_configs.py | 9 +- 11 files changed, 633 insertions(+), 8 deletions(-) create mode 100644 fast_llm/models/multimodal/conversion/__init__.py create mode 100644 fast_llm/models/multimodal/conversion/auto.py create mode 100644 fast_llm/models/multimodal/conversion/config.py create mode 100644 fast_llm/models/multimodal/conversion/llava.py create mode 100644 fast_llm/models/multimodal/conversion/llava_hybrid.py create mode 100644 fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py create mode 100644 fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index fb05a520c..13e23829e 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -85,7 +85,7 @@ class PatchConvolutionConfig(BlockConfig): ) @functools.cached_property - def input_channels(self): + def input_channels(self) -> int: # Number of input channels. Currently hard-coded to 3 (RGB). return 3 @@ -99,6 +99,7 @@ def layer_class(self) -> "type[PatchConvolution]": @config_class(registry=True) class VisionEncoderConfig(BlockConfig): _abstract = False + # TODO: ====== Rename to patch_embeddings? ====== patch_convolution: PatchConvolutionConfig = Field( desc="Configuration for the patch convolution layer.", hint=FieldHint.architecture, diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index a92492260..530f3359e 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -10,7 +10,7 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig @@ -498,7 +498,7 @@ def get_converters( ] -class LlamaBaseModelConverter: +class LlamaBaseModelConverter(HuggingFaceBaseModelConverter): # TODO: Peft? decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py index 7bce78853..e07f596ad 100644 --- a/fast_llm/models/multimodal/config.py +++ b/fast_llm/models/multimodal/config.py @@ -14,6 +14,7 @@ GPTTrainerConfig, PretrainedGPTModelConfig, ) +from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat, LlavaHybridSSMCheckpointFormat if typing.TYPE_CHECKING: from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalModel @@ -41,8 +42,10 @@ class MultiModalModelConfig(GPTModelConfig): _abstract = False model_name: typing.ClassVar[str] = "multimodal" base_model: MultiModalBaseModelConfig = FieldUpdate() - # TODO: ====== Conversion ====== - checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( + LlavaCheckpointFormat, + LlavaHybridSSMCheckpointFormat, + ) @classmethod def get_model_class(cls) -> type["MultiModalModel"]: diff --git a/fast_llm/models/multimodal/conversion/__init__.py b/fast_llm/models/multimodal/conversion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/multimodal/conversion/auto.py b/fast_llm/models/multimodal/conversion/auto.py new file mode 100644 index 000000000..3660ef5f5 --- /dev/null +++ b/fast_llm/models/multimodal/conversion/auto.py @@ -0,0 +1,17 @@ +import abc + +from fast_llm.engine.checkpoint.external import AutoStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat, LlavaHybridSSMCheckpointFormat +from fast_llm.models.multimodal.conversion.llava import LlavaHuggingfaceCheckpointHandler +from fast_llm.models.multimodal.conversion.llava_hybrid import LlavaHybridSSMHuggingfaceCheckpointHandler + + +class AutoMultimodalHuggingfaceCheckpointHandler( + AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC +): + + handler_map = { + LlavaCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + LlavaHybridSSMCheckpointFormat.name: LlavaHybridSSMHuggingfaceCheckpointHandler, + } diff --git a/fast_llm/models/multimodal/conversion/config.py b/fast_llm/models/multimodal/conversion/config.py new file mode 100644 index 000000000..b8663e113 --- /dev/null +++ b/fast_llm/models/multimodal/conversion/config.py @@ -0,0 +1,25 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler + + +class MultimodalHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.multimodal.conversion.auto import AutoMultimodalHuggingfaceCheckpointHandler + + return AutoMultimodalHuggingfaceCheckpointHandler.get_handler_class(cls.name) + + +class AutoMultimodalHuggingfaceCheckpointFormat(MultimodalHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "auto" + + +class LlavaCheckpointFormat(MultimodalHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" + + +class LlavaHybridSSMCheckpointFormat(MultimodalHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava_hybrid_ssm" diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py new file mode 100644 index 000000000..549f367e6 --- /dev/null +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -0,0 +1,279 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import Rotary2DConfig +from fast_llm.layers.common.normalization.config import RMSNormalizationConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.vision.config import PatchConvolutionConfig, VisionEncoderConfig +from fast_llm.models.gpt.conversion.llama import ( + LlamaAttentionConverter, + LlamaBlockConverter, + LlamaDecoderConverter, + LlamaMLPConverter, + LlamaNormalizationConverter, + MLPLayer2Converter, + get_weight_and_bias_converters, +) +from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter +from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig +from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat +from fast_llm.models.multimodal.model import MultiModalModel +from fast_llm.utils import Assert, div, safe_merge_dicts + + +class PixtralNormalizationConverter(LlamaNormalizationConverter): + """ + epsilon hard-coded to 1e-5. + """ + + @classmethod + def import_config(cls, config: dict) -> dict: + return {"type": "rms_norm", "epsilon": 1e-5} + + @classmethod + def export_config(cls, config: RMSNormalizationConfig) -> dict: + Assert.custom(isinstance, config, RMSNormalizationConfig) + assert not config.zero_centered + # TODO: Too strict? + Assert.eq(config.epsilon, 1e-5) + return {} + + +# TODO: ====== MistralAttentionConverter (#391 / #382) ====== +class PixtralAttentionConverter(LlamaAttentionConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["num_key_value_heads"] = config["num_attention_heads"] + config["attention_bias"] = False + out = super().import_config(config) + out["rotary"]["type"] = "default_2d" + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + cls._check_config(config) + Assert.eq(config.softmax_scale_power, 0.5) + Assert.is_(type(config.rotary), Rotary2DConfig) + assert not config.add_linear_biases + Assert.eq(config.head_groups, config.heads) + return { + "num_attention_heads": config.heads, + "attention_dropout": config.dropout, + "rope_theta": config.rotary.theta, + # Not in PixtralConfig, but needed for consistency check in LlavaVisionModelConverter. + "head_dim": config.head_size, + } + + +class PixtralBlockConverter(LlamaBlockConverter): + mixer_converter_class: typing.ClassVar[type[PixtralAttentionConverter]] = PixtralAttentionConverter + # TODO: ====== MistralMLPConverter (#391 / #382) ====== + mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter + normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter + hf_mixer_name: typing.ClassVar[str] = "attention" + hf_mlp_name: typing.ClassVar[str] = "feed_forward" + hf_norm_1_name: typing.ClassVar[str] = "attention_norm" + hf_norm_2_name: typing.ClassVar[str] = "ffn_norm" + + +class PixtralEncoderConverter(LlamaDecoderConverter): + block_converter_class: typing.ClassVar[type[PixtralBlockConverter]] = PixtralBlockConverter + + +class PixtralPatchConvolutionConverter: + normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + Assert.eq(config["num_channels"], 3) + return { + "normalization": cls.normalization_converter_class.import_config(config), + "patch_height": config["patch_size"], + "patch_width": config["patch_size"], + } + + @classmethod + def export_config(cls, config: PatchConvolutionConfig) -> dict: + Assert.custom(isinstance, config, PatchConvolutionConfig) + Assert.eq(config.patch_height, config.patch_width) + Assert.incl(config.convolution.bias.enabled, (None, False)) + + return safe_merge_dicts( + { + "patch_size": config.patch_height, + "num_channels": config.input_channels, + }, + cls.normalization_converter_class.export_config(config.normalization), + ) + + @classmethod + def get_converters( + cls, config: PatchConvolutionConfig, fast_llm_prefix: str, hf_prefix: str + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.patch_conv", + False, + WeightConverter, + ), + *cls.normalization_converter_class.get_converters( + config, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.ln_pre" + ), + ] + + +class LlavaVisionAdapterConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "intermediate_size": config["projector_intermediate_size"], + "add_linear_biases": config["multimodal_projector_bias"], + "gated": False, + "activation": ActivationType.from_hf_name(config["projector_hidden_act"]), + } + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + Assert.custom(isinstance, config, MLPConfig) + Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) + assert not config.gated + + return { + "projector_hidden_act": config.activation.hf_name, + "projector_intermediate_size": config.intermediate_size, + "multimodal_projector_bias": config.add_linear_biases, + } + + @classmethod + def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + f"{hf_prefix}.linear_1", + config.add_linear_biases, + WeightConverter, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.linear_2", + config.add_linear_biases, + MLPLayer2Converter, + ), + ] + + +class LlavaVisionModelConverter: + vision_adapter_converter_class: typing.ClassVar[type[LlavaVisionAdapterConverter]] = LlavaVisionAdapterConverter + patch_convolution_converter_class: typing.ClassVar[type[PixtralPatchConvolutionConverter]] = ( + PixtralPatchConvolutionConverter + ) + encoder_converter_class: typing.ClassVar[type[PixtralEncoderConverter]] = PixtralEncoderConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "patch_convolution": cls.patch_convolution_converter_class.import_config(config["vision_config"]), + "encoder": cls.encoder_converter_class.import_config(config["vision_config"]), + "adapter": cls.vision_adapter_converter_class.import_config(config), + "hidden_size": config["vision_config"]["hidden_size"], + } + + @classmethod + def export_config(cls, config: VisionEncoderConfig) -> dict: + Assert.custom(isinstance, config, VisionEncoderConfig) + # TODO: ====== image_size? ====== + vision_config = safe_merge_dicts( + cls.patch_convolution_converter_class.export_config(config.patch_convolution), + cls.encoder_converter_class.export_config(config.encoder), + {"hidden_size": config.hidden_size}, + ) + + Assert.eq( + vision_config.pop("head_dim"), div(vision_config["hidden_size"], vision_config["num_attention_heads"]) + ) + + return safe_merge_dicts( + {"vision_config": vision_config}, + cls.vision_adapter_converter_class.export_config(config.adapter), + # TODO: ====== What about these? ====== + # { + # "image_token_index":32000, + # "vision_feature_select_strategy":"default", + # "vision_feature_layer":-2, + # "image_seq_length":576, + # } + ) + + @classmethod + def get_converters( + cls, config: VisionEncoderConfig, fast_llm_prefix: str, hf_prefix: str + ) -> list[WeightConverter]: + return [ + *cls.patch_convolution_converter_class.get_converters( + config.patch_convolution, f"{fast_llm_prefix}.patch_convolution", hf_prefix + ), + *cls.encoder_converter_class.get_converters( + config.encoder, f"{fast_llm_prefix}.encoder", f"{hf_prefix}.transformer" + ), + *cls.vision_adapter_converter_class.get_converters( + config.adapter, f"{fast_llm_prefix}.adapter", f"{hf_prefix}.multi_modal_projector" + ), + ] + + +class LlavaBaseModelConverter(HuggingFaceBaseModelConverter): + vision_model_converter_class: typing.ClassVar[type[LlavaVisionModelConverter]] = LlavaVisionModelConverter + # TODO: Make it flexible? + language_model_converter_class: typing.ClassVar[type[MistralBaseModelConverter]] = MistralBaseModelConverter + # TODO: ====== Is tie_word_embeddings supported? ====== + + @classmethod + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts( + {"vision_encoder": cls.vision_model_converter_class.import_config(config)}, + cls.language_model_converter_class.import_config(config["text_config"]), + ) + + @classmethod + def export_config(cls, config: MultiModalBaseModelConfig) -> dict: + Assert.custom(isinstance, config, MultiModalBaseModelConfig) + return safe_merge_dicts( + cls.vision_model_converter_class.export_config(config.vision_encoder), + {"text_config": cls.language_model_converter_class.export_config(config)}, + ) + + @classmethod + def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + return [ + *cls.vision_model_converter_class.get_converters(config.vision_encoder, "vision_encoder", "model"), + *cls.language_model_converter_class.embeddings_converter_class.get_converters( + config.embeddings, "embeddings", "model.language_model" + ), + *cls.language_model_converter_class.decoder_converter_class.get_converters( + config.decoder, "decoder", "model.language_model.layers" + ), + *cls.language_model_converter_class.head_converter_class.get_converters( + config.head, {"tie_word_embeddings": False}, "head" + ), + ] + + +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model: MultiModalModel + _model_class: typing.ClassVar[FastLLMModelConfig] = MultiModalModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LlavaCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaForConditionalGeneration" + base_model_converter_class: typing.ClassVar[type[LlavaBaseModelConverter]] = LlavaBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.LlavaConfig diff --git a/fast_llm/models/multimodal/conversion/llava_hybrid.py b/fast_llm/models/multimodal/conversion/llava_hybrid.py new file mode 100644 index 000000000..da84455a6 --- /dev/null +++ b/fast_llm/models/multimodal/conversion/llava_hybrid.py @@ -0,0 +1,46 @@ +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.models.gpt.conversion.apriel import AprielBaseModelConverter +from fast_llm.models.multimodal.config import MultiModalModelConfig +from fast_llm.models.multimodal.conversion.config import LlavaHybridSSMCheckpointFormat +from fast_llm.models.multimodal.conversion.llava import LlavaBaseModelConverter, LlavaHuggingfaceCheckpointHandler +from fast_llm.utils import safe_merge_dicts + + +class LlavaHybridBaseModelConverter(LlavaBaseModelConverter): + language_model_converter_class: typing.ClassVar[type[AprielBaseModelConverter]] = AprielBaseModelConverter + + +class LlavaHybridSSMHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridSSMCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaHybridForConditionalGeneration" + base_model_converter_class: typing.ClassVar[type[LlavaHybridBaseModelConverter]] = LlavaHybridBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.llava_hybrid.configuration_llava_hybrid import LlavaHybridConfig + + return LlavaHybridConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid + + return configuration_llava_hybrid.__file__, modeling_llava_hybrid.__file__, None + + @classmethod + def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + }, + }, + ) diff --git a/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py new file mode 100644 index 000000000..9d1f014d8 --- /dev/null +++ b/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py @@ -0,0 +1,119 @@ +import transformers +from transformers import MistralConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +# Copied from configuration_ssm_hybrid_apriel15b.py +# TODO: split into mamba 2 and discrete mamba 2 configs with a base dict +ssm_config_default = { + # discrete mamba2 + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, + # mamba2 + "d_xb": None, # will be set to model dim + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, +} +transformers.CLIPModel + + +class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config + + +class LlavaHybridConfig(PretrainedConfig): + """ + Configuration class for Llava SSM-Hybrid-decoder model. + """ + + model_type = "llava_hybrid" + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + projector_intermediate_size=4096, + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + # projector_intermediate_size is an addition to the original Llava config + self.projector_intermediate_size = projector_intermediate_size + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + # Load the custom SSM hybrid config if specified + if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": + text_config = AprielSSMHybridConfig(**text_config) + else: + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) + + +__all__ = ["LlavaHybridConfig"] diff --git a/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py new file mode 100644 index 000000000..243413a33 --- /dev/null +++ b/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py @@ -0,0 +1,132 @@ +from torch import nn +from transformers import AutoModel, LlavaForConditionalGeneration, LlavaModel +from transformers.activations import ACT2FN + +from .configuration_llava_hybrid import LlavaHybridConfig + +try: + # In the fast-llm repo, import from the SSM modeling file + from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import ( + AprielThinkerSSMHybridModel, + HybridMambaAttentionDynamicCache, + ) +except ImportError: + # In the exported checkpoint, import from local file + from .modeling_apriel_hybrid_ssm import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaHybridConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.projector_intermediate_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaHybridModel(LlavaModel): + """ + Llava SSM-Hybrid-decoder model. + """ + + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaModel, self).__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + assert ( + config.text_config.model_type == "apriel_ssm_thinker_hybrid" + ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" + + self.language_model = AprielThinkerSSMHybridModel(config.text_config) + self.post_init() + + +class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaForConditionalGeneration, self).__init__(config) + self.model = LlavaHybridModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config.text_config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + # Copy from `LlavaForConditionalGeneration.prepare_inputs_for_generation` + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1ed99416e..143fa7bab 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -22,6 +22,7 @@ MTPLlamaCheckpointFormat, Qwen2CheckpointFormat, ) +from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from tests.utils.dataset import get_model_test_dataset, get_multimodal_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import MODEL_TEST_VOCAB_SIZE @@ -689,7 +690,7 @@ def _update_and_add_testing_config( model_type="multimodal", updates={ ("model", "base_model", "vision_encoder"): { - "patch_convolution": {"patch_height": 4, "patch_width": 4}, + "patch_convolution": {"patch_height": 4, "patch_width": 4, "normalization": {"type": "rms_norm"}}, "encoder": copy.deepcopy(MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]), "adapter": {"intermediate_size": 512}, "hidden_size": 256, @@ -699,14 +700,16 @@ def _update_and_add_testing_config( ("model", "base_model", "vision_encoder", "encoder", "num_blocks"): 1, ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "causal"): False, ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "cross_document_attention"): False, + # Pixtal doesn't support GQA + ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "head_groups"): 8, }, get_dataset=get_multimodal_test_dataset, megatron_args=None, - checkpoint_format=None, + checkpoint_format=LlavaCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement From a0fac2d400a2d84176ea4b2ed691e3e228c0ec6d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 20 Nov 2025 17:30:05 -0500 Subject: [PATCH 2/3] fix --- fast_llm/models/gpt/conversion/mistral.py | 4 ++-- fast_llm/models/multimodal/conversion/llava.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index a9a0909ec..d4a669b22 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -40,7 +40,7 @@ def _check_config(cls, config: AttentionConfig) -> None: assert not config.add_linear_biases -class MistrallMLPConverter(LlamaMLPConverter): +class MistralMLPConverter(LlamaMLPConverter): @classmethod def import_config(cls, config: dict) -> dict: config["mlp_bias"] = False @@ -56,7 +56,7 @@ def export_config(cls, config: MLPConfig) -> dict: class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter - mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter + mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter class MistralDecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 549f367e6..58077bfd4 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -14,12 +14,11 @@ LlamaAttentionConverter, LlamaBlockConverter, LlamaDecoderConverter, - LlamaMLPConverter, LlamaNormalizationConverter, MLPLayer2Converter, get_weight_and_bias_converters, ) -from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter +from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralMLPConverter from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel @@ -44,7 +43,6 @@ def export_config(cls, config: RMSNormalizationConfig) -> dict: return {} -# TODO: ====== MistralAttentionConverter (#391 / #382) ====== class PixtralAttentionConverter(LlamaAttentionConverter): @classmethod def import_config(cls, config: dict) -> dict: @@ -73,7 +71,7 @@ def export_config(cls, config: AttentionConfig) -> dict: class PixtralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[PixtralAttentionConverter]] = PixtralAttentionConverter # TODO: ====== MistralMLPConverter (#391 / #382) ====== - mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter + mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter hf_mixer_name: typing.ClassVar[str] = "attention" hf_mlp_name: typing.ClassVar[str] = "feed_forward" From f7197678ff80ba4df31034c3aff7350ddb758f68 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 21 Nov 2025 19:12:09 -0500 Subject: [PATCH 3/3] fixes --- fast_llm/layers/vision/config.py | 5 ++ fast_llm/models/gpt/huggingface.py | 1 + .../models/multimodal/conversion/llava.py | 77 +++++++++++++------ tests/utils/model_configs.py | 5 +- 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index 13e23829e..bd1c69160 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -133,6 +133,11 @@ class VisionMultiModalModelConfig(LanguageModelConfig): hint=FieldHint.architecture, desc="Configuration for the vision encoder.", ) + image_token_index: int | None = Field( + default=None, + hint=FieldHint.optional, + desc="Index of the image token. Unused, but required for Hugging Face conversion.", + ) @property def layer_class(self) -> "type[VisionMultiModalModel]": diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 34e38469a..0756f62ac 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -45,6 +45,7 @@ def inner_forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: # TODO: Most of this is generalizable. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 58077bfd4..3342fe5e8 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -9,6 +9,7 @@ from fast_llm.layers.attention.rotary.config import Rotary2DConfig from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.vision.config import PatchConvolutionConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -16,9 +17,10 @@ LlamaDecoderConverter, LlamaNormalizationConverter, MLPLayer2Converter, + get_parameter_converter, get_weight_and_bias_converters, ) -from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralMLPConverter +from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralHeadConverter, MistralMLPConverter from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel @@ -130,7 +132,7 @@ class LlavaVisionAdapterConverter: @classmethod def import_config(cls, config: dict) -> dict: return { - "intermediate_size": config["projector_intermediate_size"], + "intermediate_size": config["vision_config"]["hidden_size"], "add_linear_biases": config["multimodal_projector_bias"], "gated": False, "activation": ActivationType.from_hf_name(config["projector_hidden_act"]), @@ -145,8 +147,9 @@ def export_config(cls, config: MLPConfig) -> dict: return { "projector_hidden_act": config.activation.hf_name, - "projector_intermediate_size": config.intermediate_size, "multimodal_projector_bias": config.add_linear_biases, + # Not in LlavaConfig, but needed for consistency check in LlavaBaseModelConverter. + "projector_intermediate_size": config.intermediate_size, } @classmethod @@ -173,9 +176,11 @@ class LlavaVisionModelConverter: PixtralPatchConvolutionConverter ) encoder_converter_class: typing.ClassVar[type[PixtralEncoderConverter]] = PixtralEncoderConverter + model_type: typing.ClassVar[str] = "pixtral" @classmethod def import_config(cls, config: dict) -> dict: + Assert.eq(config["vision_config"]["model_type"], cls.model_type) return { "patch_convolution": cls.patch_convolution_converter_class.import_config(config["vision_config"]), "encoder": cls.encoder_converter_class.import_config(config["vision_config"]), @@ -190,7 +195,7 @@ def export_config(cls, config: VisionEncoderConfig) -> dict: vision_config = safe_merge_dicts( cls.patch_convolution_converter_class.export_config(config.patch_convolution), cls.encoder_converter_class.export_config(config.encoder), - {"hidden_size": config.hidden_size}, + {"hidden_size": config.hidden_size, "model_type": cls.model_type}, ) Assert.eq( @@ -200,57 +205,85 @@ def export_config(cls, config: VisionEncoderConfig) -> dict: return safe_merge_dicts( {"vision_config": vision_config}, cls.vision_adapter_converter_class.export_config(config.adapter), - # TODO: ====== What about these? ====== - # { - # "image_token_index":32000, - # "vision_feature_select_strategy":"default", - # "vision_feature_layer":-2, - # "image_seq_length":576, - # } ) @classmethod - def get_converters( - cls, config: VisionEncoderConfig, fast_llm_prefix: str, hf_prefix: str - ) -> list[WeightConverter]: + def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: return [ *cls.patch_convolution_converter_class.get_converters( - config.patch_convolution, f"{fast_llm_prefix}.patch_convolution", hf_prefix + config.patch_convolution, "vision_encoder.patch_convolution", "model.vision_tower" ), *cls.encoder_converter_class.get_converters( - config.encoder, f"{fast_llm_prefix}.encoder", f"{hf_prefix}.transformer" + config.encoder, "vision_encoder.encoder", "model.vision_tower.transformer.layers" ), *cls.vision_adapter_converter_class.get_converters( - config.adapter, f"{fast_llm_prefix}.adapter", f"{hf_prefix}.multi_modal_projector" + config.adapter, "vision_encoder.adapter", "model.multi_modal_projector" + ), + ] + + +class LlavaHeadConverter(MistralHeadConverter): + @classmethod + def get_converters( + cls, + config: LanguageModelHeadConfig, + exported_config: dict, + fast_llm_prefix: str, + ) -> list[WeightConverter]: + return [ + *cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.final_norm", + f"model.language_model.norm", + ), + get_parameter_converter( + f"{fast_llm_prefix}.output_weights", + "lm_head.weight", + drop_on_import=exported_config["tie_word_embeddings"], ), ] +class LlavaLanguageModelConverter(MistralBaseModelConverter): + head_converter_class: typing.ClassVar[type[LlavaHeadConverter]] = LlavaHeadConverter + + class LlavaBaseModelConverter(HuggingFaceBaseModelConverter): vision_model_converter_class: typing.ClassVar[type[LlavaVisionModelConverter]] = LlavaVisionModelConverter # TODO: Make it flexible? - language_model_converter_class: typing.ClassVar[type[MistralBaseModelConverter]] = MistralBaseModelConverter + language_model_converter_class: typing.ClassVar[type[LlavaLanguageModelConverter]] = LlavaLanguageModelConverter # TODO: ====== Is tie_word_embeddings supported? ====== @classmethod def import_config(cls, config: dict) -> dict: return safe_merge_dicts( - {"vision_encoder": cls.vision_model_converter_class.import_config(config)}, + { + "vision_encoder": cls.vision_model_converter_class.import_config(config), + "image_token_index": config["image_token_index"], + }, cls.language_model_converter_class.import_config(config["text_config"]), ) @classmethod def export_config(cls, config: MultiModalBaseModelConfig) -> dict: Assert.custom(isinstance, config, MultiModalBaseModelConfig) - return safe_merge_dicts( + assert config.image_token_index is not None + out = safe_merge_dicts( cls.vision_model_converter_class.export_config(config.vision_encoder), - {"text_config": cls.language_model_converter_class.export_config(config)}, + { + "text_config": cls.language_model_converter_class.export_config(config), + "image_token_index": config.image_token_index, + "vision_feature_select_strategy": "full", + "vision_feature_layer": -1, + }, ) + Assert.eq(out.pop("projector_intermediate_size"), out["text_config"]["hidden_size"]) + return out @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: return [ - *cls.vision_model_converter_class.get_converters(config.vision_encoder, "vision_encoder", "model"), + *cls.vision_model_converter_class.get_converters(config.vision_encoder), *cls.language_model_converter_class.embeddings_converter_class.get_converters( config.embeddings, "embeddings", "model.language_model" ), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 143fa7bab..ab59505f5 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -692,10 +692,13 @@ def _update_and_add_testing_config( ("model", "base_model", "vision_encoder"): { "patch_convolution": {"patch_height": 4, "patch_width": 4, "normalization": {"type": "rms_norm"}}, "encoder": copy.deepcopy(MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]), - "adapter": {"intermediate_size": 512}, + "adapter": {"intermediate_size": 256}, "hidden_size": 256, }, ("model", "base_model", "decoder", "num_blocks"): 1, + # Extend the vocab size to ensure the token id is not in the mock dataset. + ("model", "base_model", "embeddings", "vocab_size"): 386, + ("model", "base_model", "image_token_index"): 384, ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "rotary", "type"): "default_2d", ("model", "base_model", "vision_encoder", "encoder", "num_blocks"): 1, ("model", "base_model", "vision_encoder", "encoder", "block", "mixer", "causal"): False,