diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 748f2f89e..8703ef920 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -167,7 +167,7 @@ class LlavaVisionAdapterConverter: @classmethod def import_config(cls, config: dict) -> dict: return { - "intermediate_size": config["vision_config"]["hidden_size"], + "intermediate_size": config["text_config"]["hidden_size"], "add_linear_biases": config["multimodal_projector_bias"], "gated": False, "activation": ActivationType.from_hf_name(config["projector_hidden_act"]), @@ -183,8 +183,6 @@ def export_config(cls, config: MLPConfig) -> dict: return { "projector_hidden_act": config.activation.hf_name, "multimodal_projector_bias": config.add_linear_biases, - # Not in LlavaConfig, but needed for consistency check in LlavaBaseModelConverter. - "projector_intermediate_size": config.intermediate_size, } @classmethod @@ -243,13 +241,13 @@ def export_config(cls, config: VisionEncoderConfig) -> dict: def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: return [ *cls.embeddings_converter_class.get_converters( - config.embeddings, "vision_encoder.embeddings", "model.vision_tower" + config.embeddings, "vision_encoder.embeddings", "vision_tower" ), *cls.encoder_converter_class.get_converters( - config.encoder, "vision_encoder.encoder", "model.vision_tower.transformer.layers" + config.encoder, "vision_encoder.encoder", "vision_tower.transformer.layers" ), *cls.vision_adapter_converter_class.get_converters( - config.adapter, "vision_encoder.adapter", "model.multi_modal_projector" + config.adapter, "vision_encoder.adapter", "multi_modal_projector" ), ] @@ -266,11 +264,11 @@ def get_converters( *cls.normalization_converter_class.get_converters( config.normalization, f"{fast_llm_prefix}.final_norm", - f"model.language_model.norm", + f"language_model.model.norm", ), get_parameter_converter( f"{fast_llm_prefix}.output_weights", - "lm_head.weight", + "language_model.lm_head.weight", drop_on_import=exported_config["tie_word_embeddings"], ), ] @@ -309,7 +307,6 @@ def export_config(cls, config: MultiModalBaseModelConfig) -> dict: "vision_feature_layer": -1, }, ) - Assert.eq(out.pop("projector_intermediate_size"), out["text_config"]["hidden_size"]) return out @classmethod @@ -317,10 +314,10 @@ def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict return [ *cls.vision_model_converter_class.get_converters(config.vision_encoder), *cls.language_model_converter_class.embeddings_converter_class.get_converters( - config.embeddings, "embeddings", "model.language_model" + config.embeddings, "embeddings", "language_model.model" ), *cls.language_model_converter_class.decoder_converter_class.get_converters( - config.decoder, "decoder", "model.language_model.layers" + config.decoder, "decoder", "language_model.model.layers" ), *cls.language_model_converter_class.head_converter_class.get_converters( config.head, {"tie_word_embeddings": False}, "head" diff --git a/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py index 9d1f014d8..eeeb0bca5 100644 --- a/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py +++ b/fast_llm_external_models/llava_hybrid/configuration_llava_hybrid.py @@ -59,7 +59,6 @@ def __init__( text_config=None, image_token_index=32000, projector_hidden_act="gelu", - projector_intermediate_size=4096, vision_feature_select_strategy="default", vision_feature_layer=-2, image_seq_length=576, @@ -68,8 +67,6 @@ def __init__( ): self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act - # projector_intermediate_size is an addition to the original Llava config - self.projector_intermediate_size = projector_intermediate_size self.image_seq_length = image_seq_length if vision_feature_select_strategy not in ["default", "full"]: diff --git a/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py b/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py index 243413a33..e51915321 100644 --- a/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm_external_models/llava_hybrid/modeling_llava_hybrid.py @@ -22,12 +22,12 @@ def __init__(self, config: LlavaHybridConfig): num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) self.linear_1 = nn.Linear( config.vision_config.hidden_size * num_feature_layers, - config.projector_intermediate_size, + config.text_config.hidden_size, bias=config.multimodal_projector_bias, ) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = nn.Linear( - config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias ) def forward(self, image_features):