From fc27f3c46647d67141e7ad0c8ca1c1a2735480fb Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Thu, 12 Jun 2025 12:26:29 +0200 Subject: [PATCH 01/42] update processors to new signatures --- colpali_engine/utils/processing_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 537b21a2..98571c68 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -28,6 +28,7 @@ class BaseVisualRetrieverProcessor(ABC): def process_images( self, images: List[Image.Image], + contexts: Optional[List[str]] = None, ) -> Union[BatchFeature, BatchEncoding]: """ Process a list of images into a format suitable for the model. @@ -56,6 +57,7 @@ def process_queries( texts: Optional[List[str]] = None, queries: Optional[List[str]] = None, max_length: int = 50, + contexts: Optional[List[str]] = None, suffix: Optional[str] = None, ) -> Union[BatchFeature, BatchEncoding]: """ From 31c0709ef63c2b911b8d28057eecad1affbe8cbf Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Thu, 12 Jun 2025 12:29:03 +0200 Subject: [PATCH 02/42] lint --- colpali_engine/utils/processing_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 98571c68..ff452db4 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -28,7 +28,6 @@ class BaseVisualRetrieverProcessor(ABC): def process_images( self, images: List[Image.Image], - contexts: Optional[List[str]] = None, ) -> Union[BatchFeature, BatchEncoding]: """ Process a list of images into a format suitable for the model. From 32de63cc1234dc165d31b02257c851e6cd26c9d6 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Thu, 12 Jun 2025 14:24:47 +0200 Subject: [PATCH 03/42] keep process_queries for back comp --- colpali_engine/utils/processing_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index ff452db4..d6bc4558 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -89,6 +89,30 @@ def process_queries( return self.process_texts(texts=texts) + def process_queries( + self, + texts: List[str], + max_length: int = 50, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process a list of queries into a format suitable for the model. + Args: + texts (List[str]): List of texts to process. + max_length (int, optional): Maximum length of the texts. Defaults to 50. + suffix (Optional[str], optional): Optional suffix to append to each text. + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + + NOTE: This function maintains back-compatibility, use `process_texts` for better control on context. + """ + return self.process_texts( + texts=texts, + contexts=[self.query_prefix] * len(texts), + max_length=max_length, + suffix=suffix, + ) + @abstractmethod def score( self, From d00b26724f31d5bb45bbf44bb55beead11173e72 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Mon, 16 Jun 2025 11:29:00 +0200 Subject: [PATCH 04/42] add vbert/vllama modeling --- colpali_engine/models/vbert/__init__.py | 2 + .../models/vbert/bivbert/__init__.py | 2 + .../models/vbert/bivbert/modeling_bivbert.py | 61 ++ .../vbert/bivbert/processing_bivbert.py | 51 + .../models/vbert/colvbert/__init__.py | 2 + .../vbert/colvbert/modeling_colvbert.py | 51 + .../vbert/colvbert/processing_colvbert.py | 96 ++ .../models/vbert/configuration_vbert.py | 232 +++++ colpali_engine/models/vbert/modeling_vbert.py | 930 ++++++++++++++++++ colpali_engine/models/vllama/__init__.py | 2 + .../models/vllama/bivllama/__init__.py | 2 + .../vllama/bivllama/modeling_bivllama.py | 61 ++ .../vllama/bivllama/processing_bivllama.py | 51 + .../models/vllama/colvllama/__init__.py | 2 + .../vllama/colvllama/modeling_colvllama.py | 51 + .../vllama/colvllama/processing_colvllama.py | 96 ++ .../models/vllama/configuration_vllama.py | 232 +++++ .../models/vllama/modeling_vllama.py | 883 +++++++++++++++++ 18 files changed, 2807 insertions(+) create mode 100644 colpali_engine/models/vbert/__init__.py create mode 100644 colpali_engine/models/vbert/bivbert/__init__.py create mode 100644 colpali_engine/models/vbert/bivbert/modeling_bivbert.py create mode 100644 colpali_engine/models/vbert/bivbert/processing_bivbert.py create mode 100644 colpali_engine/models/vbert/colvbert/__init__.py create mode 100644 colpali_engine/models/vbert/colvbert/modeling_colvbert.py create mode 100644 colpali_engine/models/vbert/colvbert/processing_colvbert.py create mode 100644 colpali_engine/models/vbert/configuration_vbert.py create mode 100644 colpali_engine/models/vbert/modeling_vbert.py create mode 100644 colpali_engine/models/vllama/__init__.py create mode 100644 colpali_engine/models/vllama/bivllama/__init__.py create mode 100644 colpali_engine/models/vllama/bivllama/modeling_bivllama.py create mode 100644 colpali_engine/models/vllama/bivllama/processing_bivllama.py create mode 100644 colpali_engine/models/vllama/colvllama/__init__.py create mode 100644 colpali_engine/models/vllama/colvllama/modeling_colvllama.py create mode 100644 colpali_engine/models/vllama/colvllama/processing_colvllama.py create mode 100644 colpali_engine/models/vllama/configuration_vllama.py create mode 100644 colpali_engine/models/vllama/modeling_vllama.py diff --git a/colpali_engine/models/vbert/__init__.py b/colpali_engine/models/vbert/__init__.py new file mode 100644 index 00000000..064334ea --- /dev/null +++ b/colpali_engine/models/vbert/__init__.py @@ -0,0 +1,2 @@ +from .bivbert import BiVBert, BiVBertProcessor +from .colvbert import ColVBert, ColVBertProcessor diff --git a/colpali_engine/models/vbert/bivbert/__init__.py b/colpali_engine/models/vbert/bivbert/__init__.py new file mode 100644 index 00000000..23bc11e3 --- /dev/null +++ b/colpali_engine/models/vbert/bivbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_bivbert import BiVBert +from .processing_bivbert import BiVBertProcessor diff --git a/colpali_engine/models/vbert/bivbert/modeling_bivbert.py b/colpali_engine/models/vbert/bivbert/modeling_bivbert.py new file mode 100644 index 00000000..470df18b --- /dev/null +++ b/colpali_engine/models/vbert/bivbert/modeling_bivbert.py @@ -0,0 +1,61 @@ +from typing import Literal + +import torch + +from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel + + +class BiVBert(VBertPreTrainedModel): + """ + Initializes the BiIdefics3 model. + + Args: + config : The model configuration. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, **kwargs): + super().__init__(config=config) + self.model = VBertModel(config, **kwargs) + self.post_init() + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = "last", + *args, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass through model and pooling. + + Args: + - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + + # Get CLS token embedding, last token, or mean pool over sequence + if pooling_strategy == "cls": + # Use CLS token (first token) embedding + pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) + elif pooling_strategy == "last": + # use last token since we are left padding + pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) + elif pooling_strategy == "mean": + # Mean pooling over sequence length + mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) + pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + # L2 normalization + pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True) + return pooled_output diff --git a/colpali_engine/models/vbert/bivbert/processing_bivbert.py b/colpali_engine/models/vbert/bivbert/processing_bivbert.py new file mode 100644 index 00000000..3bbf7750 --- /dev/null +++ b/colpali_engine/models/vbert/bivbert/processing_bivbert.py @@ -0,0 +1,51 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.vbert.colvbert import ColVBertProcessor + + +class BiVBertProcessor(ColVBertProcessor): # noqa: N801 + """ + Processor for BiVBert. + """ + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiVBert. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token # we remove buffer tokens + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/vbert/colvbert/__init__.py b/colpali_engine/models/vbert/colvbert/__init__.py new file mode 100644 index 00000000..2d05a989 --- /dev/null +++ b/colpali_engine/models/vbert/colvbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colvbert import ColVBert +from .processing_colvbert import ColVBertProcessor diff --git a/colpali_engine/models/vbert/colvbert/modeling_colvbert.py b/colpali_engine/models/vbert/colvbert/modeling_colvbert.py new file mode 100644 index 00000000..2beb2958 --- /dev/null +++ b/colpali_engine/models/vbert/colvbert/modeling_colvbert.py @@ -0,0 +1,51 @@ +from torch import nn + +from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel + + +class ColVBert(VBertPreTrainedModel): + """ + Initializes the ColVBert model. + + Args: + config : The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): + super().__init__(config=config) + self.model = VBertModel(config, **kwargs) + self.dim = 128 + self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + self.mask_non_image_embeddings = mask_non_image_embeddings + self.main_input_name = "doc_input_ids" + + def forward(self, *args, **kwargs): + """ + Forward pass through the model and the linear layer for dimensionality reduction + + Args: + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + proj = self.linear(last_hidden_states) + # normalize l2 norm + proj = proj / proj.norm(dim=-1, keepdim=True) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj diff --git a/colpali_engine/models/vbert/colvbert/processing_colvbert.py b/colpali_engine/models/vbert/colvbert/processing_colvbert.py new file mode 100644 index 00000000..cb9c96f2 --- /dev/null +++ b/colpali_engine/models/vbert/colvbert/processing_colvbert.py @@ -0,0 +1,96 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature, Idefics3Processor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): + """ + Processor for ColIdefics3. + """ + + query_augmentation_token: ClassVar[str] = "<|end_of_text|>" + image_token: ClassVar[str] = "" + visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(self.image_token) + + def process_images( + self, + images: List[Image.Image], + contexts: Optional[List[str]] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColVBert. + + Args: + images: List of PIL images. + contexts: List of optional context prompts, i.e. some text description of the context of the image. + """ + # if contexts is None: + # contexts = [self.visual_prompt_prefix] * len(images) + contexts = [self.visual_prompt_prefix] * len(images) + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=contexts, + images=images, + padding="longest", + return_tensors="pt", + ) + return batch_doc + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColVBert. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token * 10 + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + patch_size: int, + ) -> Tuple[int, int]: + raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/vbert/configuration_vbert.py b/colpali_engine/models/vbert/configuration_vbert.py new file mode 100644 index 00000000..504f333b --- /dev/null +++ b/colpali_engine/models/vbert/configuration_vbert.py @@ -0,0 +1,232 @@ +import copy +import os +from typing import Any, Dict, Union + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def collect_arg_in_candidates(config, candidates, default = None) -> Any: + """ Gets the argument in a config given a list of candidates """ + for c in candidates: + if hasattr(config, c): + return getattr(config, c) + elif c in config: + return config[c] + if default is not None: + return default + raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) + +class VBertTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "vbert" + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + text_model_name="EuroBERT/EuroBERT-210m", + **kwargs, + ): + self.text_model_name = text_model_name + text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + if hasattr(text_config, "text_config"): + text_config = text_config.text_config + + self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) + self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) + self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) + self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) + + super().__init__(text_model_name=text_model_name, **kwargs) + +class VBertVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "vbert" + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + vision_model_name="google/siglip2-base-patch16-512", + **kwargs, + ): + self.vision_model_name = vision_model_name + vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + if hasattr(vision_config, "vision_config"): + vision_config = vision_config.vision_config + + self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) + self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) + self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) + self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) + + super().__init__(vision_model_name=vision_model_name, **kwargs) + +class VBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a + SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should cache the key/value pairs of the attention mechanism. Only + relevant if `config.is_decoder=True`. + image_token_id (`int`, *optional*, defaults to 128257): + The id of the "image" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the token embeddings. + vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): + Custom vision config or dict for the vision tower + text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): + Custom text config or dict for the text model + scale_factor (`int`, *optional*, defaults to 2): + The scale factor for the image encoder. + pad_token_id (`int`, *optional*, defaults to 128002): + The id of the padding token. + + Example: + ```python + >>> from transformers import SmolVLMModel, SmolVLMConfig + >>> # Initializing configuration + >>> configuration = SmolVLMConfig() + >>> # Initializing a model from the configuration + >>> model = SmolVLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vbert" + is_composition = True + # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig} + + DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" + DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" + + def __init__( + self, + text_config: Union[PretrainedConfig, Dict[str, Any]] = None, + vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, + image_token_id: int = 128_257, + vocab_size=128_256, + use_cache = True, + tie_word_embeddings = False, + freeze_config = None, + pad_token_id = None, + initializer_range = 0.02, + pixel_shuffle_factor = 4, + use_resampler = False, + additional_vocab_size = 0, + neftune_noise_alpha = 0.0, + **kwargs, + ): + self.image_token_id = image_token_id + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings + self.scale_factor = pixel_shuffle_factor + self.additional_vocab_size = additional_vocab_size + + if text_config is None: + text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) + elif isinstance(text_config, dict): + text_config = VBertTextConfig(text_config["text_model_name"]) + self.text_config = text_config + + if vision_config is None: + vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) + elif isinstance(vision_config, dict): + vision_config = VBertVisionConfig(vision_config["vision_model_name"]) + self.vision_config = vision_config + + self.freeze_config = freeze_config + + # Pixel shuffle factor + self.pixel_shuffle_factor = pixel_shuffle_factor + self.use_resampler = use_resampler + + self.neftune_noise_alpha = neftune_noise_alpha + + self.initializer_range = initializer_range + + hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) + + super().__init__( + **kwargs, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + vocab_size=vocab_size, + hidden_size=hidden_size, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["model_type"] = self.__class__.model_type + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + # output["freeze_config"] = self.freeze_config.to_dict() + + return output + + # @classmethod + # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + # outputs = super(VBertConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) + # return outputs + + @classmethod + def from_pretrained_models( + cls, + text_model_name: Union[str, os.PathLike], + vision_model_name: Union[str, os.PathLike], + **kwargs + ) -> "PretrainedConfig": + # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + text_model_config = VBertTextConfig(text_model_name) + vision_model_config = VBertVisionConfig(vision_model_name) + return cls( + text_config=text_model_config, + vision_config=vision_model_config, + **kwargs + ) diff --git a/colpali_engine/models/vbert/modeling_vbert.py b/colpali_engine/models/vbert/modeling_vbert.py new file mode 100644 index 00000000..c2d6b380 --- /dev/null +++ b/colpali_engine/models/vbert/modeling_vbert.py @@ -0,0 +1,930 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging +from transformers.cache_utils import DynamicCache +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput + +from .configuration_vbert import VBertConfig + +logger = logging.get_logger(__name__) + + +class DecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. + In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. + If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=False, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. + partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), + since the 2nd embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do + the padding, but then we have to create a new tensor and populate it with 2 tensors that are + spread out across various indices - i.e. not a simple concat - I haven't benchmarked the + complex case if it's any faster, given that seqlens are usually relatively short it's + probably not faster or if faster not by much - but might be a good idea to measure. + + """ + if self.num_additional_embeddings == 0: + return self.additional_embedding(input_ids) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.embedding_dim, + self.partially_freeze, + ) + +@dataclass +class VBertBaseModelOutput(BaseModelOutput): + """ + Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class VBertMaskedLMOutput(MaskedLMOutput): + """ + Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + +class VBertSimpleMLP(nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.proj = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + return self.proj(x) + +class VBertConnector(nn.Module): + def __init__(self, config): + super().__init__() + self.scale_factor = config.pixel_shuffle_factor + self.modality_projection = VBertSimpleMLP( + input_size=config.vision_config.hidden_size * (config.scale_factor**2), + output_size=config.text_config.hidden_size + ) + + def pixel_shuffle(self, x, scale_factor): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + +class VBertPreTrainedModel(PreTrainedModel): + config_class = VBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["VBertDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + """Initialize the weights.""" + + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + +class VBertModel(VBertPreTrainedModel): + """ + A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger + in forward. Instead, we override inputs_merger here with custom logic. + """ + + def __init__(self, config: VBertConfig, **kwargs): + super().__init__(config) + + self.vision_model = VBertModel.init_vision_model(config, **kwargs) + self.connector = VBertConnector(config) + self.text_model = VBertModel.init_language_model(config, **kwargs) + + self.image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) + ) + self.image_token_id = self.config.image_token_id + + self.post_init() + + @staticmethod + def init_vision_model(config: VBertConfig, **kwargs): + vision_model_config = AutoConfig.from_pretrained( + config.vision_config.vision_model_name, + trust_remote_code=True, + **kwargs, + ) + + vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) + + if hasattr(vision_model, "vision_model"): + # If the model has a vision_model attribute, it means it's a wrapper around another model + vision_model = vision_model.vision_model + + return vision_model + + @staticmethod + def init_language_model(config: VBertConfig, **kwargs): + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + + text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) + # extractor = regex_lookup(language_model_name, language_model_name2model) + + embed_layer = DecoupledEmbedding( + num_embeddings=text_model_config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_config["freeze_text_layers"], + padding_idx=config.pad_token_id, + ) + + text_model.set_input_embeddings(embed_layer) + + return text_model + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. + + This is useful for lora when using gradient checkpointing. + c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 + + Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. + """ + + def get_lowest_module(module): + if len(list(module.children())) == 0: + # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) + return module + else: + # Recursively call the function on each child module + return get_lowest_module(list(module.children())[0]) + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( + make_inputs_require_grads + ) + + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def inputs_merger( + self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor + ): + """ + This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. + The merging happens as follows: + - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. + - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. + We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. + - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. + - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. + """ + _, patch_size, _ = image_hidden_states.shape + + image_mask = input_ids == self.image_token_id + num_image_tokens = image_mask.sum(dim=1) + if not torch.all(num_image_tokens % patch_size == 0): + raise ValueError("At least one sample has tokens not divisible by patch_size.") + + blocks_per_sample = num_image_tokens // patch_size + + offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) + block_offset = offsets[:-1] + row_cum = image_mask.cumsum(dim=-1) + chunk_idx = (row_cum - 1) // patch_size + local_idx = (row_cum - 1) % patch_size + block_idx = block_offset.unsqueeze(1) + chunk_idx + + image_embeds = torch.zeros_like(inputs_embeds) + image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] + + merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) + return merged_embeds + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() + past_seen_tokens = past_key_values.get_seq_length() + + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + # patch_size = self.config.vision_config.patch_size + # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + # patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if inputs_embeds is not None and image_hidden_states is not None: + # When we embed, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + # use_cache=use_cache, + # cache_position=cache_position, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return VBertBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + +class VBertForMaskedLM(VBertPreTrainedModel): + # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.model = VBertModel(config, **kwargs) + self.lm_head = VBertForMaskedLM.init_lm_head(config, **kwargs) + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def init_lm_head(config, **kwargs): + # Get the pretrained model config + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) + # Get the lm head + lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None + if lm_head is None: + logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") + lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) + return lm_head + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, VBertMaskedLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + # Pass the inputs to VBertModel + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Pass the outputs to the MLM head + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + additional_features = self.additional_fc(hidden_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + masked_lm_loss = None + if labels is not None: + # print the ratio of not ignored tokens + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return VBertMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + # @classmethod + # def from_pretrained_models( + # cls, + # text_model_name, + # vision_model_name, + # vl_config, + # *args, + # **kwargs + # ): + # """ + # Use this method when creating a new vloom model that hasn't been yet trained and it'll be + # composed of 2 pre-trained models - hence `pretrained_models`. + # """ + # model = super().from_pretrained_models( + # text_model_name=text_model_name, + # vision_model_name=vision_model_name, + # vl_config=vl_config, + # *args, + # **kwargs + # ) + # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # # fetch the pretrained text model w/o zero.Init + # pretrained_lm_head = AutoModelForMaskedLM.from_pretrained( + # text_model_name, trust_remote_code=True, **kwargs + # ).lm_head + + # # Load the lm_head + # load_state_dict_into_model(model.lm_head, pretrained_lm_head.state_dict(), start_prefix="") + + # return model + +class VModernBertLMHead(nn.Module): + def __init__(self, config, **kwargs): + super().__init__() + pretrained_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) + + self.head = pretrained_model.head + self.decoder = pretrained_model.decoder + + def forward(self, hidden_states): + hidden_states = self.head(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + # @classmethod + # def from_pretrained( + # cls, + # text_model_name, + # vl_config, + # *args, + # **kwargs + # ): + # """ + # Use this method when creating a new vloom model that hasn't been yet trained and it'll be + # composed of 2 pre-trained models - hence `pretrained_models`. + # """ + # lm_head = cls(vl_config, *args, **kwargs) + + # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # # fetch the pretrained text model w/o zero.Init + # pretrained_model = AutoModelForMaskedLM.from_pretrained( + # text_model_name, trust_remote_code=True, **kwargs + # ) + + # pretrained_head = pretrained_model.head + # pretrained_decoder = pretrained_model.decoder + + # # Load the head + # load_state_dict_into_model(lm_head.head, pretrained_head.state_dict(), start_prefix="") + # load_state_dict_into_model(lm_head.decoder, pretrained_decoder.state_dict(), start_prefix="") + + # return lm_head + +class VModernBertForMaskedLM(VBertPreTrainedModel): + # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.model = VBertModel(config, **kwargs) + self.lm_head = VModernBertLMHead(config, **kwargs) + + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, VBertMaskedLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + # Pass the inputs to VBertModel + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Pass the outputs to the MLM head + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + proj_states = self.lm_head.head(hidden_states) + additional_features = self.additional_fc(proj_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + masked_lm_loss = None + if labels is not None: + # print the ratio of not ignored tokens + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return VBertMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None): + config_vl_model = self.config + + lm_config = config_vl_model.text_config + + language_embed_size = lm_config.hidden_size + num_language_layers = lm_config.num_hidden_layers + ffn_inner_size = lm_config.intermediate_size + + vision_config = config_vl_model.vision_config + + # Get vision model blocks infos + vision_patch_size = vision_config.patch_size + vision_hidden_size = vision_config.embed_dim + num_vision_layers = vision_config.num_hidden_layers + # The +1 is for the CLS token + single_image_vision_encoder_seq_len = int(((vision_config.image_size // vision_patch_size) ** 2) // (self.config.pixel_shuffle_factor**2)) + vision_exp_factor = vision_config.intermediate_size // vision_hidden_size + + # Get language blocks infos + language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len + language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4 + + # Get modality projection infos + vision_pipeline_output_seq_len = ( + self.config.perceiver_config.resampler_n_latents + if self.config.use_resampler + else single_image_vision_encoder_seq_len + ) + + language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( + num_layers=num_language_layers, + batch_size=hparams.batch_size_per_gpu, + q_seq_len=language_seq_len, + k_seq_len=language_seq_len, + hidden_size=language_embed_size, + kv_in_dim=language_embed_size, + ff_exp_factor=language_exp_factor, + grad_acc_size=hparams.grad_acc_size, + swiglu=True, + vocab_size=tokenizer.vocab_size, + count_backward=True, # Always True regardless of freezing, because gradients are computed for vision adaptor + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu( + batch_size=hparams.batch_size_per_gpu * max_num_images, + seq_len=vision_pipeline_output_seq_len, + in_features=vision_hidden_size, + out_features=language_embed_size, + count_backward=True, + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + + vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( + num_layers=num_vision_layers, + batch_size=hparams.batch_size_per_gpu * max_num_images, + q_seq_len=single_image_vision_encoder_seq_len, + k_seq_len=single_image_vision_encoder_seq_len, + hidden_size=vision_hidden_size, + kv_in_dim=vision_hidden_size, + ff_exp_factor=vision_exp_factor, + grad_acc_size=hparams.grad_acc_size, + swiglu=False, + vocab_size=None, + count_backward=not hparams.model_config["freeze_config"]["freeze_vision_layers"], + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + if self.config.use_resampler: + perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu( + num_layers=self.config.perceiver_config.resampler_depth, + batch_size=hparams.batch_size_per_gpu * max_num_images, + q_seq_len=self.config.perceiver_config.resampler_n_latents, + vision_embed_seq_len=single_image_vision_encoder_seq_len, + q_k_v_input_dim=vision_hidden_size, + attention_hidden_size=self.config.perceiver_config.resampler_n_heads + * self.config.perceiver_config.resampler_head_dim, + ff_exp_factor=4, + count_backward=True, + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + tflop_count = ( + language_tflops_per_batch_per_gpu + + modality_projection_tflops_per_batch_per_gpu + + perceiver_tflops_per_batch_per_gpu + + vision_tflops_per_batch_per_gpu + ) + else: + tflop_count = ( + language_tflops_per_batch_per_gpu + + modality_projection_tflops_per_batch_per_gpu + + vision_tflops_per_batch_per_gpu + ) + return tflop_count + + @classmethod + def from_pretrained_models( + cls, + text_model_name, + vision_model_name, + vl_config, + *args, + **kwargs + ): + """ + Use this method when creating a new vloom model that hasn't been yet trained and it'll be + composed of 2 pre-trained models - hence `pretrained_models`. + """ + model = super().from_pretrained_models( + text_model_name=text_model_name, + vision_model_name=vision_model_name, + vl_config=vl_config, + *args, + **kwargs + ) + + # Load the lm_head + model.lm_head = VModernBertLMHead.from_pretrained( + text_model_name=text_model_name, + vl_config=vl_config, + *args, + **kwargs + ) + + return model diff --git a/colpali_engine/models/vllama/__init__.py b/colpali_engine/models/vllama/__init__.py new file mode 100644 index 00000000..534ea814 --- /dev/null +++ b/colpali_engine/models/vllama/__init__.py @@ -0,0 +1,2 @@ +from .bivllama import BiVLlama, BiVLlamaProcessor +from .colvllama import ColVLlama, ColVLlamaProcessor diff --git a/colpali_engine/models/vllama/bivllama/__init__.py b/colpali_engine/models/vllama/bivllama/__init__.py new file mode 100644 index 00000000..55f602b0 --- /dev/null +++ b/colpali_engine/models/vllama/bivllama/__init__.py @@ -0,0 +1,2 @@ +from .modeling_bivllama import BiVLlama +from .processing_bivllama import BiVLlamaProcessor diff --git a/colpali_engine/models/vllama/bivllama/modeling_bivllama.py b/colpali_engine/models/vllama/bivllama/modeling_bivllama.py new file mode 100644 index 00000000..92a37690 --- /dev/null +++ b/colpali_engine/models/vllama/bivllama/modeling_bivllama.py @@ -0,0 +1,61 @@ +from typing import Literal + +import torch + +from colpali_engine.models.vllama.modeling_vllama import VLlamaModel, VLlamaPreTrainedModel + + +class BiVLlama(VLlamaPreTrainedModel): + """ + Initializes the BiVLlama model. + + Args: + config : The model configuration. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, **kwargs): + super().__init__(config=config) + self.model = VLlamaModel(config, **kwargs) + self.post_init() + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = "last", + *args, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass through model and pooling. + + Args: + - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + + # Get CLS token embedding, last token, or mean pool over sequence + if pooling_strategy == "cls": + # Use CLS token (first token) embedding + pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) + elif pooling_strategy == "last": + # use last token since we are left padding + pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) + elif pooling_strategy == "mean": + # Mean pooling over sequence length + mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) + pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + # L2 normalization + pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True) + return pooled_output diff --git a/colpali_engine/models/vllama/bivllama/processing_bivllama.py b/colpali_engine/models/vllama/bivllama/processing_bivllama.py new file mode 100644 index 00000000..e6fa2908 --- /dev/null +++ b/colpali_engine/models/vllama/bivllama/processing_bivllama.py @@ -0,0 +1,51 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchFeature, BatchEncoding + +from colpali_engine.models.vllama.colvllama import ColVLlamaProcessor + + +class BiVLlamaProcessor(ColVLlamaProcessor): # noqa: N801 + """ + Processor for BiVLlama model. + """ + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiVLlama. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token # we remove buffer tokens + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/vllama/colvllama/__init__.py b/colpali_engine/models/vllama/colvllama/__init__.py new file mode 100644 index 00000000..00dae459 --- /dev/null +++ b/colpali_engine/models/vllama/colvllama/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colvllama import ColVLlama +from .processing_colvllama import ColVLlamaProcessor diff --git a/colpali_engine/models/vllama/colvllama/modeling_colvllama.py b/colpali_engine/models/vllama/colvllama/modeling_colvllama.py new file mode 100644 index 00000000..0d40391a --- /dev/null +++ b/colpali_engine/models/vllama/colvllama/modeling_colvllama.py @@ -0,0 +1,51 @@ +from torch import nn + +from colpali_engine.models.vllama.modeling_vllama import VLlamaModel, VLlamaPreTrainedModel + + +class ColVLlama(VLlamaPreTrainedModel): + """ + Initializes the ColVBert model. + + Args: + config : The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): + super().__init__(config=config) + self.model = VLlamaModel(config, **kwargs) + self.dim = 128 + self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + self.mask_non_image_embeddings = mask_non_image_embeddings + self.main_input_name = "doc_input_ids" + + def forward(self, *args, **kwargs): + """ + Forward pass through the model and the linear layer for dimensionality reduction + + Args: + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + proj = self.linear(last_hidden_states) + # normalize l2 norm + proj = proj / proj.norm(dim=-1, keepdim=True) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj diff --git a/colpali_engine/models/vllama/colvllama/processing_colvllama.py b/colpali_engine/models/vllama/colvllama/processing_colvllama.py new file mode 100644 index 00000000..0be8d5f1 --- /dev/null +++ b/colpali_engine/models/vllama/colvllama/processing_colvllama.py @@ -0,0 +1,96 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature, Idefics3Processor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColVLlamaProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): + """ + Processor for ColVLlama. + """ + + query_augmentation_token: ClassVar[str] = "<|end_of_text|>" + image_token: ClassVar[str] = "" + visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:Describe the image.\nAssistant:" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(self.image_token) + + def process_images( + self, + images: List[Image.Image], + contexts: Optional[List[str]] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColVLlama. + + Args: + images: List of PIL images. + contexts: List of optional context prompts, i.e. some text description of the context of the image. + """ + # if contexts is None: + # contexts = [self.visual_prompt_prefix] * len(images) + contexts = [self.visual_prompt_prefix] * len(images) + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=contexts, + images=images, + padding="longest", + return_tensors="pt", + ) + return batch_doc + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColVLlama. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token * 10 + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + patch_size: int, + ) -> Tuple[int, int]: + raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/vllama/configuration_vllama.py b/colpali_engine/models/vllama/configuration_vllama.py new file mode 100644 index 00000000..576b6497 --- /dev/null +++ b/colpali_engine/models/vllama/configuration_vllama.py @@ -0,0 +1,232 @@ +import copy +import os +from typing import Any, Dict, Union + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def collect_arg_in_candidates(config, candidates, default = None) -> Any: + """ Gets the argument in a config given a list of candidates """ + for c in candidates: + if hasattr(config, c): + return getattr(config, c) + elif c in config: + return config[c] + if default is not None: + return default + raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) + +class VLlamaTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "VLlama" + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + text_model_name="HuggingFaceTB/SmolLM2-135M-Instruct", + **kwargs, + ): + self.text_model_name = text_model_name + text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + if hasattr(text_config, "text_config"): + text_config = text_config.text_config + + self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) + self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) + self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) + self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) + + super().__init__(text_model_name=text_model_name, **kwargs) + +class VLlamaVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "VLlama" + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + vision_model_name="google/siglip2-base-patch16-512", + **kwargs, + ): + self.vision_model_name = vision_model_name + vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + if hasattr(vision_config, "vision_config"): + vision_config = vision_config.vision_config + + self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) + self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) + self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) + self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) + + super().__init__(vision_model_name=vision_model_name, **kwargs) + +class VLlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a + SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should cache the key/value pairs of the attention mechanism. Only + relevant if `config.is_decoder=True`. + image_token_id (`int`, *optional*, defaults to 128257): + The id of the "image" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the token embeddings. + vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): + Custom vision config or dict for the vision tower + text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): + Custom text config or dict for the text model + scale_factor (`int`, *optional*, defaults to 2): + The scale factor for the image encoder. + pad_token_id (`int`, *optional*, defaults to 128002): + The id of the padding token. + + Example: + ```python + >>> from transformers import SmolVLMModel, SmolVLMConfig + >>> # Initializing configuration + >>> configuration = SmolVLMConfig() + >>> # Initializing a model from the configuration + >>> model = SmolVLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "VLlama" + is_composition = True + # sub_configs = {"text_config": VLlamaTextConfig, "vision_config": VLlamaVisionConfig} + + DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" + DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" + + def __init__( + self, + text_config: Union[PretrainedConfig, Dict[str, Any]] = None, + vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, + image_token_id: int = 128_257, + vocab_size=128_256, + use_cache = True, + tie_word_embeddings = False, + freeze_config = None, + pad_token_id = None, + initializer_range = 0.02, + pixel_shuffle_factor = 4, + use_resampler = False, + additional_vocab_size = 0, + neftune_noise_alpha = 0.0, + **kwargs, + ): + self.image_token_id = image_token_id + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings + self.scale_factor = pixel_shuffle_factor + self.additional_vocab_size = additional_vocab_size + + if text_config is None: + text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) + elif isinstance(text_config, dict): + text_config = VLlamaTextConfig(text_config["text_model_name"]) + self.text_config = text_config + + if vision_config is None: + vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) + elif isinstance(vision_config, dict): + vision_config = VLlamaVisionConfig(vision_config["vision_model_name"]) + self.vision_config = vision_config + + self.freeze_config = freeze_config + + # Pixel shuffle factor + self.pixel_shuffle_factor = pixel_shuffle_factor + self.use_resampler = use_resampler + + self.neftune_noise_alpha = neftune_noise_alpha + + self.initializer_range = initializer_range + + hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) + + super().__init__( + **kwargs, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + vocab_size=vocab_size, + hidden_size=hidden_size, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["model_type"] = self.__class__.model_type + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + # output["freeze_config"] = self.freeze_config.to_dict() + + return output + + # @classmethod + # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + # outputs = super(VLlamaConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) + # return outputs + + @classmethod + def from_pretrained_models( + cls, + text_model_name: Union[str, os.PathLike], + vision_model_name: Union[str, os.PathLike], + **kwargs + ) -> "PretrainedConfig": + # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + text_model_config = VLlamaTextConfig(text_model_name) + vision_model_config = VLlamaVisionConfig(vision_model_name) + return cls( + text_config=text_model_config, + vision_config=vision_model_config, + **kwargs + ) diff --git a/colpali_engine/models/vllama/modeling_vllama.py b/colpali_engine/models/vllama/modeling_vllama.py new file mode 100644 index 00000000..b1fa576f --- /dev/null +++ b/colpali_engine/models/vllama/modeling_vllama.py @@ -0,0 +1,883 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, GenerationMixin, logging +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import LossKwargs + +# from transformers.models.smolvlm import SmolVLMModel, SmolVLMPreTrainedModel +from .configuration_vllama import VLlamaConfig + +logger = logging.get_logger(__name__) + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + +class DecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. + In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. + If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=False, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. + partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), + since the 2nd embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + note: for the 1st embedding lookup we could have looked up only the low indices and not do + the padding, but then we have to create a new tensor and populate it with 2 tensors that are + spread out across various indices - i.e. not a simple concat - I haven't benchmarked the + complex case if it's any faster, given that seqlens are usually relatively short it's + probably not faster or if faster not by much - but might be a good idea to measure. + """ + if self.num_additional_embeddings == 0: + return self.additional_embedding(input_ids) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.embedding_dim, + self.partially_freeze, + ) + +@dataclass +class VLlamaBaseModelOutputWithPast(BaseModelOutput): + """ + Base class for VLlama3 model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class VLlamaCausalLMOutputWithPast(BaseModelOutput): + """ + Base class for VLlama3 causal language model (or autoregressive) outputs. + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class VLlamaSimpleMLP(nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.proj = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + return self.proj(x) + +class VLlamaConnector(nn.Module): + def __init__(self, config): + super().__init__() + self.scale_factor = config.pixel_shuffle_factor + self.modality_projection = VLlamaSimpleMLP( + input_size=config.vision_config.hidden_size * (config.scale_factor**2), + output_size=config.text_config.hidden_size + ) + + def pixel_shuffle(self, x, scale_factor): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + +class VLlamaPreTrainedModel(PreTrainedModel): + config_class = VLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["VLlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + """Initialize the weights.""" + + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + +class VLlamaModel(VLlamaPreTrainedModel): + """ + A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger + in forward. Instead, we override inputs_merger here with custom logic. + """ + + def __init__(self, config: VLlamaConfig, **kwargs): + super().__init__(config) + + self.vision_model = VLlamaModel.init_vision_model(config, **kwargs) + self.connector = VLlamaConnector(config) + self.text_model = VLlamaModel.init_language_model(config, **kwargs) + + self.image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) + ) + self.image_token_id = self.config.image_token_id + + self.post_init() + + @staticmethod + def init_vision_model(config: VLlamaConfig, **kwargs): + vision_model_config = AutoConfig.from_pretrained( + config.vision_config.vision_model_name, + trust_remote_code=True, + **kwargs, + ) + + vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) + + if hasattr(vision_model, "vision_model"): + # If the model has a vision_model attribute, it means it's a wrapper around another model + vision_model = vision_model.vision_model + + return vision_model + + @staticmethod + def init_language_model(config: VLlamaConfig, **kwargs): + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + + text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) + # extractor = regex_lookup(language_model_name, language_model_name2model) + + embed_layer = DecoupledEmbedding( + num_embeddings=text_model_config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_config["freeze_text_layers"], + padding_idx=config.pad_token_id, + ) + + text_model.set_input_embeddings(embed_layer) + + return text_model + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. + This is useful for lora when using gradient checkpointing. + c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 + Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. + """ + + def get_lowest_module(module): + if len(list(module.children())) == 0: + # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) + return module + else: + # Recursively call the function on each child module + return get_lowest_module(list(module.children())[0]) + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( + make_inputs_require_grads + ) + + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def inputs_merger( + self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor + ): + """ + This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. + The merging happens as follows: + - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. + - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. + We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. + - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. + - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. + """ + _, patch_size, _ = image_hidden_states.shape + + image_mask = input_ids == self.image_token_id + num_image_tokens = image_mask.sum(dim=1) + if not torch.all(num_image_tokens % patch_size == 0): + raise ValueError("At least one sample has tokens not divisible by patch_size.") + + blocks_per_sample = num_image_tokens // patch_size + + offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) + block_offset = offsets[:-1] + row_cum = image_mask.cumsum(dim=-1) + chunk_idx = (row_cum - 1) // patch_size + local_idx = (row_cum - 1) % patch_size + block_idx = block_offset.unsqueeze(1) + chunk_idx + + image_embeds = torch.zeros_like(inputs_embeds) + image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] + + merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) + return merged_embeds + + def embed_tokens(self, input_ids: torch.LongTensor) -> torch.FloatTensor: + """ + Override the embed_tokens method to use the text model's input embeddings. + This is necessary to ensure that the image token ID is correctly handled. + """ + if self.text_model.get_input_embeddings() is None: + raise ValueError("The text model does not have input embeddings.") + + return self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, VLlamaBaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if inputs_embeds is not None and input_ids is None: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + # patch_size = self.config.vision_config.patch_size + # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + # patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if inputs_embeds is not None and image_hidden_states is not None: + # When we embed, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return VLlamaBaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + +class VLlamaForCausalLM(VLlamaPreTrainedModel): + # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + self.model = VLlamaModel(config, **kwargs) + self.lm_head = VLlamaForCausalLM.init_lm_head(config, **kwargs) + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def init_lm_head(config, **kwargs): + # Get the pretrained model config + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) + # Get the lm head + lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None + if lm_head is None: + logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") + lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) + return lm_head + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, VLlamaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + # Pass the inputs to VLlamaModel + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + # Pass the outputs to the MLM head + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + additional_features = self.additional_fc(hidden_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return VLlamaCausalLMOutputWithPast( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + +class VLlamaForVision2Seq(VLlamaPreTrainedModel, GenerationMixin): + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + self.model = VLlamaModel(config, **kwargs) + self.lm_head = VLlamaForVision2Seq.init_lm_head(config, **kwargs) + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + self.loss_fct = CrossEntropyLoss() + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def init_lm_head(config, **kwargs): + # Get the pretrained model config + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) + # Get the lm head + lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None + if lm_head is None: + logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") + lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) + return lm_head + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( + make_inputs_require_grads + ) + + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + + def get_input_embeddings(self): + return self.model.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.text_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, VLlamaCausalLMOutputWithPast]: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The hidden states of the image encoder after modality projection. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `SmolVLMForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + >>> from transformers.image_utils import load_image + + >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible + >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") + >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") + >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") + + >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") + >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto") + + >>> # Create inputs + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "video", "path": path/to/video}, + ... {"type": "text", "text": "What is happening in this video?"}, + ... ] + ... } + ... ] + + >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=256) + >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + >>> print(generated_texts) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=True, + **kwargs, + ) + + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + hidden_states = hidden_states[:, slice_indices, :] + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + additional_features = self.additional_fc(hidden_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + loss = None + if labels is not None: + loss = self.loss_fct( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return VLlamaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + pixel_values=None, + pixel_attention_mask=None, + image_hidden_states=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take + # precedence is moved to the model, we can remove this fn) + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + # but IDEFICS requires both ids and embeds to be present + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs["input_ids"] = input_ids + + if image_hidden_states is not None: + model_inputs["pixel_values"] = None + model_inputs["pixel_attention_mask"] = None + + return model_inputs + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + # Get the precomputed image_hidden_states + model_kwargs["image_hidden_states"] = outputs.image_hidden_states + return model_kwargs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past From 7fba1c668365185019f0066fcd0cdae128ca5da3 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Tue, 24 Jun 2025 17:44:42 +0200 Subject: [PATCH 05/42] stage --- colpali_engine/models/vbert/bivbert/modeling_bivbert.py | 7 +++++-- colpali_engine/models/vllama/bivllama/modeling_bivllama.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/colpali_engine/models/vbert/bivbert/modeling_bivbert.py b/colpali_engine/models/vbert/bivbert/modeling_bivbert.py index 470df18b..fcf5eb60 100644 --- a/colpali_engine/models/vbert/bivbert/modeling_bivbert.py +++ b/colpali_engine/models/vbert/bivbert/modeling_bivbert.py @@ -17,14 +17,15 @@ class BiVBert(VBertPreTrainedModel): _supports_sdpa = True _supports_cache_class = True - def __init__(self, config, **kwargs): + def __init__(self, config, pooling_strategy = "last", **kwargs): super().__init__(config=config) self.model = VBertModel(config, **kwargs) + self.pooling_strategy = pooling_strategy self.post_init() def forward( self, - pooling_strategy: Literal["cls", "last", "mean"] = "last", + pooling_strategy: Literal["cls", "last", "mean"] = None, *args, **kwargs, ) -> torch.Tensor: @@ -42,6 +43,8 @@ def forward( outputs = self.model(*args, **kwargs) last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + pooling_strategy = pooling_strategy or self.pooling_strategy + # Get CLS token embedding, last token, or mean pool over sequence if pooling_strategy == "cls": # Use CLS token (first token) embedding diff --git a/colpali_engine/models/vllama/bivllama/modeling_bivllama.py b/colpali_engine/models/vllama/bivllama/modeling_bivllama.py index 92a37690..a1ec56e6 100644 --- a/colpali_engine/models/vllama/bivllama/modeling_bivllama.py +++ b/colpali_engine/models/vllama/bivllama/modeling_bivllama.py @@ -17,14 +17,15 @@ class BiVLlama(VLlamaPreTrainedModel): _supports_sdpa = True _supports_cache_class = True - def __init__(self, config, **kwargs): + def __init__(self, config, pooling_strategy = "last", **kwargs): super().__init__(config=config) self.model = VLlamaModel(config, **kwargs) + self.pooling_strategy = pooling_strategy self.post_init() def forward( self, - pooling_strategy: Literal["cls", "last", "mean"] = "last", + pooling_strategy: Literal["cls", "last", "mean"] = None, *args, **kwargs, ) -> torch.Tensor: @@ -42,6 +43,8 @@ def forward( outputs = self.model(*args, **kwargs) last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + pooling_strategy = pooling_strategy or self.pooling_strategy + # Get CLS token embedding, last token, or mean pool over sequence if pooling_strategy == "cls": # Use CLS token (first token) embedding From 43d3d36261f8d68d88bbb3dcdcfd50e6224827bb Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Mon, 30 Jun 2025 16:14:30 +0200 Subject: [PATCH 06/42] fix typo in vbert modeling --- colpali_engine/models/vbert/colvbert/modeling_colvbert.py | 4 ++-- colpali_engine/models/vllama/colvllama/modeling_colvllama.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/colpali_engine/models/vbert/colvbert/modeling_colvbert.py b/colpali_engine/models/vbert/colvbert/modeling_colvbert.py index 2beb2958..dd0c68c7 100644 --- a/colpali_engine/models/vbert/colvbert/modeling_colvbert.py +++ b/colpali_engine/models/vbert/colvbert/modeling_colvbert.py @@ -22,7 +22,7 @@ def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): super().__init__(config=config) self.model = VBertModel(config, **kwargs) self.dim = 128 - self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) self.mask_non_image_embeddings = mask_non_image_embeddings self.main_input_name = "doc_input_ids" @@ -39,7 +39,7 @@ def forward(self, *args, **kwargs): """ outputs = self.model(*args, **kwargs) last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - proj = self.linear(last_hidden_states) + proj = self.custom_text_proj(last_hidden_states) # normalize l2 norm proj = proj / proj.norm(dim=-1, keepdim=True) proj = proj * kwargs["attention_mask"].unsqueeze(-1) diff --git a/colpali_engine/models/vllama/colvllama/modeling_colvllama.py b/colpali_engine/models/vllama/colvllama/modeling_colvllama.py index 0d40391a..a5a0114e 100644 --- a/colpali_engine/models/vllama/colvllama/modeling_colvllama.py +++ b/colpali_engine/models/vllama/colvllama/modeling_colvllama.py @@ -22,7 +22,7 @@ def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): super().__init__(config=config) self.model = VLlamaModel(config, **kwargs) self.dim = 128 - self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) self.mask_non_image_embeddings = mask_non_image_embeddings self.main_input_name = "doc_input_ids" @@ -39,7 +39,7 @@ def forward(self, *args, **kwargs): """ outputs = self.model(*args, **kwargs) last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - proj = self.linear(last_hidden_states) + proj = self.custom_text_proj(last_hidden_states) # normalize l2 norm proj = proj / proj.norm(dim=-1, keepdim=True) proj = proj * kwargs["attention_mask"].unsqueeze(-1) From ed11060d24e359afe0e12bab467297bcf67cc44a Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Tue, 1 Jul 2025 17:16:34 +0200 Subject: [PATCH 07/42] loss --- colpali_engine/loss/__init__.py | 1 + colpali_engine/loss/bi_encoder_losses.py | 2 + .../loss/late_interaction_losses.py | 60 +++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index db060015..0ad15237 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -11,4 +11,5 @@ ColbertNegativeCELoss, ColbertPairwiseCELoss, ColbertPairwiseNegativeCELoss, + ColbertSigmoidLoss, ) diff --git a/colpali_engine/loss/bi_encoder_losses.py b/colpali_engine/loss/bi_encoder_losses.py index b8dbef34..b82423ff 100644 --- a/colpali_engine/loss/bi_encoder_losses.py +++ b/colpali_engine/loss/bi_encoder_losses.py @@ -206,6 +206,7 @@ def forward( self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, + offset: int = 0, ) -> torch.Tensor: """ Compute softplus(hardest_neg - pos) where hardest_neg is the highest off-diagonal score. @@ -267,6 +268,7 @@ def forward( query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, neg_doc_embeddings: torch.Tensor, + offset: int = 0, ) -> torch.Tensor: """ Compute softplus(neg-explicit - pos) plus optional pairwise in-batch loss. diff --git a/colpali_engine/loss/late_interaction_losses.py b/colpali_engine/loss/late_interaction_losses.py index 95bcf6ff..03dcfd84 100644 --- a/colpali_engine/loss/late_interaction_losses.py +++ b/colpali_engine/loss/late_interaction_losses.py @@ -395,3 +395,63 @@ def forward( loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight return loss + + +class ColbertSigmoidLoss(ColbertModule): + """ + Sigmoid loss for ColBERT with explicit negatives. + + Args: + temperature (float): Scaling for logits. + normalize_scores (bool): Normalize scores by query lengths. + use_smooth_max (bool): Use log-sum-exp instead of amax. + pos_aware_negative_filtering (bool): Apply pos-aware negative filtering. + """ + + def __init__( + self, + temperature: float = 0.02, + normalize_scores: bool = True, + use_smooth_max: bool = False, + pos_aware_negative_filtering: bool = False, + max_batch_size: int = 1024, + tau: float = 0.1, + norm_tol: float = 1e-3, + filter_threshold: float = 0.95, + filter_factor: float = 0.5, + ): + super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor) + self.temperature = temperature + self.normalize_scores = normalize_scores + self.use_smooth_max = use_smooth_max + self.pos_aware_negative_filtering = pos_aware_negative_filtering + self.ce_loss = CrossEntropyLoss() + + def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor: + """ + Compute sigmoid loss over positive and negative document pairs. + + Args: + query_embeddings (Tensor): [B, Nq, D] + doc_embeddings (Tensor): [B, Nd, D] positive docs + + Returns: + Tensor: Scalar loss value. + """ + + lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) + raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings) + scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2) + + if self.normalize_scores: + scores = self._apply_normalization(scores, lengths) + + batch_size = scores.size(0) + idx, pos_idx = self._get_idx(batch_size, offset, scores.device) + + if self.pos_aware_negative_filtering: + self._filter_high_negatives(scores, pos_idx) + + loss = self.ce_loss(scores / self.temperature, pos_idx) + + return loss.mean() From 55ebd0c6f3ee6efab1e52e4ee4efa5002517c077 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Tue, 8 Jul 2025 16:39:47 +0200 Subject: [PATCH 08/42] models --- colpali_engine/models/__init__.py | 2 + colpali_engine/models/eurovbert/__init__.py | 2 + .../models/eurovbert/bivbert/__init__.py | 2 + .../eurovbert/bivbert/modeling_bivbert.py | 65 ++ .../eurovbert/bivbert/processing_bivbert.py | 51 + .../models/eurovbert/colvbert/__init__.py | 2 + .../colvbert/modeling_coleurovbert.py | 51 + .../colvbert/processing_coleurovbert.py | 96 ++ .../models/eurovbert/configuration_vbert.py | 232 +++++ .../models/eurovbert/modeling_vbert.py | 930 +++++++++++++++++ colpali_engine/models/modernvbert/__init__.py | 2 + .../models/modernvbert/bivbert/__init__.py | 2 + .../modernvbert/bivbert/modeling_bivbert.py | 65 ++ .../modernvbert/bivbert/processing_bivbert.py | 51 + .../models/modernvbert/colvbert/__init__.py | 2 + .../colvbert/modeling_colmodernvbert.py | 53 + .../colvbert/processing_colmodernvbert.py | 97 ++ .../models/modernvbert/configuration_vbert.py | 232 +++++ .../models/modernvbert/modeling_vbert.py | 931 ++++++++++++++++++ .../models/vbert/bivbert/modeling_bivbert.py | 7 +- .../vbert/bivbert/processing_bivbert.py | 2 +- colpali_engine/models/vbert/modeling_vbert.py | 2 +- 22 files changed, 2874 insertions(+), 5 deletions(-) create mode 100644 colpali_engine/models/eurovbert/__init__.py create mode 100644 colpali_engine/models/eurovbert/bivbert/__init__.py create mode 100644 colpali_engine/models/eurovbert/bivbert/modeling_bivbert.py create mode 100644 colpali_engine/models/eurovbert/bivbert/processing_bivbert.py create mode 100644 colpali_engine/models/eurovbert/colvbert/__init__.py create mode 100644 colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py create mode 100644 colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py create mode 100644 colpali_engine/models/eurovbert/configuration_vbert.py create mode 100644 colpali_engine/models/eurovbert/modeling_vbert.py create mode 100644 colpali_engine/models/modernvbert/__init__.py create mode 100644 colpali_engine/models/modernvbert/bivbert/__init__.py create mode 100644 colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py create mode 100644 colpali_engine/models/modernvbert/bivbert/processing_bivbert.py create mode 100644 colpali_engine/models/modernvbert/colvbert/__init__.py create mode 100644 colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py create mode 100644 colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py create mode 100644 colpali_engine/models/modernvbert/configuration_vbert.py create mode 100644 colpali_engine/models/modernvbert/modeling_vbert.py diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 0f0c8118..3546b276 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -3,3 +3,5 @@ from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor +from .eurovbert import ColEuroVBert, ColEuroVBertProcessor +from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor diff --git a/colpali_engine/models/eurovbert/__init__.py b/colpali_engine/models/eurovbert/__init__.py new file mode 100644 index 00000000..84ab5f61 --- /dev/null +++ b/colpali_engine/models/eurovbert/__init__.py @@ -0,0 +1,2 @@ +from .bivbert import BiVBert, BiVBertProcessor +from .colvbert import ColEuroVBert, ColEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/bivbert/__init__.py b/colpali_engine/models/eurovbert/bivbert/__init__.py new file mode 100644 index 00000000..23bc11e3 --- /dev/null +++ b/colpali_engine/models/eurovbert/bivbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_bivbert import BiVBert +from .processing_bivbert import BiVBertProcessor diff --git a/colpali_engine/models/eurovbert/bivbert/modeling_bivbert.py b/colpali_engine/models/eurovbert/bivbert/modeling_bivbert.py new file mode 100644 index 00000000..33b7bd22 --- /dev/null +++ b/colpali_engine/models/eurovbert/bivbert/modeling_bivbert.py @@ -0,0 +1,65 @@ +from typing import Literal + +import torch + +from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel + + +class BiVBert(VBertPreTrainedModel): + """ + Initializes the BiIdefics3 model. + + Args: + config : The model configuration. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, pooling_strategy = "mean", **kwargs): + super().__init__(config=config) + self.model = VBertModel(config, **kwargs) + self.pooling_strategy = pooling_strategy + self.post_init() + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass through model and pooling. + + Args: + - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + + pooling_strategy = pooling_strategy or self.pooling_strategy + + # Get CLS token embedding, last token, or mean pool over sequence + if pooling_strategy == "cls": + # Use CLS token (first token) embedding + pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) + elif pooling_strategy == "last": + # Use last token + last_unpadded_index = kwargs["attention_mask"].sum(dim=1) - 1 + pooled_output = last_hidden_states[:, last_unpadded_index.clamp(min=0)] # (batch_size, hidden_size) + elif pooling_strategy == "mean": + # Mean pooling over sequence length + mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) + pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + # L2 normalization + pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True) + return pooled_output diff --git a/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py b/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py new file mode 100644 index 00000000..b1606f94 --- /dev/null +++ b/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py @@ -0,0 +1,51 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.eurovbert.colvbert import ColEuroVBertProcessor + + +class BiVBertProcessor(ColEuroVBertProcessor): # noqa: N801 + """ + Processor for BiVBert. + """ + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiVBert. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token # we remove buffer tokens + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/eurovbert/colvbert/__init__.py b/colpali_engine/models/eurovbert/colvbert/__init__.py new file mode 100644 index 00000000..4e0b32a9 --- /dev/null +++ b/colpali_engine/models/eurovbert/colvbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_coleurovbert import ColEuroVBert +from .processing_coleurovbert import ColEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py b/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py new file mode 100644 index 00000000..d7e14bcb --- /dev/null +++ b/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py @@ -0,0 +1,51 @@ +from torch import nn + +from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel + + +class ColEuroVBert(VBertPreTrainedModel): + """ + Initializes the ColVBert model. + + Args: + config : The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): + super().__init__(config=config) + self.model = VBertModel(config, **kwargs) + self.dim = 128 + self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + self.mask_non_image_embeddings = mask_non_image_embeddings + self.main_input_name = "doc_input_ids" + + def forward(self, *args, **kwargs): + """ + Forward pass through the model and the linear layer for dimensionality reduction + + Args: + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + proj = self.custom_text_proj(last_hidden_states) + # normalize l2 norm + proj = proj / proj.norm(dim=-1, keepdim=True) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj diff --git a/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py b/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py new file mode 100644 index 00000000..c4c78e7f --- /dev/null +++ b/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py @@ -0,0 +1,96 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature, Idefics3Processor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColEuroVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): + """ + Processor for ColIdefics3. + """ + + query_augmentation_token: ClassVar[str] = "<|end_of_text|>" + image_token: ClassVar[str] = "" + visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(self.image_token) + + def process_images( + self, + images: List[Image.Image], + contexts: Optional[List[str]] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColVBert. + + Args: + images: List of PIL images. + contexts: List of optional context prompts, i.e. some text description of the context of the image. + """ + # if contexts is None: + # contexts = [self.visual_prompt_prefix] * len(images) + contexts = [self.visual_prompt_prefix] * len(images) + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=contexts, + images=images, + padding="longest", + return_tensors="pt", + ) + return batch_doc + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColVBert. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token * 10 + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + patch_size: int, + ) -> Tuple[int, int]: + raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/eurovbert/configuration_vbert.py b/colpali_engine/models/eurovbert/configuration_vbert.py new file mode 100644 index 00000000..504f333b --- /dev/null +++ b/colpali_engine/models/eurovbert/configuration_vbert.py @@ -0,0 +1,232 @@ +import copy +import os +from typing import Any, Dict, Union + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def collect_arg_in_candidates(config, candidates, default = None) -> Any: + """ Gets the argument in a config given a list of candidates """ + for c in candidates: + if hasattr(config, c): + return getattr(config, c) + elif c in config: + return config[c] + if default is not None: + return default + raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) + +class VBertTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "vbert" + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + text_model_name="EuroBERT/EuroBERT-210m", + **kwargs, + ): + self.text_model_name = text_model_name + text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + if hasattr(text_config, "text_config"): + text_config = text_config.text_config + + self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) + self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) + self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) + self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) + + super().__init__(text_model_name=text_model_name, **kwargs) + +class VBertVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "vbert" + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + vision_model_name="google/siglip2-base-patch16-512", + **kwargs, + ): + self.vision_model_name = vision_model_name + vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + if hasattr(vision_config, "vision_config"): + vision_config = vision_config.vision_config + + self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) + self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) + self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) + self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) + + super().__init__(vision_model_name=vision_model_name, **kwargs) + +class VBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a + SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should cache the key/value pairs of the attention mechanism. Only + relevant if `config.is_decoder=True`. + image_token_id (`int`, *optional*, defaults to 128257): + The id of the "image" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the token embeddings. + vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): + Custom vision config or dict for the vision tower + text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): + Custom text config or dict for the text model + scale_factor (`int`, *optional*, defaults to 2): + The scale factor for the image encoder. + pad_token_id (`int`, *optional*, defaults to 128002): + The id of the padding token. + + Example: + ```python + >>> from transformers import SmolVLMModel, SmolVLMConfig + >>> # Initializing configuration + >>> configuration = SmolVLMConfig() + >>> # Initializing a model from the configuration + >>> model = SmolVLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vbert" + is_composition = True + # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig} + + DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" + DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" + + def __init__( + self, + text_config: Union[PretrainedConfig, Dict[str, Any]] = None, + vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, + image_token_id: int = 128_257, + vocab_size=128_256, + use_cache = True, + tie_word_embeddings = False, + freeze_config = None, + pad_token_id = None, + initializer_range = 0.02, + pixel_shuffle_factor = 4, + use_resampler = False, + additional_vocab_size = 0, + neftune_noise_alpha = 0.0, + **kwargs, + ): + self.image_token_id = image_token_id + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings + self.scale_factor = pixel_shuffle_factor + self.additional_vocab_size = additional_vocab_size + + if text_config is None: + text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) + elif isinstance(text_config, dict): + text_config = VBertTextConfig(text_config["text_model_name"]) + self.text_config = text_config + + if vision_config is None: + vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) + elif isinstance(vision_config, dict): + vision_config = VBertVisionConfig(vision_config["vision_model_name"]) + self.vision_config = vision_config + + self.freeze_config = freeze_config + + # Pixel shuffle factor + self.pixel_shuffle_factor = pixel_shuffle_factor + self.use_resampler = use_resampler + + self.neftune_noise_alpha = neftune_noise_alpha + + self.initializer_range = initializer_range + + hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) + + super().__init__( + **kwargs, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + vocab_size=vocab_size, + hidden_size=hidden_size, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["model_type"] = self.__class__.model_type + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + # output["freeze_config"] = self.freeze_config.to_dict() + + return output + + # @classmethod + # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + # outputs = super(VBertConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) + # return outputs + + @classmethod + def from_pretrained_models( + cls, + text_model_name: Union[str, os.PathLike], + vision_model_name: Union[str, os.PathLike], + **kwargs + ) -> "PretrainedConfig": + # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + text_model_config = VBertTextConfig(text_model_name) + vision_model_config = VBertVisionConfig(vision_model_name) + return cls( + text_config=text_model_config, + vision_config=vision_model_config, + **kwargs + ) diff --git a/colpali_engine/models/eurovbert/modeling_vbert.py b/colpali_engine/models/eurovbert/modeling_vbert.py new file mode 100644 index 00000000..3d681d69 --- /dev/null +++ b/colpali_engine/models/eurovbert/modeling_vbert.py @@ -0,0 +1,930 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging +from transformers.cache_utils import DynamicCache +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput + +from .configuration_vbert import VBertConfig + +logger = logging.get_logger(__name__) + + +class DecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. + In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. + If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=False, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. + partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), + since the 2nd embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do + the padding, but then we have to create a new tensor and populate it with 2 tensors that are + spread out across various indices - i.e. not a simple concat - I haven't benchmarked the + complex case if it's any faster, given that seqlens are usually relatively short it's + probably not faster or if faster not by much - but might be a good idea to measure. + + """ + if self.num_additional_embeddings == 0: + return self.additional_embedding(input_ids) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.embedding_dim, + self.partially_freeze, + ) + +@dataclass +class VBertBaseModelOutput(BaseModelOutput): + """ + Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class VBertMaskedLMOutput(MaskedLMOutput): + """ + Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + +class VBertSimpleMLP(nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.proj = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + return self.proj(x) + +class VBertConnector(nn.Module): + def __init__(self, config): + super().__init__() + self.scale_factor = config.pixel_shuffle_factor + self.modality_projection = VBertSimpleMLP( + input_size=config.vision_config.hidden_size * (config.scale_factor**2), + output_size=config.text_config.hidden_size + ) + + def pixel_shuffle(self, x, scale_factor): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + +class VBertPreTrainedModel(PreTrainedModel): + config_class = VBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["VBertDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + """Initialize the weights.""" + + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + +class VBertModel(VBertPreTrainedModel): + """ + A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger + in forward. Instead, we override inputs_merger here with custom logic. + """ + + def __init__(self, config: VBertConfig, **kwargs): + super().__init__(config) + + self.vision_model = VBertModel.init_vision_model(config, **kwargs) + self.connector = VBertConnector(config) + self.text_model = VBertModel.init_language_model(config, **kwargs) + + self.image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) + ) + self.image_token_id = self.config.image_token_id + + self.post_init() + + @staticmethod + def init_vision_model(config: VBertConfig, **kwargs): + vision_model_config = AutoConfig.from_pretrained( + config.vision_config.vision_model_name, + trust_remote_code=True, + **kwargs, + ) + + vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) + + if hasattr(vision_model, "vision_model"): + # If the model has a vision_model attribute, it means it's a wrapper around another model + vision_model = vision_model.vision_model + + return vision_model + + @staticmethod + def init_language_model(config: VBertConfig, **kwargs): + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + + text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) + # extractor = regex_lookup(language_model_name, language_model_name2model) + + embed_layer = DecoupledEmbedding( + num_embeddings=text_model_config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_config["freeze_text_layers"], + padding_idx=config.pad_token_id, + ) + + text_model.set_input_embeddings(embed_layer) + + return text_model + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. + + This is useful for lora when using gradient checkpointing. + c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 + + Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. + """ + + def get_lowest_module(module): + if len(list(module.children())) == 0: + # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) + return module + else: + # Recursively call the function on each child module + return get_lowest_module(list(module.children())[0]) + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( + make_inputs_require_grads + ) + + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def inputs_merger( + self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor + ): + """ + This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. + The merging happens as follows: + - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. + - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. + We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. + - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. + - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. + """ + _, patch_size, _ = image_hidden_states.shape + + image_mask = input_ids == self.image_token_id + num_image_tokens = image_mask.sum(dim=1) + if not torch.all(num_image_tokens % patch_size == 0): + raise ValueError("At least one sample has tokens not divisible by patch_size.") + + blocks_per_sample = num_image_tokens // patch_size + + offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) + block_offset = offsets[:-1] + row_cum = image_mask.cumsum(dim=-1) + chunk_idx = (row_cum - 1) // patch_size + local_idx = (row_cum - 1) % patch_size + block_idx = block_offset.unsqueeze(1) + chunk_idx + + image_embeds = torch.zeros_like(inputs_embeds) + image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] + + merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) + return merged_embeds + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() + past_seen_tokens = past_key_values.get_seq_length() + + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + # patch_size = self.config.vision_config.patch_size + # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + # patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if inputs_embeds is not None and image_hidden_states is not None: + # When we embed, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + # use_cache=use_cache, + # cache_position=cache_position, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return VBertBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + +class VBertForMaskedLM(VBertPreTrainedModel): + # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.model = VBertModel(config, **kwargs) + self.lm_head = VBertForMaskedLM.init_lm_head(config, **kwargs) + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def init_lm_head(config, **kwargs): + # Get the pretrained model config + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) + # Get the lm head + lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None + if lm_head is None: + logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") + lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) + return lm_head + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, VBertMaskedLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + # Pass the inputs to VBertModel + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Pass the outputs to the MLM head + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + additional_features = self.additional_fc(hidden_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + masked_lm_loss = None + if labels is not None: + # print the ratio of not ignored tokens + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return VBertMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + # @classmethod + # def from_pretrained_models( + # cls, + # text_model_name, + # vision_model_name, + # vl_config, + # *args, + # **kwargs + # ): + # """ + # Use this method when creating a new vloom model that hasn't been yet trained and it'll be + # composed of 2 pre-trained models - hence `pretrained_models`. + # """ + # model = super().from_pretrained_models( + # text_model_name=text_model_name, + # vision_model_name=vision_model_name, + # vl_config=vl_config, + # *args, + # **kwargs + # ) + # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # # fetch the pretrained text model w/o zero.Init + # pretrained_lm_head = AutoModelForMaskedLM.from_pretrained( + # text_model_name, trust_remote_code=True, **kwargs + # ).lm_head + + # # Load the lm_head + # load_state_dict_into_model(model.lm_head, pretrained_lm_head.state_dict(), start_prefix="") + + # return model + +class VModernBertLMHead(nn.Module): + def __init__(self, config, **kwargs): + super().__init__() + pretrained_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) + + self.head = pretrained_model.head + self.decoder = pretrained_model.decoder + + def forward(self, hidden_states): + hidden_states = self.head(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + # @classmethod + # def from_pretrained( + # cls, + # text_model_name, + # vl_config, + # *args, + # **kwargs + # ): + # """ + # Use this method when creating a new vloom model that hasn't been yet trained and it'll be + # composed of 2 pre-trained models - hence `pretrained_models`. + # """ + # lm_head = cls(vl_config, *args, **kwargs) + + # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # # fetch the pretrained text model w/o zero.Init + # pretrained_model = AutoModelForMaskedLM.from_pretrained( + # text_model_name, trust_remote_code=True, **kwargs + # ) + + # pretrained_head = pretrained_model.head + # pretrained_decoder = pretrained_model.decoder + + # # Load the head + # load_state_dict_into_model(lm_head.head, pretrained_head.state_dict(), start_prefix="") + # load_state_dict_into_model(lm_head.decoder, pretrained_decoder.state_dict(), start_prefix="") + + # return lm_head + +class VModernBertForMaskedLM(VBertPreTrainedModel): + # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.model = VBertModel(config, **kwargs) + self.lm_head = VModernBertLMHead(config, **kwargs) + + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, VBertMaskedLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + # Pass the inputs to VBertModel + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Pass the outputs to the MLM head + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + proj_states = self.lm_head.head(hidden_states) + additional_features = self.additional_fc(proj_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + masked_lm_loss = None + if labels is not None: + # print the ratio of not ignored tokens + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return VBertMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None): + config_vl_model = self.config + + lm_config = config_vl_model.text_config + + language_embed_size = lm_config.hidden_size + num_language_layers = lm_config.num_hidden_layers + ffn_inner_size = lm_config.intermediate_size + + vision_config = config_vl_model.vision_config + + # Get vision model blocks infos + vision_patch_size = vision_config.patch_size + vision_hidden_size = vision_config.embed_dim + num_vision_layers = vision_config.num_hidden_layers + # The +1 is for the CLS token + single_image_vision_encoder_seq_len = int(((vision_config.image_size // vision_patch_size) ** 2) // (self.config.pixel_shuffle_factor**2)) + vision_exp_factor = vision_config.intermediate_size // vision_hidden_size + + # Get language blocks infos + language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len + language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4 + + # Get modality projection infos + vision_pipeline_output_seq_len = ( + self.config.perceiver_config.resampler_n_latents + if self.config.use_resampler + else single_image_vision_encoder_seq_len + ) + + language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( + num_layers=num_language_layers, + batch_size=hparams.batch_size_per_gpu, + q_seq_len=language_seq_len, + k_seq_len=language_seq_len, + hidden_size=language_embed_size, + kv_in_dim=language_embed_size, + ff_exp_factor=language_exp_factor, + grad_acc_size=hparams.grad_acc_size, + swiglu=True, + vocab_size=tokenizer.vocab_size, + count_backward=True, # Always True regardless of freezing, because gradients are computed for vision adaptor + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu( + batch_size=hparams.batch_size_per_gpu * max_num_images, + seq_len=vision_pipeline_output_seq_len, + in_features=vision_hidden_size, + out_features=language_embed_size, + count_backward=True, + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + + vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( + num_layers=num_vision_layers, + batch_size=hparams.batch_size_per_gpu * max_num_images, + q_seq_len=single_image_vision_encoder_seq_len, + k_seq_len=single_image_vision_encoder_seq_len, + hidden_size=vision_hidden_size, + kv_in_dim=vision_hidden_size, + ff_exp_factor=vision_exp_factor, + grad_acc_size=hparams.grad_acc_size, + swiglu=False, + vocab_size=None, + count_backward=not hparams.model_config["freeze_config"]["freeze_vision_layers"], + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + if self.config.use_resampler: + perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu( + num_layers=self.config.perceiver_config.resampler_depth, + batch_size=hparams.batch_size_per_gpu * max_num_images, + q_seq_len=self.config.perceiver_config.resampler_n_latents, + vision_embed_seq_len=single_image_vision_encoder_seq_len, + q_k_v_input_dim=vision_hidden_size, + attention_hidden_size=self.config.perceiver_config.resampler_n_heads + * self.config.perceiver_config.resampler_head_dim, + ff_exp_factor=4, + count_backward=True, + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + tflop_count = ( + language_tflops_per_batch_per_gpu + + modality_projection_tflops_per_batch_per_gpu + + perceiver_tflops_per_batch_per_gpu + + vision_tflops_per_batch_per_gpu + ) + else: + tflop_count = ( + language_tflops_per_batch_per_gpu + + modality_projection_tflops_per_batch_per_gpu + + vision_tflops_per_batch_per_gpu + ) + return tflop_count + + @classmethod + def from_pretrained_models( + cls, + text_model_name, + vision_model_name, + vl_config, + *args, + **kwargs + ): + """ + Use this method when creating a new vloom model that hasn't been yet trained and it'll be + composed of 2 pre-trained models - hence `pretrained_models`. + """ + model = super().from_pretrained_models( + text_model_name=text_model_name, + vision_model_name=vision_model_name, + vl_config=vl_config, + *args, + **kwargs + ) + + # Load the lm_head + model.lm_head = VModernBertLMHead.from_pretrained( + text_model_name=text_model_name, + vl_config=vl_config, + *args, + **kwargs + ) + + return model diff --git a/colpali_engine/models/modernvbert/__init__.py b/colpali_engine/models/modernvbert/__init__.py new file mode 100644 index 00000000..d6626781 --- /dev/null +++ b/colpali_engine/models/modernvbert/__init__.py @@ -0,0 +1,2 @@ +from .bivbert import BiModernVBert, BiModernVBertProcessor +from .colvbert import ColModernVBert, ColModernVBertProcessor diff --git a/colpali_engine/models/modernvbert/bivbert/__init__.py b/colpali_engine/models/modernvbert/bivbert/__init__.py new file mode 100644 index 00000000..46514eda --- /dev/null +++ b/colpali_engine/models/modernvbert/bivbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_bivbert import BiModernVBert +from .processing_bivbert import BiModernVBertProcessor diff --git a/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py b/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py new file mode 100644 index 00000000..fb4d05d2 --- /dev/null +++ b/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py @@ -0,0 +1,65 @@ +from typing import Literal + +import torch + +from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel + + +class BiModernVBert(VBertPreTrainedModel): + """ + Initializes the BiIdefics3 model. + + Args: + config : The model configuration. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, pooling_strategy = "mean", **kwargs): + super().__init__(config=config) + self.model = VBertModel(config, **kwargs) + self.pooling_strategy = pooling_strategy + self.post_init() + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass through model and pooling. + + Args: + - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + + pooling_strategy = pooling_strategy or self.pooling_strategy + + # Get CLS token embedding, last token, or mean pool over sequence + if pooling_strategy == "cls": + # Use CLS token (first token) embedding + pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) + elif pooling_strategy == "last": + # Use last token + last_unpadded_index = kwargs["attention_mask"].sum(dim=1) - 1 + pooled_output = last_hidden_states[:, last_unpadded_index.clamp(min=0)] # (batch_size, hidden_size) + elif pooling_strategy == "mean": + # Mean pooling over sequence length + mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) + pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + # L2 normalization + pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True).clamp_min(1e-12) + return pooled_output diff --git a/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py b/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py new file mode 100644 index 00000000..0e7f27ec --- /dev/null +++ b/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py @@ -0,0 +1,51 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.modernvbert.colvbert import ColModernVBertProcessor # noqa: N801 + + +class BiModernVBertProcessor(ColModernVBertProcessor): # noqa: N801 + """ + Processor for BiVBert. + """ + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiVBert. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token # we remove buffer tokens + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/modernvbert/colvbert/__init__.py b/colpali_engine/models/modernvbert/colvbert/__init__.py new file mode 100644 index 00000000..e8f041b5 --- /dev/null +++ b/colpali_engine/models/modernvbert/colvbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colmodernvbert import ColModernVBert +from .processing_colmodernvbert import ColModernVBertProcessor \ No newline at end of file diff --git a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py new file mode 100644 index 00000000..457ecb50 --- /dev/null +++ b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py @@ -0,0 +1,53 @@ +from torch import nn +import torch + +from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel + + +class ColModernVBert(VBertPreTrainedModel): + """ + Initializes the ColVBert model. + + Args: + config : The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): + super().__init__(config=config) + self.model = VBertModel(config, **kwargs) + self.dim = 128 + self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + self.mask_non_image_embeddings = mask_non_image_embeddings + self.main_input_name = "doc_input_ids" + + def forward(self, *args, **kwargs): + """ + Forward pass through the model and the linear layer for dimensionality reduction + + Args: + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + proj = self.custom_text_proj(last_hidden_states) + # normalize l2 norm + # proj = torch.where(kwargs["attention_mask"].unsqueeze(-1).bool(), proj / proj.norm(dim=-1, keepdim=True), torch.zeros_like(proj)) + proj = proj / proj.norm(dim=-1, keepdim=True).clamp_min(1e-12) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj diff --git a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py new file mode 100644 index 00000000..787112b7 --- /dev/null +++ b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py @@ -0,0 +1,97 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature, Idefics3Processor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColModernVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): + """ + Processor for ColIdefics3. + """ + + query_augmentation_token: ClassVar[str] = "" + image_token: ClassVar[str] = "" + visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(self.image_token) + + def process_images( + self, + images: List[Image.Image], + contexts: Optional[List[str]] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColVBert. + + Args: + images: List of PIL images. + contexts: List of optional context prompts, i.e. some text description of the context of the image. + """ + # if contexts is None: + # contexts = [self.visual_prompt_prefix] * len(images) + contexts = [self.visual_prompt_prefix] * len(images) + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=contexts, + images=images, + padding="longest", + return_tensors="pt", + ) + return batch_doc + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColVBert. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + # suffix = self.query_augmentation_token * 10 + suffix = "" + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + patch_size: int, + ) -> Tuple[int, int]: + raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/modernvbert/configuration_vbert.py b/colpali_engine/models/modernvbert/configuration_vbert.py new file mode 100644 index 00000000..504f333b --- /dev/null +++ b/colpali_engine/models/modernvbert/configuration_vbert.py @@ -0,0 +1,232 @@ +import copy +import os +from typing import Any, Dict, Union + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def collect_arg_in_candidates(config, candidates, default = None) -> Any: + """ Gets the argument in a config given a list of candidates """ + for c in candidates: + if hasattr(config, c): + return getattr(config, c) + elif c in config: + return config[c] + if default is not None: + return default + raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) + +class VBertTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "vbert" + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + text_model_name="EuroBERT/EuroBERT-210m", + **kwargs, + ): + self.text_model_name = text_model_name + text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + if hasattr(text_config, "text_config"): + text_config = text_config.text_config + + self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) + self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) + self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) + self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) + + super().__init__(text_model_name=text_model_name, **kwargs) + +class VBertVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + embed_dim (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + """ + model_type = "vbert" + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + # Case for when vllama3 is from the hub with no vision_model_name + vision_model_name="google/siglip2-base-patch16-512", + **kwargs, + ): + self.vision_model_name = vision_model_name + vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + if hasattr(vision_config, "vision_config"): + vision_config = vision_config.vision_config + + self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) + self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) + self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) + self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) + self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) + + super().__init__(vision_model_name=vision_model_name, **kwargs) + +class VBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a + SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should cache the key/value pairs of the attention mechanism. Only + relevant if `config.is_decoder=True`. + image_token_id (`int`, *optional*, defaults to 128257): + The id of the "image" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the token embeddings. + vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): + Custom vision config or dict for the vision tower + text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): + Custom text config or dict for the text model + scale_factor (`int`, *optional*, defaults to 2): + The scale factor for the image encoder. + pad_token_id (`int`, *optional*, defaults to 128002): + The id of the padding token. + + Example: + ```python + >>> from transformers import SmolVLMModel, SmolVLMConfig + >>> # Initializing configuration + >>> configuration = SmolVLMConfig() + >>> # Initializing a model from the configuration + >>> model = SmolVLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vbert" + is_composition = True + # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig} + + DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" + DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" + + def __init__( + self, + text_config: Union[PretrainedConfig, Dict[str, Any]] = None, + vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, + image_token_id: int = 128_257, + vocab_size=128_256, + use_cache = True, + tie_word_embeddings = False, + freeze_config = None, + pad_token_id = None, + initializer_range = 0.02, + pixel_shuffle_factor = 4, + use_resampler = False, + additional_vocab_size = 0, + neftune_noise_alpha = 0.0, + **kwargs, + ): + self.image_token_id = image_token_id + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings + self.scale_factor = pixel_shuffle_factor + self.additional_vocab_size = additional_vocab_size + + if text_config is None: + text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) + elif isinstance(text_config, dict): + text_config = VBertTextConfig(text_config["text_model_name"]) + self.text_config = text_config + + if vision_config is None: + vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) + elif isinstance(vision_config, dict): + vision_config = VBertVisionConfig(vision_config["vision_model_name"]) + self.vision_config = vision_config + + self.freeze_config = freeze_config + + # Pixel shuffle factor + self.pixel_shuffle_factor = pixel_shuffle_factor + self.use_resampler = use_resampler + + self.neftune_noise_alpha = neftune_noise_alpha + + self.initializer_range = initializer_range + + hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) + + super().__init__( + **kwargs, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + vocab_size=vocab_size, + hidden_size=hidden_size, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["model_type"] = self.__class__.model_type + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + # output["freeze_config"] = self.freeze_config.to_dict() + + return output + + # @classmethod + # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + # outputs = super(VBertConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) + # return outputs + + @classmethod + def from_pretrained_models( + cls, + text_model_name: Union[str, os.PathLike], + vision_model_name: Union[str, os.PathLike], + **kwargs + ) -> "PretrainedConfig": + # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + text_model_config = VBertTextConfig(text_model_name) + vision_model_config = VBertVisionConfig(vision_model_name) + return cls( + text_config=text_model_config, + vision_config=vision_model_config, + **kwargs + ) diff --git a/colpali_engine/models/modernvbert/modeling_vbert.py b/colpali_engine/models/modernvbert/modeling_vbert.py new file mode 100644 index 00000000..828a35e6 --- /dev/null +++ b/colpali_engine/models/modernvbert/modeling_vbert.py @@ -0,0 +1,931 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging +from transformers.cache_utils import DynamicCache +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput + +from .configuration_vbert import VBertConfig + +logger = logging.get_logger(__name__) + +torch.set_float32_matmul_precision('high') + +class DecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. + In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. + If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=False, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. + partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), + since the 2nd embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do + the padding, but then we have to create a new tensor and populate it with 2 tensors that are + spread out across various indices - i.e. not a simple concat - I haven't benchmarked the + complex case if it's any faster, given that seqlens are usually relatively short it's + probably not faster or if faster not by much - but might be a good idea to measure. + + """ + if self.num_additional_embeddings == 0: + return self.additional_embedding(input_ids) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.embedding_dim, + self.partially_freeze, + ) + +@dataclass +class VBertBaseModelOutput(BaseModelOutput): + """ + Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class VBertMaskedLMOutput(MaskedLMOutput): + """ + Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + +class VBertSimpleMLP(nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.proj = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + return self.proj(x) + +class VBertConnector(nn.Module): + def __init__(self, config): + super().__init__() + self.scale_factor = config.pixel_shuffle_factor + self.modality_projection = VBertSimpleMLP( + input_size=config.vision_config.hidden_size * (config.scale_factor**2), + output_size=config.text_config.hidden_size + ) + + def pixel_shuffle(self, x, scale_factor): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + +class VBertPreTrainedModel(PreTrainedModel): + config_class = VBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["VBertDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + """Initialize the weights.""" + + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + +class VBertModel(VBertPreTrainedModel): + """ + A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger + in forward. Instead, we override inputs_merger here with custom logic. + """ + + def __init__(self, config: VBertConfig, **kwargs): + super().__init__(config) + + self.vision_model = VBertModel.init_vision_model(config, **kwargs) + self.connector = VBertConnector(config) + self.text_model = VBertModel.init_language_model(config, **kwargs) + + self.image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) + ) + self.image_token_id = self.config.image_token_id + + self.post_init() + + @staticmethod + def init_vision_model(config: VBertConfig, **kwargs): + vision_model_config = AutoConfig.from_pretrained( + config.vision_config.vision_model_name, + trust_remote_code=True, + **kwargs, + ) + + vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) + + if hasattr(vision_model, "vision_model"): + # If the model has a vision_model attribute, it means it's a wrapper around another model + vision_model = vision_model.vision_model + + return vision_model + + @staticmethod + def init_language_model(config: VBertConfig, **kwargs): + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + + text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) + # extractor = regex_lookup(language_model_name, language_model_name2model) + + embed_layer = DecoupledEmbedding( + num_embeddings=text_model_config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_config["freeze_text_layers"], + padding_idx=config.pad_token_id, + ) + + text_model.set_input_embeddings(embed_layer) + + return text_model + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. + + This is useful for lora when using gradient checkpointing. + c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 + + Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. + """ + + def get_lowest_module(module): + if len(list(module.children())) == 0: + # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) + return module + else: + # Recursively call the function on each child module + return get_lowest_module(list(module.children())[0]) + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( + make_inputs_require_grads + ) + + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def inputs_merger( + self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor + ): + """ + This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. + The merging happens as follows: + - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. + - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. + We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. + - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. + - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. + """ + _, patch_size, _ = image_hidden_states.shape + + image_mask = input_ids == self.image_token_id + num_image_tokens = image_mask.sum(dim=1) + if not torch.all(num_image_tokens % patch_size == 0): + raise ValueError("At least one sample has tokens not divisible by patch_size.") + + blocks_per_sample = num_image_tokens // patch_size + + offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) + block_offset = offsets[:-1] + row_cum = image_mask.cumsum(dim=-1) + chunk_idx = (row_cum - 1) // patch_size + local_idx = (row_cum - 1) % patch_size + block_idx = block_offset.unsqueeze(1) + chunk_idx + + image_embeds = torch.zeros_like(inputs_embeds) + image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] + + merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) + return merged_embeds + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() + past_seen_tokens = past_key_values.get_seq_length() + + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + # patch_size = self.config.vision_config.patch_size + # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + # patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if inputs_embeds is not None and image_hidden_states is not None: + # When we embed, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + # use_cache=use_cache, + # cache_position=cache_position, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return VBertBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + +class VBertForMaskedLM(VBertPreTrainedModel): + # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.model = VBertModel(config, **kwargs) + self.lm_head = VBertForMaskedLM.init_lm_head(config, **kwargs) + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def init_lm_head(config, **kwargs): + # Get the pretrained model config + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) + # Get the lm head + lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None + if lm_head is None: + logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") + lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) + return lm_head + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, VBertMaskedLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + # Pass the inputs to VBertModel + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Pass the outputs to the MLM head + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + additional_features = self.additional_fc(hidden_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + masked_lm_loss = None + if labels is not None: + # print the ratio of not ignored tokens + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return VBertMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + # @classmethod + # def from_pretrained_models( + # cls, + # text_model_name, + # vision_model_name, + # vl_config, + # *args, + # **kwargs + # ): + # """ + # Use this method when creating a new vloom model that hasn't been yet trained and it'll be + # composed of 2 pre-trained models - hence `pretrained_models`. + # """ + # model = super().from_pretrained_models( + # text_model_name=text_model_name, + # vision_model_name=vision_model_name, + # vl_config=vl_config, + # *args, + # **kwargs + # ) + # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # # fetch the pretrained text model w/o zero.Init + # pretrained_lm_head = AutoModelForMaskedLM.from_pretrained( + # text_model_name, trust_remote_code=True, **kwargs + # ).lm_head + + # # Load the lm_head + # load_state_dict_into_model(model.lm_head, pretrained_lm_head.state_dict(), start_prefix="") + + # return model + +class VModernBertLMHead(nn.Module): + def __init__(self, config, **kwargs): + super().__init__() + pretrained_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + trust_remote_code=True, + **kwargs, + ) + pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) + + self.head = pretrained_model.head + self.decoder = pretrained_model.decoder + + def forward(self, hidden_states): + hidden_states = self.head(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + # @classmethod + # def from_pretrained( + # cls, + # text_model_name, + # vl_config, + # *args, + # **kwargs + # ): + # """ + # Use this method when creating a new vloom model that hasn't been yet trained and it'll be + # composed of 2 pre-trained models - hence `pretrained_models`. + # """ + # lm_head = cls(vl_config, *args, **kwargs) + + # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # # fetch the pretrained text model w/o zero.Init + # pretrained_model = AutoModelForMaskedLM.from_pretrained( + # text_model_name, trust_remote_code=True, **kwargs + # ) + + # pretrained_head = pretrained_model.head + # pretrained_decoder = pretrained_model.decoder + + # # Load the head + # load_state_dict_into_model(lm_head.head, pretrained_head.state_dict(), start_prefix="") + # load_state_dict_into_model(lm_head.decoder, pretrained_decoder.state_dict(), start_prefix="") + + # return lm_head + +class VModernBertForMaskedLM(VBertPreTrainedModel): + # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.model = VBertModel(config, **kwargs) + self.lm_head = VModernBertLMHead(config, **kwargs) + + if self.out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=self.in_features, + out_features=self.out_additional_features, + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, VBertMaskedLMOutput]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + # Pass the inputs to VBertModel + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Pass the outputs to the MLM head + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + proj_states = self.lm_head.head(hidden_states) + additional_features = self.additional_fc(proj_states) + logits = torch.cat((logits, additional_features), -1) + logits = logits.float() + + masked_lm_loss = None + if labels is not None: + # print the ratio of not ignored tokens + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return VBertMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None): + config_vl_model = self.config + + lm_config = config_vl_model.text_config + + language_embed_size = lm_config.hidden_size + num_language_layers = lm_config.num_hidden_layers + ffn_inner_size = lm_config.intermediate_size + + vision_config = config_vl_model.vision_config + + # Get vision model blocks infos + vision_patch_size = vision_config.patch_size + vision_hidden_size = vision_config.embed_dim + num_vision_layers = vision_config.num_hidden_layers + # The +1 is for the CLS token + single_image_vision_encoder_seq_len = int(((vision_config.image_size // vision_patch_size) ** 2) // (self.config.pixel_shuffle_factor**2)) + vision_exp_factor = vision_config.intermediate_size // vision_hidden_size + + # Get language blocks infos + language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len + language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4 + + # Get modality projection infos + vision_pipeline_output_seq_len = ( + self.config.perceiver_config.resampler_n_latents + if self.config.use_resampler + else single_image_vision_encoder_seq_len + ) + + language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( + num_layers=num_language_layers, + batch_size=hparams.batch_size_per_gpu, + q_seq_len=language_seq_len, + k_seq_len=language_seq_len, + hidden_size=language_embed_size, + kv_in_dim=language_embed_size, + ff_exp_factor=language_exp_factor, + grad_acc_size=hparams.grad_acc_size, + swiglu=True, + vocab_size=tokenizer.vocab_size, + count_backward=True, # Always True regardless of freezing, because gradients are computed for vision adaptor + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu( + batch_size=hparams.batch_size_per_gpu * max_num_images, + seq_len=vision_pipeline_output_seq_len, + in_features=vision_hidden_size, + out_features=language_embed_size, + count_backward=True, + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + + vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( + num_layers=num_vision_layers, + batch_size=hparams.batch_size_per_gpu * max_num_images, + q_seq_len=single_image_vision_encoder_seq_len, + k_seq_len=single_image_vision_encoder_seq_len, + hidden_size=vision_hidden_size, + kv_in_dim=vision_hidden_size, + ff_exp_factor=vision_exp_factor, + grad_acc_size=hparams.grad_acc_size, + swiglu=False, + vocab_size=None, + count_backward=not hparams.model_config["freeze_config"]["freeze_vision_layers"], + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + if self.config.use_resampler: + perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu( + num_layers=self.config.perceiver_config.resampler_depth, + batch_size=hparams.batch_size_per_gpu * max_num_images, + q_seq_len=self.config.perceiver_config.resampler_n_latents, + vision_embed_seq_len=single_image_vision_encoder_seq_len, + q_k_v_input_dim=vision_hidden_size, + attention_hidden_size=self.config.perceiver_config.resampler_n_heads + * self.config.perceiver_config.resampler_head_dim, + ff_exp_factor=4, + count_backward=True, + use_grad_checkpointing=hparams.gradient_checkpointing, + ) + tflop_count = ( + language_tflops_per_batch_per_gpu + + modality_projection_tflops_per_batch_per_gpu + + perceiver_tflops_per_batch_per_gpu + + vision_tflops_per_batch_per_gpu + ) + else: + tflop_count = ( + language_tflops_per_batch_per_gpu + + modality_projection_tflops_per_batch_per_gpu + + vision_tflops_per_batch_per_gpu + ) + return tflop_count + + @classmethod + def from_pretrained_models( + cls, + text_model_name, + vision_model_name, + vl_config, + *args, + **kwargs + ): + """ + Use this method when creating a new vloom model that hasn't been yet trained and it'll be + composed of 2 pre-trained models - hence `pretrained_models`. + """ + model = super().from_pretrained_models( + text_model_name=text_model_name, + vision_model_name=vision_model_name, + vl_config=vl_config, + *args, + **kwargs + ) + + # Load the lm_head + model.lm_head = VModernBertLMHead.from_pretrained( + text_model_name=text_model_name, + vl_config=vl_config, + *args, + **kwargs + ) + + return model diff --git a/colpali_engine/models/vbert/bivbert/modeling_bivbert.py b/colpali_engine/models/vbert/bivbert/modeling_bivbert.py index fcf5eb60..33b7bd22 100644 --- a/colpali_engine/models/vbert/bivbert/modeling_bivbert.py +++ b/colpali_engine/models/vbert/bivbert/modeling_bivbert.py @@ -17,7 +17,7 @@ class BiVBert(VBertPreTrainedModel): _supports_sdpa = True _supports_cache_class = True - def __init__(self, config, pooling_strategy = "last", **kwargs): + def __init__(self, config, pooling_strategy = "mean", **kwargs): super().__init__(config=config) self.model = VBertModel(config, **kwargs) self.pooling_strategy = pooling_strategy @@ -50,8 +50,9 @@ def forward( # Use CLS token (first token) embedding pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) elif pooling_strategy == "last": - # use last token since we are left padding - pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) + # Use last token + last_unpadded_index = kwargs["attention_mask"].sum(dim=1) - 1 + pooled_output = last_hidden_states[:, last_unpadded_index.clamp(min=0)] # (batch_size, hidden_size) elif pooling_strategy == "mean": # Mean pooling over sequence length mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) diff --git a/colpali_engine/models/vbert/bivbert/processing_bivbert.py b/colpali_engine/models/vbert/bivbert/processing_bivbert.py index 3bbf7750..1ddc12bf 100644 --- a/colpali_engine/models/vbert/bivbert/processing_bivbert.py +++ b/colpali_engine/models/vbert/bivbert/processing_bivbert.py @@ -3,7 +3,7 @@ import torch from transformers import BatchEncoding, BatchFeature -from colpali_engine.models.vbert.colvbert import ColVBertProcessor +from colpali_engine.models.vbert.colvbert import ColVBertProcessor # noqa: N801 class BiVBertProcessor(ColVBertProcessor): # noqa: N801 diff --git a/colpali_engine/models/vbert/modeling_vbert.py b/colpali_engine/models/vbert/modeling_vbert.py index c2d6b380..3d681d69 100644 --- a/colpali_engine/models/vbert/modeling_vbert.py +++ b/colpali_engine/models/vbert/modeling_vbert.py @@ -393,7 +393,7 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.training and self.text_model.gradient_checkpointing and use_cache: From 4ddc4531ff27d43c6c2a44ba558edc913852faf4 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Tue, 8 Jul 2025 16:40:13 +0200 Subject: [PATCH 09/42] losses --- colpali_engine/loss/__init__.py | 1 + colpali_engine/loss/bi_encoder_losses.py | 123 +++++++++++++++++- .../loss/late_interaction_losses.py | 13 +- 3 files changed, 132 insertions(+), 5 deletions(-) diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index 0ad15237..0e3ecbc2 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -4,6 +4,7 @@ BiNegativeCELoss, BiPairwiseCELoss, BiPairwiseNegativeCELoss, + BiSigmoidLoss, ) from .late_interaction_losses import ( ColbertLoss, diff --git a/colpali_engine/loss/bi_encoder_losses.py b/colpali_engine/loss/bi_encoder_losses.py index b82423ff..92e4a657 100644 --- a/colpali_engine/loss/bi_encoder_losses.py +++ b/colpali_engine/loss/bi_encoder_losses.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F # noqa: N812 from torch.nn import CrossEntropyLoss @@ -111,6 +112,60 @@ def forward( return self.ce_loss(scores / self.temperature, pos_idx) +class BiPairedEncoderLoss(BiEncoderModule): + """ + InfoNCE loss for bi-encoders without explicit negatives. + + Args: + temperature (float): Scaling factor for logits. + pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True. + max_batch_size (int): Max batch size for index buffer caching. + filter_threshold (float): Threshold ratio for negative filtering. + filter_factor (float): Factor to down-weight filtered negatives. + """ + + def __init__( + self, + temperature: float = 0.02, + pos_aware_negative_filtering: bool = False, + max_batch_size: int = 1024, + filter_threshold: float = 0.95, + filter_factor: float = 0.5, + ): + super().__init__(max_batch_size, temperature, filter_threshold, filter_factor) + self.pos_aware_negative_filtering = pos_aware_negative_filtering + self.ce_loss = CrossEntropyLoss() + + def forward( + self, + query_embeddings: torch.Tensor, + doc_embeddings: torch.Tensor, + offset: int = 0, + ) -> torch.Tensor: + """ + Compute the InfoNCE loss over a batch of bi-encoder embeddings. + + Args: + query_embeddings (Tensor[B, D]): Query vectors. + doc_embeddings (Tensor[B, D]): Document vectors. + offset (int): Offset for positive indices (multi-GPU). + + Returns: + Tensor: Scalar cross-entropy loss. + """ + # Compute in-batch similarity matrix + scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) + batch_size = scores.size(0) + idx, pos_idx = self._get_idx(batch_size, offset, scores.device) + + if self.pos_aware_negative_filtering: + self._filter_high_negatives(scores, pos_idx) + + q2t = self.ce_loss(scores / self.temperature, pos_idx) + t2q = self.ce_loss(scores.T / self.temperature, ...) + + return (q2t + t2q) / 2.0 + class BiNegativeCELoss(BiEncoderModule): """ @@ -171,7 +226,7 @@ def forward( pos_scores = (query_embeddings * doc_embeddings).sum(dim=1) / self.temperature neg_scores = (query_embeddings * neg_doc_embeddings).sum(dim=1) / self.temperature - loss = torch.nn.functional.softplus(neg_scores - pos_scores).mean() + loss = F.softplus(neg_scores - pos_scores).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset) @@ -292,3 +347,69 @@ def forward( loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight return loss + +class BiSigmoidLoss(BiEncoderModule): + """ + Sigmoid loss for ColBERT with in-batch negatives. + + Args: + temperature (float): Scaling factor for logits. + pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True. + max_batch_size (int): Max batch size for index buffer caching. + filter_threshold (float): Threshold ratio for negative filtering. + filter_factor (float): Factor to down-weight filtered negatives. + """ + + def __init__( + self, + temperature: float = 0.02, + pos_aware_negative_filtering: bool = False, + max_batch_size: int = 1024, + filter_threshold: float = 0.95, + filter_factor: float = 0.5, + ): + super().__init__(max_batch_size, temperature, filter_threshold, filter_factor) + self.pos_aware_negative_filtering = pos_aware_negative_filtering + + def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor: + """ + Compute the sigmoid loss for a batch of bi-encoder embeddings. + + Args: + query_embeddings (Tensor[B, D]): Query vectors. + doc_embeddings (Tensor[B, D]): Document vectors. + offset (int): Offset for positive indices (multi-GPU). + + Returns: + Tensor: Scalar cross-entropy loss. + """ + + # Compute in-batch similarity matrix + scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) + + batch_size, num_targets = scores.shape + device = scores.device + + _, pos_idx = self._get_idx(batch_size, offset, device) + + if self.pos_aware_negative_filtering: + self._filter_high_negatives(scores, pos_idx) + + all_losses = [] + for k in range(num_targets // batch_size): + # mask equal to 1 on offset -> offset + batch_size + curr_idx = torch.arange(offset, offset + batch_size, device=device) + # keep only the scores for the current batch + curr_scores = scores[:, curr_idx].view(-1) / self.temperature + # compute the labels + labels = -torch.ones(batch_size * batch_size, device=device) + if k == 0: + flat_pos = (pos_idx - offset) * (batch_size + 1) + labels[flat_pos] = 1.0 + # compute the loss + block_loss = F.softplus(curr_scores * labels) + all_losses.append(block_loss) + # shift the offset for the next batch + offset = (offset + batch_size) % num_targets + + return torch.stack(all_losses, dim=0).mean() \ No newline at end of file diff --git a/colpali_engine/loss/late_interaction_losses.py b/colpali_engine/loss/late_interaction_losses.py index 03dcfd84..5d08c6ac 100644 --- a/colpali_engine/loss/late_interaction_losses.py +++ b/colpali_engine/loss/late_interaction_losses.py @@ -152,7 +152,6 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings) scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2) - if self.normalize_scores: scores = self._apply_normalization(scores, lengths) @@ -163,7 +162,6 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, self._filter_high_negatives(scores, pos_idx) # print(f"Scores shape: {scores.shape}, offset: {offset}") - return self.ce_loss(scores / self.temperature, pos_idx) @@ -452,6 +450,13 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, if self.pos_aware_negative_filtering: self._filter_high_negatives(scores, pos_idx) - loss = self.ce_loss(scores / self.temperature, pos_idx) + # for each idx in pos_idx, the 2D index (idx, idx) → flat index = idx * B + idx + # build a 1-D mask of length B*B with ones at those positions + flat_pos = pos_idx * (batch_size + 1) + pos_mask = -torch.ones(batch_size * batch_size, device=scores.device) + pos_mask[flat_pos] = 1.0 - return loss.mean() + # flatten the scores to [B * B] + scores = scores.view(-1) / self.temperature + + return F.softplus(scores * pos_mask).mean() From e54df499b7f8384af756f9b0e8f133a6969ac1ea Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Thu, 10 Jul 2025 11:18:44 +0200 Subject: [PATCH 10/42] symetric loss + flex biencodr score --- colpali_engine/trainer/contrastive_trainer.py | 49 +++++++++++++------ colpali_engine/utils/processing_utils.py | 15 +++--- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 78514eba..eb340f62 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -19,7 +19,7 @@ def concat_all_gather(t: torch.Tensor) -> torch.Tensor: class ContrastiveTrainer(Trainer): - def __init__(self, loss_func, is_vision_model, *args, **kwargs): + def __init__(self, loss_func, is_vision_model, compute_symetric_loss=False, *args, **kwargs): if isinstance(kwargs["train_dataset"], DatasetDict): dataset_list = list(kwargs["train_dataset"].values()) elif isinstance(kwargs["train_dataset"], list): @@ -43,6 +43,7 @@ def __init__(self, loss_func, is_vision_model, *args, **kwargs): self.is_vision_model = is_vision_model # Unused argument, will be removed in 0.4.0 self.args.remove_unused_columns = False # Safety, don't remove dataset columns from dataloader self.dataset_list = dataset_list + self.compute_symetric_loss = compute_symetric_loss def get_train_dataloader(self) -> DataLoader: """ @@ -118,6 +119,27 @@ def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optiona drop_last=self.args.dataloader_drop_last, generator=generator, ) + + def _compute_loss_from_outputs( + self, + query_outputs, + pos_target_outputs, + neg_target_outputs=None, + ): + offset = 0 + batch_size = query_outputs.size(0) + if self.accelerator.num_processes > 1 and self.accelerator.sync_gradients: + # gather docs across all processes + pos_target_outputs = concat_all_gather(pos_target_outputs) + rank = self.accelerator.process_index + offset = rank * batch_size + + if neg_target_outputs is not None: + loss = self.loss_func(query_outputs, pos_target_outputs, neg_target_outputs, offset=offset) + else: + loss = self.loss_func(query_outputs, pos_target_outputs, offset=offset) + + return loss def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) @@ -126,20 +148,17 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if "neg_doc_input_ids" in inputs: # Negative docs are not gathered across processes, so we can use them without offset neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) - loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) - return (loss, (query_outputs, doc_outputs, neg_doc_outputs)) if return_outputs else loss - - offset = 0 - if self.accelerator.num_processes > 1 and self.accelerator.sync_gradients: - # gather docs across all processes - if num_items_in_batch is None: - num_items_in_batch = inputs["doc_input_ids"].shape[0] - doc_outputs = self.accelerator.pad_across_processes(doc_outputs, dim=1, pad_index=0, pad_first=True) - doc_outputs = concat_all_gather(doc_outputs) - rank = self.accelerator.process_index - offset = rank * num_items_in_batch - - loss = self.loss_func(query_outputs, doc_outputs, offset=offset) + else: + neg_doc_outputs = None + + # query -> doc loss + loss = self._compute_loss_from_outputs(query_outputs, doc_outputs, neg_doc_outputs) + + if self.compute_symetric_loss: + assert neg_doc_outputs is None, "Symmetric loss is not compatible with negative documents." + # doc -> query loss + sym_loss = self._compute_loss_from_outputs(doc_outputs, query_outputs) + loss = (loss + sym_loss) / 2 return (loss, (query_outputs, doc_outputs)) if return_outputs else loss diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index d6bc4558..bb7e006e 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -22,7 +22,7 @@ class BaseVisualRetrieverProcessor(ABC): Base class for visual retriever processors. """ - query_prefix: ClassVar[str] = "Query: " # Default prefix for queries. Override in subclasses if needed. + query_prefix: ClassVar[str] = "" # Default prefix for queries. Override in subclasses if needed. @abstractmethod def process_images( @@ -134,14 +134,17 @@ def score_single_vector( """ device = device or get_torch_device("auto") - if len(qs) == 0: - raise ValueError("No queries provided") - if len(ps) == 0: - raise ValueError("No passages provided") + if isinstance(qs, list) and isinstance(ps, list): + if len(qs) == 0: + raise ValueError("No queries provided") + if len(ps) == 0: + raise ValueError("No passages provided") - if isinstance(qs, list): qs = torch.stack(qs).to(device) ps = torch.stack(ps).to(device) + else: + qs = qs.to(device) + ps = ps.to(device) scores = torch.einsum("bd,cd->bc", qs, ps) assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" From 0375d68395a42a4600b4a2b221c9b97523b2b60f Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Thu, 10 Jul 2025 12:21:42 +0200 Subject: [PATCH 11/42] process --- colpali_engine/utils/processing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index bb7e006e..e8a2a011 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -22,7 +22,7 @@ class BaseVisualRetrieverProcessor(ABC): Base class for visual retriever processors. """ - query_prefix: ClassVar[str] = "" # Default prefix for queries. Override in subclasses if needed. + query_prefix: ClassVar[str] = "Query: " # Default prefix for queries. Override in subclasses if needed. @abstractmethod def process_images( From 2ab0cb0911f87303012b5a0b03b98707dcfa1257 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Mon, 14 Jul 2025 16:56:39 +0200 Subject: [PATCH 12/42] merge --- colpali_engine/__init__.py | 4 +- .../collators/visual_retriever_collator.py | 3 +- colpali_engine/models/__init__.py | 3 +- .../colidefics3/processing_colidefics3.py | 1 + .../modernvbert/bivbert/modeling_bivbert.py | 3 +- .../modernvbert/bivbert/processing_bivbert.py | 29 +++---- .../colvbert/processing_colmodernvbert.py | 43 ++++------- colpali_engine/models/qwen_omni/__init__.py | 1 - .../models/qwen_omni/colqwen_omni/__init__.py | 2 - colpali_engine/models/siglip/__init__.py | 2 + .../models/siglip/modeling_bisiglip.py | 50 +++++++++++++ .../models/siglip/processing_bisiglip.py | 75 +++++++++++++++++++ colpali_engine/trainer/contrastive_trainer.py | 2 +- 13 files changed, 158 insertions(+), 60 deletions(-) delete mode 100644 colpali_engine/models/qwen_omni/__init__.py delete mode 100644 colpali_engine/models/qwen_omni/colqwen_omni/__init__.py create mode 100644 colpali_engine/models/siglip/__init__.py create mode 100644 colpali_engine/models/siglip/modeling_bisiglip.py create mode 100644 colpali_engine/models/siglip/processing_bisiglip.py diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py index 67b4dbc1..3341df6b 100644 --- a/colpali_engine/__init__.py +++ b/colpali_engine/__init__.py @@ -12,7 +12,7 @@ ColQwen2, ColQwen2_5, ColQwen2_5_Processor, - ColQwen2_5Omni, - ColQwen2_5OmniProcessor, + # ColQwen2_5Omni, + # ColQwen2_5OmniProcessor, ColQwen2Processor, ) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 21a2f222..47955a95 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -78,7 +78,8 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: ) # Process queries. - queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] + # queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] + queries = [q + self.processor.query_augmentation_token * 10 for q in queries] batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) # Process targets. diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 3546b276..7090bc32 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -2,6 +2,7 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor -from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor +# from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor from .eurovbert import ColEuroVBert, ColEuroVBertProcessor from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor +from .siglip import BiSiglip, BiSiglipProcessor diff --git a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py index afed014b..497fe12c 100644 --- a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py +++ b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py @@ -18,6 +18,7 @@ class ColIdefics3Processor(BaseVisualRetrieverProcessor, Idefics3Processor): def __init__(self, *args, image_seq_len=64, **kwargs): super().__init__(*args, image_seq_len=image_seq_len, **kwargs) + self.tokenizer.padding_side = "left" def process_images( self, diff --git a/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py b/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py index fb4d05d2..bf73c421 100644 --- a/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py +++ b/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py @@ -51,8 +51,7 @@ def forward( pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) elif pooling_strategy == "last": # Use last token - last_unpadded_index = kwargs["attention_mask"].sum(dim=1) - 1 - pooled_output = last_hidden_states[:, last_unpadded_index.clamp(min=0)] # (batch_size, hidden_size) + pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) elif pooling_strategy == "mean": # Mean pooling over sequence length mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) diff --git a/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py b/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py index 0e7f27ec..c6e72fdc 100644 --- a/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py +++ b/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py @@ -11,33 +11,22 @@ class BiModernVBertProcessor(ColModernVBertProcessor): # noqa: N801 Processor for BiVBert. """ - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for BiVBert. - - NOTE: `max_length` is not used and kept only for trainer compatibility. + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: """ - if suffix is None: - suffix = self.query_augmentation_token # we remove buffer tokens - if contexts is None: - contexts = [""] * len(texts) + Process texts for BiModernVBert. - prompts = [context + text + suffix for context, text in zip(contexts, texts)] + Args: + texts: List of input texts. - batch_texts = self( - text=prompts, + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, return_tensors="pt", padding="longest", ) - return batch_texts - def score( self, qs: List[torch.Tensor], diff --git a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py index 787112b7..99f143e7 100644 --- a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py @@ -18,6 +18,7 @@ class ColModernVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.tokenizer.padding_side = "left" @property def image_token_id(self) -> int: @@ -26,57 +27,39 @@ def image_token_id(self) -> int: def process_images( self, images: List[Image.Image], - contexts: Optional[List[str]] = None, ) -> Union[BatchFeature, BatchEncoding]: """ - Process images for ColVBert. + Process images for ColModernVBert. Args: images: List of PIL images. - contexts: List of optional context prompts, i.e. some text description of the context of the image. """ - # if contexts is None: - # contexts = [self.visual_prompt_prefix] * len(images) - contexts = [self.visual_prompt_prefix] * len(images) - images = [image.convert("RGB") for image in images] - + batch_doc = self( - text=contexts, + text=[self.visual_prompt_prefix] * len(images), images=images, padding="longest", return_tensors="pt", ) return batch_doc - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: """ - Process texts for ColVBert. + Process texts for ColModernVBert. - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - # suffix = self.query_augmentation_token * 10 - suffix = "" - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] + Args: + texts: List of input texts. - batch_texts = self( - text=prompts, + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, return_tensors="pt", padding="longest", ) - return batch_texts - def score( self, qs: List[torch.Tensor], diff --git a/colpali_engine/models/qwen_omni/__init__.py b/colpali_engine/models/qwen_omni/__init__.py deleted file mode 100644 index 7dd08129..00000000 --- a/colpali_engine/models/qwen_omni/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .colqwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py b/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py deleted file mode 100644 index b754b552..00000000 --- a/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_colqwen_omni import ColQwen2_5Omni -from .processing_colqwen_omni import ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/siglip/__init__.py b/colpali_engine/models/siglip/__init__.py new file mode 100644 index 00000000..f1bb314b --- /dev/null +++ b/colpali_engine/models/siglip/__init__.py @@ -0,0 +1,2 @@ +from .modeling_bisiglip import BiSiglip +from .processing_bisiglip import BiSiglipProcessor \ No newline at end of file diff --git a/colpali_engine/models/siglip/modeling_bisiglip.py b/colpali_engine/models/siglip/modeling_bisiglip.py new file mode 100644 index 00000000..97f65a6f --- /dev/null +++ b/colpali_engine/models/siglip/modeling_bisiglip.py @@ -0,0 +1,50 @@ +from typing import ClassVar + +from transformers import SiglipModel + + +class BiSiglip(SiglipModel): + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + + def forward(self, *args, **kwargs): + """ + Forward pass through Llama and the linear layer for dimensionality reduction + + Args: + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) + """ + + output_attentions = kwargs.pop("output_attentions", None) + output_hidden_states = kwargs.pop("output_hidden_states", None) + return_dict = kwargs.pop("return_dict", None) + interpolate_pos_encoding = kwargs.pop("interpolate_pos_encoding", None) + + if "pixel_values" in kwargs: + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + pixel_values = kwargs.pop("pixel_values") + + embeds = self.vision_model( + pixel_values=pixel_values.to(dtype=self.dtype), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ).pooler_output + + else: + embeds = self.text_model( + input_ids=kwargs.pop("input_ids", None), + attention_mask=kwargs.pop("attention_mask", None), + position_ids=kwargs.pop("position_ids", None), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ).pooler_output + + # normalized features + embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True) + return embeds \ No newline at end of file diff --git a/colpali_engine/models/siglip/processing_bisiglip.py b/colpali_engine/models/siglip/processing_bisiglip.py new file mode 100644 index 00000000..073b6372 --- /dev/null +++ b/colpali_engine/models/siglip/processing_bisiglip.py @@ -0,0 +1,75 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature +from transformers.models.siglip import SiglipProcessor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class BiSiglipProcessor(BaseVisualRetrieverProcessor, SiglipProcessor): # noqa: N801 + """ + Processor for BiSiglip + """ + + query_augmentation_token: ClassVar[str] = "" + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Args: + texts: List of input texts. + + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, + return_tensors="pt", + padding="max_length", # the model was trained with max_length padding + max_length=64, + truncation=True, + ) + + def process_images( + self, + images: List[Image.Image], + ) -> Union[BatchFeature, BatchEncoding]: + """ + Args: + images: List of PIL images. + """ + images = [image.convert("RGB") for image in images] + + batch_doc = self( + images=images, + return_tensors="pt", + padding="longest", # the model was trained with max_length padding + ) + return batch_doc + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) + + def get_n_patches( + self, + image_size: Tuple[int, int], + spatial_merge_size: int, + ) -> Tuple[int, int]: + """ + Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of + size (height, width) with the given patch size. + + The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in + as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`. + """ + raise NotImplementedError("BiSiglip does not support the `get_n_patches` method. ") \ No newline at end of file diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index eb340f62..5d46a067 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -142,7 +142,7 @@ def _compute_loss_from_outputs( return loss def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) + query_outputs = model(**{k[6:]: v for k, v in inputs.items() if k.startswith("query")}) # feed only kwargs with 'doc_' prefix doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) if "neg_doc_input_ids" in inputs: From 00337b1605113416d0d61264df9e5223dcecd4d4 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Tue, 15 Jul 2025 12:38:46 +0200 Subject: [PATCH 13/42] fix dup --- colpali_engine/utils/processing_utils.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index e8a2a011..46dd6857 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -89,30 +89,6 @@ def process_queries( return self.process_texts(texts=texts) - def process_queries( - self, - texts: List[str], - max_length: int = 50, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process a list of queries into a format suitable for the model. - Args: - texts (List[str]): List of texts to process. - max_length (int, optional): Maximum length of the texts. Defaults to 50. - suffix (Optional[str], optional): Optional suffix to append to each text. - Returns: - Union[BatchFeature, BatchEncoding]: Processed texts. - - NOTE: This function maintains back-compatibility, use `process_texts` for better control on context. - """ - return self.process_texts( - texts=texts, - contexts=[self.query_prefix] * len(texts), - max_length=max_length, - suffix=suffix, - ) - @abstractmethod def score( self, From ec4d4dd6bf6931c03946d05edb4f68c34f58ef73 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Wed, 13 Aug 2025 10:54:17 +0200 Subject: [PATCH 14/42] latest --- colpali_engine/models/__init__.py | 3 +- colpali_engine/models/eurovbert/__init__.py | 2 +- .../models/eurovbert/bivbert/__init__.py | 4 +- ...ing_bivbert.py => modeling_bieurovbert.py} | 14 +- .../bivbert/processing_bieurovbert.py | 40 + .../eurovbert/bivbert/processing_bivbert.py | 51 - .../colvbert/modeling_coleurovbert.py | 2 +- .../colvbert/processing_coleurovbert.py | 44 +- .../models/eurovbert/configuration_vbert.py | 24 +- .../models/eurovbert/modeling_vbert.py | 348 +------ .../models/modernvbert/bivbert/__init__.py | 4 +- ...g_bivbert.py => modeling_bimodernvbert.py} | 2 +- ...bivbert.py => processing_bimodernvbert.py} | 0 .../colvbert/modeling_colmodernvbert.py | 2 +- .../models/modernvbert/modeling_vbert.py | 8 +- colpali_engine/models/vbert/__init__.py | 2 - .../models/vbert/bivbert/__init__.py | 2 - .../models/vbert/bivbert/modeling_bivbert.py | 65 -- .../vbert/bivbert/processing_bivbert.py | 51 - .../models/vbert/colvbert/__init__.py | 2 - .../vbert/colvbert/modeling_colvbert.py | 51 - .../vbert/colvbert/processing_colvbert.py | 96 -- .../models/vbert/configuration_vbert.py | 232 ----- colpali_engine/models/vbert/modeling_vbert.py | 930 ------------------ .../vllama/colvllama/processing_colvllama.py | 7 +- .../models/vllama/modeling_vllama.py | 8 +- colpali_engine/trainer/contrastive_trainer.py | 58 +- pyproject.toml | 2 +- 28 files changed, 134 insertions(+), 1920 deletions(-) rename colpali_engine/models/eurovbert/bivbert/{modeling_bivbert.py => modeling_bieurovbert.py} (84%) create mode 100644 colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py delete mode 100644 colpali_engine/models/eurovbert/bivbert/processing_bivbert.py rename colpali_engine/models/modernvbert/bivbert/{modeling_bivbert.py => modeling_bimodernvbert.py} (96%) rename colpali_engine/models/modernvbert/bivbert/{processing_bivbert.py => processing_bimodernvbert.py} (100%) delete mode 100644 colpali_engine/models/vbert/__init__.py delete mode 100644 colpali_engine/models/vbert/bivbert/__init__.py delete mode 100644 colpali_engine/models/vbert/bivbert/modeling_bivbert.py delete mode 100644 colpali_engine/models/vbert/bivbert/processing_bivbert.py delete mode 100644 colpali_engine/models/vbert/colvbert/__init__.py delete mode 100644 colpali_engine/models/vbert/colvbert/modeling_colvbert.py delete mode 100644 colpali_engine/models/vbert/colvbert/processing_colvbert.py delete mode 100644 colpali_engine/models/vbert/configuration_vbert.py delete mode 100644 colpali_engine/models/vbert/modeling_vbert.py diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 7090bc32..356220ae 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -3,6 +3,7 @@ from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor # from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor -from .eurovbert import ColEuroVBert, ColEuroVBertProcessor +from .eurovbert import BiEuroVBert, BiEuroVBertProcessor, ColEuroVBert, ColEuroVBertProcessor from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor from .siglip import BiSiglip, BiSiglipProcessor +from .vllama import BiVLlama, BiVLlamaProcessor diff --git a/colpali_engine/models/eurovbert/__init__.py b/colpali_engine/models/eurovbert/__init__.py index 84ab5f61..dc492dc5 100644 --- a/colpali_engine/models/eurovbert/__init__.py +++ b/colpali_engine/models/eurovbert/__init__.py @@ -1,2 +1,2 @@ -from .bivbert import BiVBert, BiVBertProcessor +from .bivbert import BiEuroVBert, BiEuroVBertProcessor from .colvbert import ColEuroVBert, ColEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/bivbert/__init__.py b/colpali_engine/models/eurovbert/bivbert/__init__.py index 23bc11e3..3d04309f 100644 --- a/colpali_engine/models/eurovbert/bivbert/__init__.py +++ b/colpali_engine/models/eurovbert/bivbert/__init__.py @@ -1,2 +1,2 @@ -from .modeling_bivbert import BiVBert -from .processing_bivbert import BiVBertProcessor +from .modeling_bieurovbert import BiEuroVBert +from .processing_bieurovbert import BiEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/bivbert/modeling_bivbert.py b/colpali_engine/models/eurovbert/bivbert/modeling_bieurovbert.py similarity index 84% rename from colpali_engine/models/eurovbert/bivbert/modeling_bivbert.py rename to colpali_engine/models/eurovbert/bivbert/modeling_bieurovbert.py index 33b7bd22..03e5dc1d 100644 --- a/colpali_engine/models/eurovbert/bivbert/modeling_bivbert.py +++ b/colpali_engine/models/eurovbert/bivbert/modeling_bieurovbert.py @@ -1,11 +1,13 @@ -from typing import Literal - +import os import torch -from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel +from typing import Literal, Union + +from colpali_engine.models.eurovbert.modeling_vbert import VBertModel, VBertPreTrainedModel +from colpali_engine.models.eurovbert.configuration_vbert import VBertConfig -class BiVBert(VBertPreTrainedModel): +class BiEuroVBert(VBertPreTrainedModel): """ Initializes the BiIdefics3 model. @@ -15,7 +17,6 @@ class BiVBert(VBertPreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True def __init__(self, config, pooling_strategy = "mean", **kwargs): super().__init__(config=config) @@ -51,8 +52,7 @@ def forward( pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) elif pooling_strategy == "last": # Use last token - last_unpadded_index = kwargs["attention_mask"].sum(dim=1) - 1 - pooled_output = last_hidden_states[:, last_unpadded_index.clamp(min=0)] # (batch_size, hidden_size) + pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) elif pooling_strategy == "mean": # Mean pooling over sequence length mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) diff --git a/colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py b/colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py new file mode 100644 index 00000000..7208b68f --- /dev/null +++ b/colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py @@ -0,0 +1,40 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.eurovbert.colvbert import ColEuroVBertProcessor + + +class BiEuroVBertProcessor(ColEuroVBertProcessor): # noqa: N801 + """ + Processor for BiVBert. + """ + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiModernVBert. + + Args: + texts: List of input texts. + + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py b/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py deleted file mode 100644 index b1606f94..00000000 --- a/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import List, Optional, Union - -import torch -from transformers import BatchEncoding, BatchFeature - -from colpali_engine.models.eurovbert.colvbert import ColEuroVBertProcessor - - -class BiVBertProcessor(ColEuroVBertProcessor): # noqa: N801 - """ - Processor for BiVBert. - """ - - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for BiVBert. - - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - suffix = self.query_augmentation_token # we remove buffer tokens - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] - - batch_texts = self( - text=prompts, - return_tensors="pt", - padding="longest", - ) - - return batch_texts - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the cosine similarity for the given query and passage embeddings. - """ - return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py b/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py index d7e14bcb..9cfb7709 100644 --- a/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py +++ b/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py @@ -1,6 +1,6 @@ from torch import nn -from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel +from colpali_engine.models.eurovbert.modeling_vbert import VBertModel, VBertPreTrainedModel class ColEuroVBert(VBertPreTrainedModel): diff --git a/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py b/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py index c4c78e7f..88e655a4 100644 --- a/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py +++ b/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py @@ -12,12 +12,13 @@ class ColEuroVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): Processor for ColIdefics3. """ - query_augmentation_token: ClassVar[str] = "<|end_of_text|>" + query_augmentation_token: ClassVar[str] = "" image_token: ClassVar[str] = "" visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.tokenizer.padding_side = "left" @property def image_token_id(self) -> int: @@ -26,56 +27,39 @@ def image_token_id(self) -> int: def process_images( self, images: List[Image.Image], - contexts: Optional[List[str]] = None, ) -> Union[BatchFeature, BatchEncoding]: """ - Process images for ColVBert. + Process images for ColEuroVBert. Args: images: List of PIL images. - contexts: List of optional context prompts, i.e. some text description of the context of the image. """ - # if contexts is None: - # contexts = [self.visual_prompt_prefix] * len(images) - contexts = [self.visual_prompt_prefix] * len(images) - images = [image.convert("RGB") for image in images] - + batch_doc = self( - text=contexts, + text=[self.visual_prompt_prefix] * len(images), images=images, padding="longest", return_tensors="pt", ) return batch_doc - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: """ - Process texts for ColVBert. + Process texts for ColEuroVBert. - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - suffix = self.query_augmentation_token * 10 - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] + Args: + texts: List of input texts. - batch_texts = self( - text=prompts, + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, return_tensors="pt", padding="longest", ) - return batch_texts - def score( self, qs: List[torch.Tensor], diff --git a/colpali_engine/models/eurovbert/configuration_vbert.py b/colpali_engine/models/eurovbert/configuration_vbert.py index 504f333b..f67eb54f 100644 --- a/colpali_engine/models/eurovbert/configuration_vbert.py +++ b/colpali_engine/models/eurovbert/configuration_vbert.py @@ -207,26 +207,4 @@ def to_dict(self): output["text_config"] = self.text_config.to_dict() # output["freeze_config"] = self.freeze_config.to_dict() - return output - - # @classmethod - # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - # outputs = super(VBertConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) - # return outputs - - @classmethod - def from_pretrained_models( - cls, - text_model_name: Union[str, os.PathLike], - vision_model_name: Union[str, os.PathLike], - **kwargs - ) -> "PretrainedConfig": - # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - text_model_config = VBertTextConfig(text_model_name) - vision_model_config = VBertVisionConfig(vision_model_name) - return cls( - text_config=text_model_config, - vision_config=vision_model_config, - **kwargs - ) + return output \ No newline at end of file diff --git a/colpali_engine/models/eurovbert/modeling_vbert.py b/colpali_engine/models/eurovbert/modeling_vbert.py index 3d681d69..5a8ca5e1 100644 --- a/colpali_engine/models/eurovbert/modeling_vbert.py +++ b/colpali_engine/models/eurovbert/modeling_vbert.py @@ -262,17 +262,20 @@ def __init__(self, config: VBertConfig, **kwargs): ) self.image_token_id = self.config.image_token_id + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.post_init() @staticmethod def init_vision_model(config: VBertConfig, **kwargs): vision_model_config = AutoConfig.from_pretrained( config.vision_config.vision_model_name, - trust_remote_code=True, + _attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, **kwargs, ) - vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) + vision_model = AutoModel.from_config(vision_model_config,**kwargs) if hasattr(vision_model, "vision_model"): # If the model has a vision_model attribute, it means it's a wrapper around another model @@ -284,12 +287,17 @@ def init_vision_model(config: VBertConfig, **kwargs): def init_language_model(config: VBertConfig, **kwargs): text_model_config = AutoConfig.from_pretrained( config.text_config.text_model_name, + _attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, trust_remote_code=True, **kwargs, ) - text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) - # extractor = regex_lookup(language_model_name, language_model_name2model) + text_model = AutoModel.from_config( + text_model_config, + trust_remote_code=True, + **kwargs + ) embed_layer = DecoupledEmbedding( num_embeddings=text_model_config.vocab_size, @@ -383,25 +391,16 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.training and self.text_model.gradient_checkpointing and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # retrieve input_ids and inputs_embeds if input_ids is not None: batch_size, seq_length = input_ids.shape @@ -410,15 +409,6 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_seen_tokens = 0 - if use_cache: - if past_key_values is None: - past_key_values = DynamicCache() - past_seen_tokens = past_key_values.get_seq_length() - - if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") - if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) @@ -487,9 +477,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - # past_key_values=past_key_values, - # use_cache=use_cache, - # cache_position=cache_position, ) if not return_dict: @@ -588,7 +575,6 @@ def forward( pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, image_hidden_states=image_hidden_states, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -619,312 +605,4 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, - ) - - # @classmethod - # def from_pretrained_models( - # cls, - # text_model_name, - # vision_model_name, - # vl_config, - # *args, - # **kwargs - # ): - # """ - # Use this method when creating a new vloom model that hasn't been yet trained and it'll be - # composed of 2 pre-trained models - hence `pretrained_models`. - # """ - # model = super().from_pretrained_models( - # text_model_name=text_model_name, - # vision_model_name=vision_model_name, - # vl_config=vl_config, - # *args, - # **kwargs - # ) - # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - # # fetch the pretrained text model w/o zero.Init - # pretrained_lm_head = AutoModelForMaskedLM.from_pretrained( - # text_model_name, trust_remote_code=True, **kwargs - # ).lm_head - - # # Load the lm_head - # load_state_dict_into_model(model.lm_head, pretrained_lm_head.state_dict(), start_prefix="") - - # return model - -class VModernBertLMHead(nn.Module): - def __init__(self, config, **kwargs): - super().__init__() - pretrained_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) - - self.head = pretrained_model.head - self.decoder = pretrained_model.decoder - - def forward(self, hidden_states): - hidden_states = self.head(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - # @classmethod - # def from_pretrained( - # cls, - # text_model_name, - # vl_config, - # *args, - # **kwargs - # ): - # """ - # Use this method when creating a new vloom model that hasn't been yet trained and it'll be - # composed of 2 pre-trained models - hence `pretrained_models`. - # """ - # lm_head = cls(vl_config, *args, **kwargs) - - # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - # # fetch the pretrained text model w/o zero.Init - # pretrained_model = AutoModelForMaskedLM.from_pretrained( - # text_model_name, trust_remote_code=True, **kwargs - # ) - - # pretrained_head = pretrained_model.head - # pretrained_decoder = pretrained_model.decoder - - # # Load the head - # load_state_dict_into_model(lm_head.head, pretrained_head.state_dict(), start_prefix="") - # load_state_dict_into_model(lm_head.decoder, pretrained_decoder.state_dict(), start_prefix="") - - # return lm_head - -class VModernBertForMaskedLM(VBertPreTrainedModel): - # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] - - def __init__(self, config, **kwargs): - super().__init__(config) - - self.image_token_id = config.image_token_id - self.in_features = config.hidden_size - self.out_additional_features = config.additional_vocab_size - self.vocab_size = config.vocab_size - - if config.is_decoder: - logger.warning( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.model = VBertModel(config, **kwargs) - self.lm_head = VModernBertLMHead(config, **kwargs) - - if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, VBertMaskedLMOutput]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - # Pass the inputs to VBertModel - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Pass the outputs to the MLM head - hidden_states = outputs[0] - - logits = self.lm_head(hidden_states) - if self.out_additional_features > 0: - proj_states = self.lm_head.head(hidden_states) - additional_features = self.additional_fc(proj_states) - logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - masked_lm_loss = None - if labels is not None: - # print the ratio of not ignored tokens - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return VBertMaskedLMOutput( - loss=masked_lm_loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None): - config_vl_model = self.config - - lm_config = config_vl_model.text_config - - language_embed_size = lm_config.hidden_size - num_language_layers = lm_config.num_hidden_layers - ffn_inner_size = lm_config.intermediate_size - - vision_config = config_vl_model.vision_config - - # Get vision model blocks infos - vision_patch_size = vision_config.patch_size - vision_hidden_size = vision_config.embed_dim - num_vision_layers = vision_config.num_hidden_layers - # The +1 is for the CLS token - single_image_vision_encoder_seq_len = int(((vision_config.image_size // vision_patch_size) ** 2) // (self.config.pixel_shuffle_factor**2)) - vision_exp_factor = vision_config.intermediate_size // vision_hidden_size - - # Get language blocks infos - language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len - language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4 - - # Get modality projection infos - vision_pipeline_output_seq_len = ( - self.config.perceiver_config.resampler_n_latents - if self.config.use_resampler - else single_image_vision_encoder_seq_len - ) - - language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( - num_layers=num_language_layers, - batch_size=hparams.batch_size_per_gpu, - q_seq_len=language_seq_len, - k_seq_len=language_seq_len, - hidden_size=language_embed_size, - kv_in_dim=language_embed_size, - ff_exp_factor=language_exp_factor, - grad_acc_size=hparams.grad_acc_size, - swiglu=True, - vocab_size=tokenizer.vocab_size, - count_backward=True, # Always True regardless of freezing, because gradients are computed for vision adaptor - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu( - batch_size=hparams.batch_size_per_gpu * max_num_images, - seq_len=vision_pipeline_output_seq_len, - in_features=vision_hidden_size, - out_features=language_embed_size, - count_backward=True, - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - - vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( - num_layers=num_vision_layers, - batch_size=hparams.batch_size_per_gpu * max_num_images, - q_seq_len=single_image_vision_encoder_seq_len, - k_seq_len=single_image_vision_encoder_seq_len, - hidden_size=vision_hidden_size, - kv_in_dim=vision_hidden_size, - ff_exp_factor=vision_exp_factor, - grad_acc_size=hparams.grad_acc_size, - swiglu=False, - vocab_size=None, - count_backward=not hparams.model_config["freeze_config"]["freeze_vision_layers"], - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - if self.config.use_resampler: - perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu( - num_layers=self.config.perceiver_config.resampler_depth, - batch_size=hparams.batch_size_per_gpu * max_num_images, - q_seq_len=self.config.perceiver_config.resampler_n_latents, - vision_embed_seq_len=single_image_vision_encoder_seq_len, - q_k_v_input_dim=vision_hidden_size, - attention_hidden_size=self.config.perceiver_config.resampler_n_heads - * self.config.perceiver_config.resampler_head_dim, - ff_exp_factor=4, - count_backward=True, - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - tflop_count = ( - language_tflops_per_batch_per_gpu - + modality_projection_tflops_per_batch_per_gpu - + perceiver_tflops_per_batch_per_gpu - + vision_tflops_per_batch_per_gpu - ) - else: - tflop_count = ( - language_tflops_per_batch_per_gpu - + modality_projection_tflops_per_batch_per_gpu - + vision_tflops_per_batch_per_gpu - ) - return tflop_count - - @classmethod - def from_pretrained_models( - cls, - text_model_name, - vision_model_name, - vl_config, - *args, - **kwargs - ): - """ - Use this method when creating a new vloom model that hasn't been yet trained and it'll be - composed of 2 pre-trained models - hence `pretrained_models`. - """ - model = super().from_pretrained_models( - text_model_name=text_model_name, - vision_model_name=vision_model_name, - vl_config=vl_config, - *args, - **kwargs - ) - - # Load the lm_head - model.lm_head = VModernBertLMHead.from_pretrained( - text_model_name=text_model_name, - vl_config=vl_config, - *args, - **kwargs - ) - - return model + ) \ No newline at end of file diff --git a/colpali_engine/models/modernvbert/bivbert/__init__.py b/colpali_engine/models/modernvbert/bivbert/__init__.py index 46514eda..e6098099 100644 --- a/colpali_engine/models/modernvbert/bivbert/__init__.py +++ b/colpali_engine/models/modernvbert/bivbert/__init__.py @@ -1,2 +1,2 @@ -from .modeling_bivbert import BiModernVBert -from .processing_bivbert import BiModernVBertProcessor +from .modeling_bimodernvbert import BiModernVBert +from .processing_bimodernvbert import BiModernVBertProcessor diff --git a/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py b/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py similarity index 96% rename from colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py rename to colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py index bf73c421..30a6f86f 100644 --- a/colpali_engine/models/modernvbert/bivbert/modeling_bivbert.py +++ b/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py @@ -2,7 +2,7 @@ import torch -from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel +from colpali_engine.models.modernvbert.modeling_vbert import VBertModel, VBertPreTrainedModel class BiModernVBert(VBertPreTrainedModel): diff --git a/colpali_engine/models/modernvbert/bivbert/processing_bivbert.py b/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py similarity index 100% rename from colpali_engine/models/modernvbert/bivbert/processing_bivbert.py rename to colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py diff --git a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py index 457ecb50..a1dd11a1 100644 --- a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py @@ -1,7 +1,7 @@ from torch import nn import torch -from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel +from colpali_engine.models.modernvbert.modeling_vbert import VBertModel, VBertPreTrainedModel class ColModernVBert(VBertPreTrainedModel): diff --git a/colpali_engine/models/modernvbert/modeling_vbert.py b/colpali_engine/models/modernvbert/modeling_vbert.py index 828a35e6..fe7903e0 100644 --- a/colpali_engine/models/modernvbert/modeling_vbert.py +++ b/colpali_engine/models/modernvbert/modeling_vbert.py @@ -263,13 +263,16 @@ def __init__(self, config: VBertConfig, **kwargs): ) self.image_token_id = self.config.image_token_id + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.post_init() @staticmethod def init_vision_model(config: VBertConfig, **kwargs): vision_model_config = AutoConfig.from_pretrained( config.vision_config.vision_model_name, - trust_remote_code=True, + _attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, **kwargs, ) @@ -285,12 +288,13 @@ def init_vision_model(config: VBertConfig, **kwargs): def init_language_model(config: VBertConfig, **kwargs): text_model_config = AutoConfig.from_pretrained( config.text_config.text_model_name, + _attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, trust_remote_code=True, **kwargs, ) text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) - # extractor = regex_lookup(language_model_name, language_model_name2model) embed_layer = DecoupledEmbedding( num_embeddings=text_model_config.vocab_size, diff --git a/colpali_engine/models/vbert/__init__.py b/colpali_engine/models/vbert/__init__.py deleted file mode 100644 index 064334ea..00000000 --- a/colpali_engine/models/vbert/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .bivbert import BiVBert, BiVBertProcessor -from .colvbert import ColVBert, ColVBertProcessor diff --git a/colpali_engine/models/vbert/bivbert/__init__.py b/colpali_engine/models/vbert/bivbert/__init__.py deleted file mode 100644 index 23bc11e3..00000000 --- a/colpali_engine/models/vbert/bivbert/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_bivbert import BiVBert -from .processing_bivbert import BiVBertProcessor diff --git a/colpali_engine/models/vbert/bivbert/modeling_bivbert.py b/colpali_engine/models/vbert/bivbert/modeling_bivbert.py deleted file mode 100644 index 33b7bd22..00000000 --- a/colpali_engine/models/vbert/bivbert/modeling_bivbert.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Literal - -import torch - -from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel - - -class BiVBert(VBertPreTrainedModel): - """ - Initializes the BiIdefics3 model. - - Args: - config : The model configuration. - """ - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def __init__(self, config, pooling_strategy = "mean", **kwargs): - super().__init__(config=config) - self.model = VBertModel(config, **kwargs) - self.pooling_strategy = pooling_strategy - self.post_init() - - def forward( - self, - pooling_strategy: Literal["cls", "last", "mean"] = None, - *args, - **kwargs, - ) -> torch.Tensor: - """ - Forward pass through model and pooling. - - Args: - - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". - - input_ids (torch.LongTensor): The input tokens tensor. - - attention_mask (torch.LongTensor): The attention mask tensor. - - Returns: - - torch.Tensor: Embeddings of shape (batch_size, dim) - """ - outputs = self.model(*args, **kwargs) - last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - - pooling_strategy = pooling_strategy or self.pooling_strategy - - # Get CLS token embedding, last token, or mean pool over sequence - if pooling_strategy == "cls": - # Use CLS token (first token) embedding - pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) - elif pooling_strategy == "last": - # Use last token - last_unpadded_index = kwargs["attention_mask"].sum(dim=1) - 1 - pooled_output = last_hidden_states[:, last_unpadded_index.clamp(min=0)] # (batch_size, hidden_size) - elif pooling_strategy == "mean": - # Mean pooling over sequence length - mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) - pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) - else: - raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") - - # L2 normalization - pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True) - return pooled_output diff --git a/colpali_engine/models/vbert/bivbert/processing_bivbert.py b/colpali_engine/models/vbert/bivbert/processing_bivbert.py deleted file mode 100644 index 1ddc12bf..00000000 --- a/colpali_engine/models/vbert/bivbert/processing_bivbert.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import List, Optional, Union - -import torch -from transformers import BatchEncoding, BatchFeature - -from colpali_engine.models.vbert.colvbert import ColVBertProcessor # noqa: N801 - - -class BiVBertProcessor(ColVBertProcessor): # noqa: N801 - """ - Processor for BiVBert. - """ - - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for BiVBert. - - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - suffix = self.query_augmentation_token # we remove buffer tokens - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] - - batch_texts = self( - text=prompts, - return_tensors="pt", - padding="longest", - ) - - return batch_texts - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the cosine similarity for the given query and passage embeddings. - """ - return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/vbert/colvbert/__init__.py b/colpali_engine/models/vbert/colvbert/__init__.py deleted file mode 100644 index 2d05a989..00000000 --- a/colpali_engine/models/vbert/colvbert/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_colvbert import ColVBert -from .processing_colvbert import ColVBertProcessor diff --git a/colpali_engine/models/vbert/colvbert/modeling_colvbert.py b/colpali_engine/models/vbert/colvbert/modeling_colvbert.py deleted file mode 100644 index dd0c68c7..00000000 --- a/colpali_engine/models/vbert/colvbert/modeling_colvbert.py +++ /dev/null @@ -1,51 +0,0 @@ -from torch import nn - -from colpali_engine.models.vbert.modeling_vbert import VBertModel, VBertPreTrainedModel - - -class ColVBert(VBertPreTrainedModel): - """ - Initializes the ColVBert model. - - Args: - config : The model configuration. - mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings - except those of the image at inference. - Defaults to False --> Do not mask any embeddings during forward pass. - """ - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): - super().__init__(config=config) - self.model = VBertModel(config, **kwargs) - self.dim = 128 - self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) - self.mask_non_image_embeddings = mask_non_image_embeddings - self.main_input_name = "doc_input_ids" - - def forward(self, *args, **kwargs): - """ - Forward pass through the model and the linear layer for dimensionality reduction - - Args: - - input_ids (torch.LongTensor): The input tokens tensor. - - attention_mask (torch.LongTensor): The attention mask tensor. - - Returns: - - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) - """ - outputs = self.model(*args, **kwargs) - last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - proj = self.custom_text_proj(last_hidden_states) - # normalize l2 norm - proj = proj / proj.norm(dim=-1, keepdim=True) - proj = proj * kwargs["attention_mask"].unsqueeze(-1) - - if "pixel_values" in kwargs and self.mask_non_image_embeddings: - # Pools only the image embeddings - image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) - proj = proj * image_mask - return proj diff --git a/colpali_engine/models/vbert/colvbert/processing_colvbert.py b/colpali_engine/models/vbert/colvbert/processing_colvbert.py deleted file mode 100644 index cb9c96f2..00000000 --- a/colpali_engine/models/vbert/colvbert/processing_colvbert.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import ClassVar, List, Optional, Tuple, Union - -import torch -from PIL import Image -from transformers import BatchEncoding, BatchFeature, Idefics3Processor - -from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor - - -class ColVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): - """ - Processor for ColIdefics3. - """ - - query_augmentation_token: ClassVar[str] = "<|end_of_text|>" - image_token: ClassVar[str] = "" - visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @property - def image_token_id(self) -> int: - return self.tokenizer.convert_tokens_to_ids(self.image_token) - - def process_images( - self, - images: List[Image.Image], - contexts: Optional[List[str]] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process images for ColVBert. - - Args: - images: List of PIL images. - contexts: List of optional context prompts, i.e. some text description of the context of the image. - """ - # if contexts is None: - # contexts = [self.visual_prompt_prefix] * len(images) - contexts = [self.visual_prompt_prefix] * len(images) - - images = [image.convert("RGB") for image in images] - - batch_doc = self( - text=contexts, - images=images, - padding="longest", - return_tensors="pt", - ) - return batch_doc - - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for ColVBert. - - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - suffix = self.query_augmentation_token * 10 - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] - - batch_texts = self( - text=prompts, - return_tensors="pt", - padding="longest", - ) - - return batch_texts - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. - """ - return self.score_multi_vector(qs, ps, device=device, **kwargs) - - def get_n_patches( - self, - image_size: Tuple[int, int], - patch_size: int, - ) -> Tuple[int, int]: - raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/vbert/configuration_vbert.py b/colpali_engine/models/vbert/configuration_vbert.py deleted file mode 100644 index 504f333b..00000000 --- a/colpali_engine/models/vbert/configuration_vbert.py +++ /dev/null @@ -1,232 +0,0 @@ -import copy -import os -from typing import Any, Dict, Union - -from transformers import AutoConfig -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -def collect_arg_in_candidates(config, candidates, default = None) -> Any: - """ Gets the argument in a config given a list of candidates """ - for c in candidates: - if hasattr(config, c): - return getattr(config, c) - elif c in config: - return config[c] - if default is not None: - return default - raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) - -class VBertTextConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "vbert" - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - text_model_name="EuroBERT/EuroBERT-210m", - **kwargs, - ): - self.text_model_name = text_model_name - text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - if hasattr(text_config, "text_config"): - text_config = text_config.text_config - - self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) - self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) - self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) - self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) - - super().__init__(text_model_name=text_model_name, **kwargs) - -class VBertVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "vbert" - attribute_map = { - "hidden_size": "embed_dim", - } - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - vision_model_name="google/siglip2-base-patch16-512", - **kwargs, - ): - self.vision_model_name = vision_model_name - vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - if hasattr(vision_config, "vision_config"): - vision_config = vision_config.vision_config - - self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) - self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) - self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) - self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) - - super().__init__(vision_model_name=vision_model_name, **kwargs) - -class VBertConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a - SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM - [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should cache the key/value pairs of the attention mechanism. Only - relevant if `config.is_decoder=True`. - image_token_id (`int`, *optional*, defaults to 128257): - The id of the "image" token. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to tie the word embeddings with the token embeddings. - vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): - Custom vision config or dict for the vision tower - text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): - Custom text config or dict for the text model - scale_factor (`int`, *optional*, defaults to 2): - The scale factor for the image encoder. - pad_token_id (`int`, *optional*, defaults to 128002): - The id of the padding token. - - Example: - ```python - >>> from transformers import SmolVLMModel, SmolVLMConfig - >>> # Initializing configuration - >>> configuration = SmolVLMConfig() - >>> # Initializing a model from the configuration - >>> model = SmolVLMModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "vbert" - is_composition = True - # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig} - - DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" - DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" - - def __init__( - self, - text_config: Union[PretrainedConfig, Dict[str, Any]] = None, - vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, - image_token_id: int = 128_257, - vocab_size=128_256, - use_cache = True, - tie_word_embeddings = False, - freeze_config = None, - pad_token_id = None, - initializer_range = 0.02, - pixel_shuffle_factor = 4, - use_resampler = False, - additional_vocab_size = 0, - neftune_noise_alpha = 0.0, - **kwargs, - ): - self.image_token_id = image_token_id - self.use_cache = use_cache - self.tie_word_embeddings = tie_word_embeddings - self.scale_factor = pixel_shuffle_factor - self.additional_vocab_size = additional_vocab_size - - if text_config is None: - text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) - elif isinstance(text_config, dict): - text_config = VBertTextConfig(text_config["text_model_name"]) - self.text_config = text_config - - if vision_config is None: - vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) - elif isinstance(vision_config, dict): - vision_config = VBertVisionConfig(vision_config["vision_model_name"]) - self.vision_config = vision_config - - self.freeze_config = freeze_config - - # Pixel shuffle factor - self.pixel_shuffle_factor = pixel_shuffle_factor - self.use_resampler = use_resampler - - self.neftune_noise_alpha = neftune_noise_alpha - - self.initializer_range = initializer_range - - hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) - - super().__init__( - **kwargs, - pad_token_id=pad_token_id, - tie_word_embeddings=tie_word_embeddings, - vocab_size=vocab_size, - hidden_size=hidden_size, - ) - - def to_dict(self): - """ - Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. - Returns: - `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, - """ - output = copy.deepcopy(self.__dict__) - - output["model_type"] = self.__class__.model_type - output["vision_config"] = self.vision_config.to_dict() - output["text_config"] = self.text_config.to_dict() - # output["freeze_config"] = self.freeze_config.to_dict() - - return output - - # @classmethod - # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - # outputs = super(VBertConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) - # return outputs - - @classmethod - def from_pretrained_models( - cls, - text_model_name: Union[str, os.PathLike], - vision_model_name: Union[str, os.PathLike], - **kwargs - ) -> "PretrainedConfig": - # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - text_model_config = VBertTextConfig(text_model_name) - vision_model_config = VBertVisionConfig(vision_model_name) - return cls( - text_config=text_model_config, - vision_config=vision_model_config, - **kwargs - ) diff --git a/colpali_engine/models/vbert/modeling_vbert.py b/colpali_engine/models/vbert/modeling_vbert.py deleted file mode 100644 index 3d681d69..00000000 --- a/colpali_engine/models/vbert/modeling_vbert.py +++ /dev/null @@ -1,930 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss -from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging -from transformers.cache_utils import DynamicCache -from transformers.modeling_outputs import BaseModelOutput -from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput - -from .configuration_vbert import VBertConfig - -logger = logging.get_logger(__name__) - - -class DecoupledEmbedding(nn.Embedding): - # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding - """ - Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. - In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. - If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. - """ - - def __init__( - self, - num_embeddings, - num_additional_embeddings, - embedding_dim, - partially_freeze=False, - device=None, - dtype=None, - padding_idx=None, - **kwargs, - ) -> None: - """ - num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. - partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. - - Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. - """ - if padding_idx is not None and padding_idx > num_embeddings: - raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") - super().__init__( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - padding_idx=padding_idx, - **kwargs, - ) - self.num_embeddings = num_embeddings - self.padding_idx = padding_idx - self.num_additional_embeddings = num_additional_embeddings - self.partially_freeze = partially_freeze - - if partially_freeze: - self.weight.requires_grad_(False) - - if self.num_additional_embeddings > 0: - self.additional_embedding = nn.Embedding( - num_embeddings=self.num_additional_embeddings, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - ) - - def forward(self, input_ids): - """ - we have 2 embeddings, with different indices - one pretrained self.weight and another - self.additional_embedding.weight that is being trained. - - in order to make a lookup of the input ids, we: - 1. find out the indices of the entries belonging to the 2nd embedding - 2. extract those values while subtracting the size of the first embedding (num_embeddings), - since the 2nd embedding starts from 0 and not num_embeddings - 3. perform the 2nd embedding lookup - 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index - 5. perform the 1st embedding lookup - 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup - - note: for the 1st embedding lookup we could have looked up only the low indices and not do - the padding, but then we have to create a new tensor and populate it with 2 tensors that are - spread out across various indices - i.e. not a simple concat - I haven't benchmarked the - complex case if it's any faster, given that seqlens are usually relatively short it's - probably not faster or if faster not by much - but might be a good idea to measure. - - """ - if self.num_additional_embeddings == 0: - return self.additional_embedding(input_ids) - - # Clone so that we don't modify the original input_ids later on - input_ids = input_ids.clone() - additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) - input_ids_additional_vocab = input_ids[additional_vocab_indices] - additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) - - # for successful lookup replace input_ids with 0, the results of these will be discarded anyway - input_ids[additional_vocab_indices] = 0 - full_vector = F.embedding(input_ids, self.weight) - - # overwrite the records with high indices - full_vector[additional_vocab_indices] = additional_embeddings - - return full_vector - - def extra_repr(self) -> str: - return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( - self.num_embeddings, - self.num_additional_embeddings, - self.embedding_dim, - self.partially_freeze, - ) - -@dataclass -class VBertBaseModelOutput(BaseModelOutput): - """ - Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - -@dataclass -class VBertMaskedLMOutput(MaskedLMOutput): - """ - Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). - Args: - loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. - logits (`torch.FloatTensor`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder - """ - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - image_hidden_states: Optional[torch.FloatTensor] = None - -class VBertSimpleMLP(nn.Module): - def __init__(self, input_size, output_size): - super().__init__() - self.proj = nn.Linear(input_size, output_size, bias=False) - - def forward(self, x): - return self.proj(x) - -class VBertConnector(nn.Module): - def __init__(self, config): - super().__init__() - self.scale_factor = config.pixel_shuffle_factor - self.modality_projection = VBertSimpleMLP( - input_size=config.vision_config.hidden_size * (config.scale_factor**2), - output_size=config.text_config.hidden_size - ) - - def pixel_shuffle(self, x, scale_factor): - bsz, seq, embed_dim = x.size() - height = width = int(seq**0.5) - x = x.view(bsz, height, width, embed_dim) - x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) - x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) - x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) - return x - - def forward(self, image_hidden_states): - image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) - image_hidden_states = self.modality_projection(image_hidden_states) - return image_hidden_states - -class VBertPreTrainedModel(PreTrainedModel): - config_class = VBertConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["VBertDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - """Initialize the weights.""" - - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - -class VBertModel(VBertPreTrainedModel): - """ - A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger - in forward. Instead, we override inputs_merger here with custom logic. - """ - - def __init__(self, config: VBertConfig, **kwargs): - super().__init__(config) - - self.vision_model = VBertModel.init_vision_model(config, **kwargs) - self.connector = VBertConnector(config) - self.text_model = VBertModel.init_language_model(config, **kwargs) - - self.image_seq_len = int( - ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) - ) - self.image_token_id = self.config.image_token_id - - self.post_init() - - @staticmethod - def init_vision_model(config: VBertConfig, **kwargs): - vision_model_config = AutoConfig.from_pretrained( - config.vision_config.vision_model_name, - trust_remote_code=True, - **kwargs, - ) - - vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) - - if hasattr(vision_model, "vision_model"): - # If the model has a vision_model attribute, it means it's a wrapper around another model - vision_model = vision_model.vision_model - - return vision_model - - @staticmethod - def init_language_model(config: VBertConfig, **kwargs): - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - - text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) - # extractor = regex_lookup(language_model_name, language_model_name2model) - - embed_layer = DecoupledEmbedding( - num_embeddings=text_model_config.vocab_size, - num_additional_embeddings=config.additional_vocab_size, - embedding_dim=config.hidden_size, - partially_freeze=config.freeze_config["freeze_text_layers"], - padding_idx=config.pad_token_id, - ) - - text_model.set_input_embeddings(embed_layer) - - return text_model - - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - - def get_input_embeddings(self): - return self.text_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.text_model.set_input_embeddings(value) - - def inputs_merger( - self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor - ): - """ - This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. - The merging happens as follows: - - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. - - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. - We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. - - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. - """ - _, patch_size, _ = image_hidden_states.shape - - image_mask = input_ids == self.image_token_id - num_image_tokens = image_mask.sum(dim=1) - if not torch.all(num_image_tokens % patch_size == 0): - raise ValueError("At least one sample has tokens not divisible by patch_size.") - - blocks_per_sample = num_image_tokens // patch_size - - offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) - block_offset = offsets[:-1] - row_cum = image_mask.cumsum(dim=-1) - chunk_idx = (row_cum - 1) // patch_size - local_idx = (row_cum - 1) % patch_size - block_idx = block_offset.unsqueeze(1) + chunk_idx - - image_embeds = torch.zeros_like(inputs_embeds) - image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] - - merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) - return merged_embeds - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.training and self.text_model.gradient_checkpointing and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # retrieve input_ids and inputs_embeds - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_seen_tokens = 0 - if use_cache: - if past_key_values is None: - past_key_values = DynamicCache() - past_seen_tokens = past_key_values.get_seq_length() - - if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") - - if inputs_embeds is None: - inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) - - # START VISUAL INPUTS INTEGRATION - if pixel_values is not None and image_hidden_states is not None: - raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") - elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values - pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - - if not any(real_images_inds): - # no images, leave one empty image. - real_images_inds[0] = True - - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=[pixel_values.shape[i] for i in (0, 2, 3)], - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - # patch_size = self.config.vision_config.patch_size - # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - # patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling - image_hidden_states = self.connector(image_hidden_states) - - elif image_hidden_states is not None: - image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - - if inputs_embeds is not None and image_hidden_states is not None: - # When we embed, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self.inputs_merger( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - image_hidden_states=image_hidden_states, - ) - - outputs = self.text_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - # past_key_values=past_key_values, - # use_cache=use_cache, - # cache_position=cache_position, - ) - - if not return_dict: - return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - - return VBertBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_hidden_states, - ) - -class VBertForMaskedLM(VBertPreTrainedModel): - # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] - - def __init__(self, config, **kwargs): - super().__init__(config) - - self.image_token_id = config.image_token_id - self.in_features = config.hidden_size - self.out_additional_features = config.additional_vocab_size - self.vocab_size = config.vocab_size - - if config.is_decoder: - logger.warning( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.model = VBertModel(config, **kwargs) - self.lm_head = VBertForMaskedLM.init_lm_head(config, **kwargs) - if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - # Initialize weights and apply final processing - self.post_init() - - @staticmethod - def init_lm_head(config, **kwargs): - # Get the pretrained model config - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) - # Get the lm head - lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None - if lm_head is None: - logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") - lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) - return lm_head - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, VBertMaskedLMOutput]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - # Pass the inputs to VBertModel - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Pass the outputs to the MLM head - hidden_states = outputs[0] - - logits = self.lm_head(hidden_states) - if self.out_additional_features > 0: - additional_features = self.additional_fc(hidden_states) - logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - masked_lm_loss = None - if labels is not None: - # print the ratio of not ignored tokens - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return VBertMaskedLMOutput( - loss=masked_lm_loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - # @classmethod - # def from_pretrained_models( - # cls, - # text_model_name, - # vision_model_name, - # vl_config, - # *args, - # **kwargs - # ): - # """ - # Use this method when creating a new vloom model that hasn't been yet trained and it'll be - # composed of 2 pre-trained models - hence `pretrained_models`. - # """ - # model = super().from_pretrained_models( - # text_model_name=text_model_name, - # vision_model_name=vision_model_name, - # vl_config=vl_config, - # *args, - # **kwargs - # ) - # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - # # fetch the pretrained text model w/o zero.Init - # pretrained_lm_head = AutoModelForMaskedLM.from_pretrained( - # text_model_name, trust_remote_code=True, **kwargs - # ).lm_head - - # # Load the lm_head - # load_state_dict_into_model(model.lm_head, pretrained_lm_head.state_dict(), start_prefix="") - - # return model - -class VModernBertLMHead(nn.Module): - def __init__(self, config, **kwargs): - super().__init__() - pretrained_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) - - self.head = pretrained_model.head - self.decoder = pretrained_model.decoder - - def forward(self, hidden_states): - hidden_states = self.head(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - # @classmethod - # def from_pretrained( - # cls, - # text_model_name, - # vl_config, - # *args, - # **kwargs - # ): - # """ - # Use this method when creating a new vloom model that hasn't been yet trained and it'll be - # composed of 2 pre-trained models - hence `pretrained_models`. - # """ - # lm_head = cls(vl_config, *args, **kwargs) - - # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - # # fetch the pretrained text model w/o zero.Init - # pretrained_model = AutoModelForMaskedLM.from_pretrained( - # text_model_name, trust_remote_code=True, **kwargs - # ) - - # pretrained_head = pretrained_model.head - # pretrained_decoder = pretrained_model.decoder - - # # Load the head - # load_state_dict_into_model(lm_head.head, pretrained_head.state_dict(), start_prefix="") - # load_state_dict_into_model(lm_head.decoder, pretrained_decoder.state_dict(), start_prefix="") - - # return lm_head - -class VModernBertForMaskedLM(VBertPreTrainedModel): - # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] - - def __init__(self, config, **kwargs): - super().__init__(config) - - self.image_token_id = config.image_token_id - self.in_features = config.hidden_size - self.out_additional_features = config.additional_vocab_size - self.vocab_size = config.vocab_size - - if config.is_decoder: - logger.warning( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.model = VBertModel(config, **kwargs) - self.lm_head = VModernBertLMHead(config, **kwargs) - - if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, VBertMaskedLMOutput]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - # Pass the inputs to VBertModel - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Pass the outputs to the MLM head - hidden_states = outputs[0] - - logits = self.lm_head(hidden_states) - if self.out_additional_features > 0: - proj_states = self.lm_head.head(hidden_states) - additional_features = self.additional_fc(proj_states) - logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - masked_lm_loss = None - if labels is not None: - # print the ratio of not ignored tokens - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return VBertMaskedLMOutput( - loss=masked_lm_loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None): - config_vl_model = self.config - - lm_config = config_vl_model.text_config - - language_embed_size = lm_config.hidden_size - num_language_layers = lm_config.num_hidden_layers - ffn_inner_size = lm_config.intermediate_size - - vision_config = config_vl_model.vision_config - - # Get vision model blocks infos - vision_patch_size = vision_config.patch_size - vision_hidden_size = vision_config.embed_dim - num_vision_layers = vision_config.num_hidden_layers - # The +1 is for the CLS token - single_image_vision_encoder_seq_len = int(((vision_config.image_size // vision_patch_size) ** 2) // (self.config.pixel_shuffle_factor**2)) - vision_exp_factor = vision_config.intermediate_size // vision_hidden_size - - # Get language blocks infos - language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len - language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4 - - # Get modality projection infos - vision_pipeline_output_seq_len = ( - self.config.perceiver_config.resampler_n_latents - if self.config.use_resampler - else single_image_vision_encoder_seq_len - ) - - language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( - num_layers=num_language_layers, - batch_size=hparams.batch_size_per_gpu, - q_seq_len=language_seq_len, - k_seq_len=language_seq_len, - hidden_size=language_embed_size, - kv_in_dim=language_embed_size, - ff_exp_factor=language_exp_factor, - grad_acc_size=hparams.grad_acc_size, - swiglu=True, - vocab_size=tokenizer.vocab_size, - count_backward=True, # Always True regardless of freezing, because gradients are computed for vision adaptor - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu( - batch_size=hparams.batch_size_per_gpu * max_num_images, - seq_len=vision_pipeline_output_seq_len, - in_features=vision_hidden_size, - out_features=language_embed_size, - count_backward=True, - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - - vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( - num_layers=num_vision_layers, - batch_size=hparams.batch_size_per_gpu * max_num_images, - q_seq_len=single_image_vision_encoder_seq_len, - k_seq_len=single_image_vision_encoder_seq_len, - hidden_size=vision_hidden_size, - kv_in_dim=vision_hidden_size, - ff_exp_factor=vision_exp_factor, - grad_acc_size=hparams.grad_acc_size, - swiglu=False, - vocab_size=None, - count_backward=not hparams.model_config["freeze_config"]["freeze_vision_layers"], - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - if self.config.use_resampler: - perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu( - num_layers=self.config.perceiver_config.resampler_depth, - batch_size=hparams.batch_size_per_gpu * max_num_images, - q_seq_len=self.config.perceiver_config.resampler_n_latents, - vision_embed_seq_len=single_image_vision_encoder_seq_len, - q_k_v_input_dim=vision_hidden_size, - attention_hidden_size=self.config.perceiver_config.resampler_n_heads - * self.config.perceiver_config.resampler_head_dim, - ff_exp_factor=4, - count_backward=True, - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - tflop_count = ( - language_tflops_per_batch_per_gpu - + modality_projection_tflops_per_batch_per_gpu - + perceiver_tflops_per_batch_per_gpu - + vision_tflops_per_batch_per_gpu - ) - else: - tflop_count = ( - language_tflops_per_batch_per_gpu - + modality_projection_tflops_per_batch_per_gpu - + vision_tflops_per_batch_per_gpu - ) - return tflop_count - - @classmethod - def from_pretrained_models( - cls, - text_model_name, - vision_model_name, - vl_config, - *args, - **kwargs - ): - """ - Use this method when creating a new vloom model that hasn't been yet trained and it'll be - composed of 2 pre-trained models - hence `pretrained_models`. - """ - model = super().from_pretrained_models( - text_model_name=text_model_name, - vision_model_name=vision_model_name, - vl_config=vl_config, - *args, - **kwargs - ) - - # Load the lm_head - model.lm_head = VModernBertLMHead.from_pretrained( - text_model_name=text_model_name, - vl_config=vl_config, - *args, - **kwargs - ) - - return model diff --git a/colpali_engine/models/vllama/colvllama/processing_colvllama.py b/colpali_engine/models/vllama/colvllama/processing_colvllama.py index 0be8d5f1..d983e075 100644 --- a/colpali_engine/models/vllama/colvllama/processing_colvllama.py +++ b/colpali_engine/models/vllama/colvllama/processing_colvllama.py @@ -12,16 +12,13 @@ class ColVLlamaProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): Processor for ColVLlama. """ - query_augmentation_token: ClassVar[str] = "<|end_of_text|>" + query_augmentation_token: ClassVar[str] = "" image_token: ClassVar[str] = "" visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:Describe the image.\nAssistant:" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - @property - def image_token_id(self) -> int: - return self.tokenizer.convert_tokens_to_ids(self.image_token) + self.tokenizer.padding_side = "left" def process_images( self, diff --git a/colpali_engine/models/vllama/modeling_vllama.py b/colpali_engine/models/vllama/modeling_vllama.py index b1fa576f..e1d9793e 100644 --- a/colpali_engine/models/vllama/modeling_vllama.py +++ b/colpali_engine/models/vllama/modeling_vllama.py @@ -271,13 +271,16 @@ def __init__(self, config: VLlamaConfig, **kwargs): ) self.image_token_id = self.config.image_token_id + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.post_init() @staticmethod def init_vision_model(config: VLlamaConfig, **kwargs): vision_model_config = AutoConfig.from_pretrained( config.vision_config.vision_model_name, - trust_remote_code=True, + _attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, **kwargs, ) @@ -293,12 +296,13 @@ def init_vision_model(config: VLlamaConfig, **kwargs): def init_language_model(config: VLlamaConfig, **kwargs): text_model_config = AutoConfig.from_pretrained( config.text_config.text_model_name, + _attn_implementation=config._attn_implementation, + torch_dtype=config.torch_dtype, trust_remote_code=True, **kwargs, ) text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) - # extractor = regex_lookup(language_model_name, language_model_name2model) embed_layer = DecoupledEmbedding( num_embeddings=text_model_config.vocab_size, diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 5d46a067..68a6c608 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -18,31 +18,40 @@ def concat_all_gather(t: torch.Tensor) -> torch.Tensor: return t +def concat_datasets(datasets: list[Dataset], batch_size: int) -> Dataset: + """ + Concatenates a list of datasets into a single dataset. + This is a utility function to handle the case where multiple datasets are provided. + """ + # round down each dataset if not divible by global batch size + for i in range(len(datasets)): + if len(datasets[i]) % batch_size != 0: + total_samples = (len(datasets[i]) // batch_size) * batch_size + datasets[i] = datasets[i].take(total_samples) + + return ConcatDataset(datasets) + + class ContrastiveTrainer(Trainer): def __init__(self, loss_func, is_vision_model, compute_symetric_loss=False, *args, **kwargs): - if isinstance(kwargs["train_dataset"], DatasetDict): - dataset_list = list(kwargs["train_dataset"].values()) - elif isinstance(kwargs["train_dataset"], list): - dataset_list = kwargs["train_dataset"] + if isinstance(kwargs["train_dataset"], list): + train_dataset_list = kwargs["train_dataset"] + kwargs["train_dataset"] = concat_datasets(train_dataset_list, batch_size=kwargs["args"].train_batch_size) else: - dataset_list = None - - if isinstance(dataset_list, list): - # round down each dataset if not divible by global batch size - batch_size = kwargs["args"].train_batch_size - for i in range(len(dataset_list)): - if len(dataset_list[i]) % batch_size != 0: - total_samples = (len(dataset_list[i]) // batch_size) * batch_size - dataset_list[i] = dataset_list[i].take(total_samples) + train_dataset_list = None - if dataset_list is not None: - kwargs["train_dataset"] = ConcatDataset(dataset_list) + if isinstance(kwargs["eval_dataset"], list): + eval_dataset_list = kwargs["eval_dataset"] + kwargs["eval_dataset"] = concat_datasets(eval_dataset_list) + else: + eval_dataset_list = None super().__init__(*args, **kwargs) self.loss_func = loss_func self.is_vision_model = is_vision_model # Unused argument, will be removed in 0.4.0 self.args.remove_unused_columns = False # Safety, don't remove dataset columns from dataloader - self.dataset_list = dataset_list + self.train_dataset_list = train_dataset_list + self.eval_dataset_list = eval_dataset_list self.compute_symetric_loss = compute_symetric_loss def get_train_dataloader(self) -> DataLoader: @@ -56,6 +65,10 @@ def get_train_dataloader(self) -> DataLoader: """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") + + if self.train_dataset_list is None: + # If no dataset list, use the default behavior + return super().get_train_dataloader() dataset = self.train_dataset description = "Training" @@ -64,9 +77,6 @@ def get_train_dataloader(self) -> DataLoader: is_training = True dataloader_key = None - if self.dataset_list is None: - return super()._get_dataloader(dataset, description, batch_size, sampler_fn, is_training, dataloader_key) - data_collator = self.data_collator if is_datasets_available() and isinstance(dataset, datasets.Dataset): dataset = self._remove_unused_columns(dataset, description=description) @@ -84,7 +94,7 @@ def get_train_dataloader(self) -> DataLoader: if not isinstance(dataset, torch.utils.data.IterableDataset): if sampler_fn is not None: ###### batch_sampler set instead of sampler in trainer code ####### - dataloader_params["batch_sampler"] = sampler_fn(dataset) + dataloader_params["batch_sampler"] = sampler_fn() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if is_training: @@ -104,9 +114,9 @@ def get_train_dataloader(self) -> DataLoader: return self.accelerator.prepare(dataloader) - def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: - if self.dataset_list is None: - return super()._get_train_sampler(train_dataset=train_dataset) + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset_list is None: + return super()._get_train_sampler() # Use SingleDatasetBatchSampler to ensure that each dataset in the list is sampled independently # Note: Surely breaks in distributed training @@ -114,7 +124,7 @@ def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optiona generator = torch.Generator() generator.manual_seed(self.args.seed) return SingleDatasetBatchSampler( - self.dataset_list, + self.train_dataset_list, self.args.train_batch_size, drop_last=self.args.dataloader_drop_last, generator=generator, diff --git a/pyproject.toml b/pyproject.toml index 74b5a372..7850ff53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "scipy", "torch>=2.5.0,<2.8.0", "torchvision", - "transformers>=4.53.1,<4.54.0", + "transformers>=4.51.1,<4.52.0" ] [project.optional-dependencies] From 91e9f363416b7b6c68771e87bd466a8a042c3642 Mon Sep 17 00:00:00 2001 From: Paul Teiletche Date: Tue, 19 Aug 2025 15:14:50 +0200 Subject: [PATCH 15/42] modeling --- colpali_engine/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 356220ae..a9511d37 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -6,4 +6,4 @@ from .eurovbert import BiEuroVBert, BiEuroVBertProcessor, ColEuroVBert, ColEuroVBertProcessor from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor from .siglip import BiSiglip, BiSiglipProcessor -from .vllama import BiVLlama, BiVLlamaProcessor +from .vllama import BiVLlama, BiVLlamaProcessor, ColVLlama, ColVLlamaProcessor From 91ba4be300ec7fd0c1e947132d617363cfa13831 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Fri, 8 Aug 2025 15:21:41 +0200 Subject: [PATCH 16/42] f From 81eef805e7d74a2e5ed2d17ca5ef49903629f3ea Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Fri, 8 Aug 2025 15:24:20 +0200 Subject: [PATCH 17/42] rebase From 9a82c1f0463516e444b613b17c72dd51c15f3113 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Fri, 8 Aug 2025 15:25:25 +0200 Subject: [PATCH 18/42] rebase From 245bb3385abfe6a7b76a5b809109b19e83ef114b Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Fri, 8 Aug 2025 15:26:47 +0200 Subject: [PATCH 19/42] rebase From 2ebe2ab5c810d7b8f1cd48130d5c250f784188b7 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Wed, 30 Jul 2025 10:42:54 +0200 Subject: [PATCH 20/42] symetric loss + flex biencodr score --- colpali_engine/utils/processing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 46dd6857..4c1d9617 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -22,7 +22,7 @@ class BaseVisualRetrieverProcessor(ABC): Base class for visual retriever processors. """ - query_prefix: ClassVar[str] = "Query: " # Default prefix for queries. Override in subclasses if needed. + query_prefix: ClassVar[str] = "" # Default prefix for queries. Override in subclasses if needed. @abstractmethod def process_images( From 44fe1e642ca53a91135483f07c03b3dcd5deb3a2 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 11:19:48 +0200 Subject: [PATCH 21/42] f --- colpali_engine/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index a9511d37..e2289faa 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -2,7 +2,7 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor -# from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor +from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor from .eurovbert import BiEuroVBert, BiEuroVBertProcessor, ColEuroVBert, ColEuroVBertProcessor from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor from .siglip import BiSiglip, BiSiglipProcessor From 1ec65fcc0da1da51c1efc802e5f4a7ec54119b22 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 11:30:05 +0200 Subject: [PATCH 22/42] f --- colpali_engine/models/eurovbert/__init__.py | 5 ++ .../eurovbert/bivbert/processing_bivbert.py | 53 +++++++++++++++++++ .../colvbert/processing_coleurovbert.py | 14 +++-- .../bivbert/processing_bimodernvbert.py | 2 + .../models/modernvbert/colvbert/__init__.py | 2 +- .../colvbert/modeling_colmodernvbert.py | 1 - .../colvbert/processing_colmodernvbert.py | 14 +++-- 7 files changed, 79 insertions(+), 12 deletions(-) create mode 100644 colpali_engine/models/eurovbert/bivbert/processing_bivbert.py diff --git a/colpali_engine/models/eurovbert/__init__.py b/colpali_engine/models/eurovbert/__init__.py index dc492dc5..76a932b7 100644 --- a/colpali_engine/models/eurovbert/__init__.py +++ b/colpali_engine/models/eurovbert/__init__.py @@ -1,2 +1,7 @@ +<<<<<<< HEAD from .bivbert import BiEuroVBert, BiEuroVBertProcessor +======= +from .bivbert import BiVBert as BiEuroVBert +from .bivbert import BiVBertProcessor as BiEuroVBertProcessor +>>>>>>> 2c87c8a (vbert processor fixes) from .colvbert import ColEuroVBert, ColEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py b/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py new file mode 100644 index 00000000..19c0e54e --- /dev/null +++ b/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py @@ -0,0 +1,53 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.eurovbert.colvbert import ColEuroVBertProcessor + + +class BiVBertProcessor(ColEuroVBertProcessor): # noqa: N801 + """ + Processor for BiVBert. + """ + + def process_texts( + self, + texts: List[str], + max_length: int = 50, + contexts: Optional[List[str]] = None, + suffix: Optional[str] = None, + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiVBert. + + NOTE: `max_length` is not used and kept only for trainer compatibility. + """ + if suffix is None: + suffix = self.query_augmentation_token # we remove buffer tokens + if contexts is None: + contexts = [""] * len(texts) + + prompts = [context + text + suffix for context, text in zip(contexts, texts)] + + batch_texts = self( + text=prompts, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=4096, + ) + + return batch_texts + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py b/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py index 88e655a4..5bd476f1 100644 --- a/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py +++ b/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py @@ -16,13 +16,13 @@ class ColEuroVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): image_token: ClassVar[str] = "" visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, image_seq_len=64, **kwargs): + super().__init__(*args, image_seq_len=image_seq_len, **kwargs) self.tokenizer.padding_side = "left" - @property - def image_token_id(self) -> int: - return self.tokenizer.convert_tokens_to_ids(self.image_token) + # @property + # def image_token_id(self) -> int: + # return self.tokenizer.convert_tokens_to_ids(self.image_token) def process_images( self, @@ -41,6 +41,8 @@ def process_images( images=images, padding="longest", return_tensors="pt", + truncation=True, + max_length=8192, ) return batch_doc @@ -58,6 +60,8 @@ def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: text=texts, return_tensors="pt", padding="longest", + truncation=True, + max_length=4096, ) def score( diff --git a/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py b/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py index c6e72fdc..80a961ac 100644 --- a/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py +++ b/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py @@ -25,6 +25,8 @@ def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: text=texts, return_tensors="pt", padding="longest", + truncation=True, + max_length=4096 ) def score( diff --git a/colpali_engine/models/modernvbert/colvbert/__init__.py b/colpali_engine/models/modernvbert/colvbert/__init__.py index e8f041b5..1b073552 100644 --- a/colpali_engine/models/modernvbert/colvbert/__init__.py +++ b/colpali_engine/models/modernvbert/colvbert/__init__.py @@ -1,2 +1,2 @@ from .modeling_colmodernvbert import ColModernVBert -from .processing_colmodernvbert import ColModernVBertProcessor \ No newline at end of file +from .processing_colmodernvbert import ColModernVBertProcessor diff --git a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py index a1dd11a1..89c6d5be 100644 --- a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py @@ -1,5 +1,4 @@ from torch import nn -import torch from colpali_engine.models.modernvbert.modeling_vbert import VBertModel, VBertPreTrainedModel diff --git a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py index 99f143e7..f9e5515c 100644 --- a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py @@ -16,13 +16,13 @@ class ColModernVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): image_token: ClassVar[str] = "" visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, image_seq_len=64, **kwargs): + super().__init__(*args, image_seq_len=image_seq_len, **kwargs) self.tokenizer.padding_side = "left" - @property - def image_token_id(self) -> int: - return self.tokenizer.convert_tokens_to_ids(self.image_token) + # @property + # def image_token_id(self) -> int: + # return self.tokenizer.convert_tokens_to_ids(self.image_token) def process_images( self, @@ -41,6 +41,8 @@ def process_images( images=images, padding="longest", return_tensors="pt", + truncation=True, + max_length=8192, ) return batch_doc @@ -58,6 +60,8 @@ def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: text=texts, return_tensors="pt", padding="longest", + truncation=True, + max_length=4096, ) def score( From 1b8510f11c1192c6a504217d009ad64b7443ca0d Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 11:31:48 +0200 Subject: [PATCH 23/42] f --- colpali_engine/models/eurovbert/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colpali_engine/models/eurovbert/__init__.py b/colpali_engine/models/eurovbert/__init__.py index 76a932b7..dc492dc5 100644 --- a/colpali_engine/models/eurovbert/__init__.py +++ b/colpali_engine/models/eurovbert/__init__.py @@ -1,7 +1,2 @@ -<<<<<<< HEAD from .bivbert import BiEuroVBert, BiEuroVBertProcessor -======= -from .bivbert import BiVBert as BiEuroVBert -from .bivbert import BiVBertProcessor as BiEuroVBertProcessor ->>>>>>> 2c87c8a (vbert processor fixes) from .colvbert import ColEuroVBert, ColEuroVBertProcessor From 3dfbe4bac58a3d03f04197c31ea5afb92d1074e9 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 11:34:26 +0200 Subject: [PATCH 24/42] remove bvbert file --- .../eurovbert/bivbert/processing_bivbert.py | 53 ------------------- 1 file changed, 53 deletions(-) delete mode 100644 colpali_engine/models/eurovbert/bivbert/processing_bivbert.py diff --git a/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py b/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py deleted file mode 100644 index 19c0e54e..00000000 --- a/colpali_engine/models/eurovbert/bivbert/processing_bivbert.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import List, Optional, Union - -import torch -from transformers import BatchEncoding, BatchFeature - -from colpali_engine.models.eurovbert.colvbert import ColEuroVBertProcessor - - -class BiVBertProcessor(ColEuroVBertProcessor): # noqa: N801 - """ - Processor for BiVBert. - """ - - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for BiVBert. - - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - suffix = self.query_augmentation_token # we remove buffer tokens - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] - - batch_texts = self( - text=prompts, - return_tensors="pt", - padding="longest", - truncation=True, - max_length=4096, - ) - - return batch_texts - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the cosine similarity for the given query and passage embeddings. - """ - return self.score_single_vector(qs, ps, device=device) From d748aa158f949523a0324bf3171212d989b2104c Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 11:01:38 +0200 Subject: [PATCH 25/42] negatives loss --- colpali_engine/loss/bi_encoder_losses.py | 15 +++++++------- .../loss/late_interaction_losses.py | 20 +++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/colpali_engine/loss/bi_encoder_losses.py b/colpali_engine/loss/bi_encoder_losses.py index 92e4a657..a5ea142f 100644 --- a/colpali_engine/loss/bi_encoder_losses.py +++ b/colpali_engine/loss/bi_encoder_losses.py @@ -162,7 +162,7 @@ def forward( self._filter_high_negatives(scores, pos_idx) q2t = self.ce_loss(scores / self.temperature, pos_idx) - t2q = self.ce_loss(scores.T / self.temperature, ...) + t2q = self.ce_loss(scores.T / self.temperature, ...) return (q2t + t2q) / 2.0 @@ -216,17 +216,18 @@ def forward( Args: query_embeddings (Tensor[B, D]): Query vectors. doc_embeddings (Tensor[B, D]): Positive document vectors. - neg_doc_embeddings (Tensor[B, D]): Negative document vectors. + neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors. offset (int): Offset for in-batch CE positives. Returns: Tensor: Scalar loss value. """ # Dot-product only for matching pairs - pos_scores = (query_embeddings * doc_embeddings).sum(dim=1) / self.temperature - neg_scores = (query_embeddings * neg_doc_embeddings).sum(dim=1) / self.temperature + pos_scores = (query_embeddings * doc_embeddings[offset:offset + neg_doc_embeddings.size(0)]).sum(dim=1) + pos_scores /= self.temperature + neg_scores = torch.einsum("bd,bnd->bn", query_embeddings, neg_doc_embeddings) / self.temperature - loss = F.softplus(neg_scores - pos_scores).mean() + loss = F.softplus(neg_scores - pos_scores.unsqueeze(1)).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset) @@ -411,5 +412,5 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, all_losses.append(block_loss) # shift the offset for the next batch offset = (offset + batch_size) % num_targets - - return torch.stack(all_losses, dim=0).mean() \ No newline at end of file + + return torch.stack(all_losses, dim=0).mean() diff --git a/colpali_engine/loss/late_interaction_losses.py b/colpali_engine/loss/late_interaction_losses.py index 5d08c6ac..2219abf3 100644 --- a/colpali_engine/loss/late_interaction_losses.py +++ b/colpali_engine/loss/late_interaction_losses.py @@ -224,25 +224,29 @@ def forward( Compute InfoNCE loss with explicit negatives and optional in-batch term. Args: - query_embeddings (Tensor): [B, Nq, D] - doc_embeddings (Tensor): [B, Nd, D] positive docs - neg_doc_embeddings (Tensor): [B, Nneg, D] negative docs + query_embeddings (Tensor): [B, Lq, D] + doc_embeddings (Tensor): [B, Ld, D] positive docs + neg_doc_embeddings (Tensor): [B, Nneg, Lneg, D] negative docs offset (int): Positional offset for in-batch CE. Returns: Tensor: Scalar loss. """ lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) - pos_raw = torch.einsum("bnd,bsd->bns", query_embeddings, doc_embeddings) - neg_raw = torch.einsum("bnd,bsd->bns", query_embeddings, neg_doc_embeddings) + pos_raw = torch.einsum( + "bnd,bsd->bns", + query_embeddings, + doc_embeddings[offset:offset + neg_doc_embeddings.size(0)] + ) + neg_raw = torch.einsum("bnd,blsd->blns", query_embeddings, neg_doc_embeddings) pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1) - neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=2, dim_sum=1) + neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2) if self.normalize_scores: pos_scores = self._apply_normalization(pos_scores, lengths) neg_scores = self._apply_normalization(neg_scores, lengths) - loss = F.softplus((neg_scores - pos_scores) / self.temperature).mean() + loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset) @@ -458,5 +462,5 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, # flatten the scores to [B * B] scores = scores.view(-1) / self.temperature - + return F.softplus(scores * pos_mask).mean() From afd0e9554caa776cc063de2797e05b056dac48cd Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 11:09:00 +0200 Subject: [PATCH 26/42] prepare collators for multi-hardnegs --- colpali_engine/collators/collator_copy.py | 142 ++++++++++++++++++ .../collators/visual_retriever_collator.py | 28 +++- 2 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 colpali_engine/collators/collator_copy.py diff --git a/colpali_engine/collators/collator_copy.py b/colpali_engine/collators/collator_copy.py new file mode 100644 index 00000000..f131c233 --- /dev/null +++ b/colpali_engine/collators/collator_copy.py @@ -0,0 +1,142 @@ +import random +import torch +from typing import Any, Dict, List, Union + +from PIL.Image import Image + +from colpali_engine.data.dataset import ColPaliEngineDataset +from colpali_engine.models.paligemma import ColPaliProcessor +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: + """ + Prefix all keys in a dictionary with the given prefix. + """ + return {f"{prefix}{k}": v for k, v in data.items()} + + +class VisualRetrieverCollator: + """ + Collator for training vision retrieval models. + """ + + # Prefixes + query_prefix = "query_" + pos_doc_prefix = "doc_" + neg_doc_prefix = "neg_doc_" + + def __init__( + self, + processor: BaseVisualRetrieverProcessor, + max_length: int = 2048, + ): + self.processor = processor + self.max_length = max_length + self.image_token_id = None + + # If processor is one of the supported types, extract the token id. + if isinstance(self.processor, (ColPaliProcessor,)): + image_token = "" + try: + idx = self.processor.tokenizer.additional_special_tokens.index(image_token) + self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[idx] + except ValueError: + self.image_token_id = None + + # Force padding to be on the right for ColPaliProcessor. + if isinstance(self.processor, ColPaliProcessor) and self.processor.tokenizer.padding_side != "right": + print("Setting padding side to right") + self.processor.tokenizer.padding_side = "right" + + def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: + queries: List[Union[None, str, Image]] = [] + pos_targets: List[Union[str, Image]] = [] + neg_targets: List[Union[str, Image]] = [] + selected_ids: List[int] = [] + + # Parse the examples. + positive_ids_tensor = -torch.ones((len(examples), 100), dtype=torch.long) + for i, example in enumerate(examples): + assert ColPaliEngineDataset.QUERY_KEY in example, f"Missing {ColPaliEngineDataset.QUERY_KEY} in example." + query = example[ColPaliEngineDataset.QUERY_KEY] + sampled_query = random.choice(query) if isinstance(query, list) else query + queries.append(sampled_query) + + assert ColPaliEngineDataset.POS_TARGET_KEY in example, ( + f"Missing {ColPaliEngineDataset.POS_TARGET_KEY} in example." + ) + pos_tgt = example[ColPaliEngineDataset.POS_TARGET_KEY] + positive_ids = example.get("positive_ids", None) + if isinstance(pos_tgt, list): + sample_tuple = random.choice([(t, id_) for t, id_ in zip(pos_tgt, positive_ids)]) + sample_pos = sample_tuple[0] + selected_ids.append(sample_tuple[1]) + else: + sample_pos = pos_tgt + pos_targets.append(sample_pos) + if positive_ids is not None: + positive_ids_tensor[i, :len(positive_ids)] = torch.tensor(positive_ids) + + neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) + if neg_tgt is not None: + # sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt + # neg_targets.append(random.choice(neg_tgt)) #neg_tgts) + neg_targets.append(neg_tgt) + + # Ensure all queries are strings or images. + assert all(isinstance(q, str) for q in queries), ( + "All queries must be strings, this collator does not support images in queries." + ) + + # Process queries. + queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] + batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) + + # Process targets. + batch_pos_target = self.auto_collate(pos_targets, key_prefix=self.pos_doc_prefix) + batch_neg_target = self.auto_collate(neg_targets, key_prefix=self.neg_doc_prefix) if neg_targets else {} + + return { + **batch_query, + **batch_pos_target, + **batch_neg_target, + "selected_ids": torch.Tensor(selected_ids), + "positive_ids_tensor": positive_ids_tensor, + } + + def auto_collate(self, batch: List[Union[str, Image, List[str], List[Image]]], key_prefix: str = "") -> Dict[str, Any]: + """Automatically collate a batch of documents.""" + # Convert Document objects to their underlying data. + # if type is mixed across the batch, raise an error. + all_types = set(type(item) for item in batch) + if str in all_types and Image in all_types: + raise ValueError(f"Batch contains mixed types: {all_types}. Expected all items to be of the same type.") + if isinstance(batch[0], str): + proc_batch = self.processor.process_texts(texts=batch) + elif isinstance(batch[0], Image): + proc_batch = self.processor.process_images(images=batch) + elif isinstance(batch[0], list): + if isinstance(batch[0][0], str): + proc_texts_batch = [] + batch_size = len(batch) + all_texts = [text for texts in batch for text in texts] + num_negatives = len(all_texts) // batch_size + proc_batch = self.processor.process_texts(texts=all_texts) + elif isinstance(batch[0][0], Image): + proc_imgs_batch = [] + batch_size = len(batch) + all_imgs = [img for imgs in batch for img in imgs] + num_negatives = len(all_imgs) // batch_size + proc_batch = self.processor.process_images(images=all_imgs) + else: + raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.") + for k, v in proc_batch.items(): + if isinstance(v, torch.Tensor): + proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:]) + else: + proc_batch[k] = v + else: + raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") + + return prefix_keys(proc_batch, key_prefix) \ No newline at end of file diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 47955a95..bae3066e 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,5 +1,6 @@ import random from typing import Any, Dict, List, Union +import torch from PIL.Image import Image @@ -69,17 +70,18 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) if neg_tgt is not None: - sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt - neg_targets.append(sampled_neg) + neg_targets.append(neg_tgt) # Ensure all queries are strings or images. assert all(isinstance(q, str) for q in queries), ( "All queries must be strings, this collator does not support images in queries." ) + is_str = isinstance(queries[0], str) + # Process queries. # queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] - queries = [q + self.processor.query_augmentation_token * 10 for q in queries] + queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) # Process targets. @@ -103,6 +105,26 @@ def auto_collate(self, batch: List[Union[str, Image]], key_prefix: str = "") -> proc_batch = self.processor.process_texts(texts=batch) elif isinstance(batch[0], Image): proc_batch = self.processor.process_images(images=batch) + elif isinstance(batch[0], list): + if isinstance(batch[0][0], str): + proc_texts_batch = [] + batch_size = len(batch) + all_texts = [text for texts in batch for text in texts] + num_negatives = len(all_texts) // batch_size + proc_batch = self.processor.process_texts(texts=all_texts) + elif isinstance(batch[0][0], Image): + proc_imgs_batch = [] + batch_size = len(batch) + all_imgs = [img for imgs in batch for img in imgs] + num_negatives = len(all_imgs) // batch_size + proc_batch = self.processor.process_images(images=all_imgs) + else: + raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.") + for k, v in proc_batch.items(): + if isinstance(v, torch.Tensor): + proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:]) + else: + proc_batch[k] = v else: raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") return prefix_keys(proc_batch, key_prefix) From dcbbe15ad0e4f31998d5c2e7c05385ab6073869b Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 11:16:43 +0200 Subject: [PATCH 27/42] multiple hard negs training --- colpali_engine/data/dataset.py | 6 ++- colpali_engine/trainer/contrastive_trainer.py | 41 +++++++++++++++++-- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/colpali_engine/data/dataset.py b/colpali_engine/data/dataset.py index 311d7421..8eec842e 100644 --- a/colpali_engine/data/dataset.py +++ b/colpali_engine/data/dataset.py @@ -77,6 +77,7 @@ def __init__( query_column_name: str = "query", pos_target_column_name: str = "pos_target", neg_target_column_name: str = None, + num_negatives: int = 3, ): """ Initialize the dataset with the provided data and external document corpus. @@ -94,6 +95,7 @@ def __init__( self.pos_target_column_name = pos_target_column_name self.neg_target_column_name = neg_target_column_name + self.num_negatives = num_negatives assert isinstance( self.data, (list, Dataset, HFDataset), @@ -131,8 +133,8 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: pos_targets = [self.corpus.retrieve(doc_id) for doc_id in pos_targets] if neg_targets is not None: # to avoid oveflowing CPU memory - if len(neg_targets) > 5: - neg_targets = random.sample(neg_targets, 5) + if len(neg_targets) > self.num_negatives: + neg_targets = random.sample(neg_targets, self.num_negatives) neg_targets = [self.corpus.retrieve(doc_id) for doc_id in neg_targets] return { diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 68a6c608..ccd699c3 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -140,16 +140,46 @@ def _compute_loss_from_outputs( batch_size = query_outputs.size(0) if self.accelerator.num_processes > 1 and self.accelerator.sync_gradients: # gather docs across all processes - pos_target_outputs = concat_all_gather(pos_target_outputs) + pos_target_outputs = self.accelerator.pad_across_processes(pos_target_outputs, dim=1, pad_index=0, pad_first=True) + pos_target_outputs = concat_all_gather(pos_target_outputs) rank = self.accelerator.process_index offset = rank * batch_size if neg_target_outputs is not None: - loss = self.loss_func(query_outputs, pos_target_outputs, neg_target_outputs, offset=offset) + loss = self.loss_func( + query_embeddings=query_outputs, + doc_embeddings=pos_target_outputs, + neg_doc_embeddings=neg_target_outputs, + offset=offset + ) else: - loss = self.loss_func(query_outputs, pos_target_outputs, offset=offset) + loss = self.loss_func( + query_embeddings=query_outputs, + doc_embeddings=pos_target_outputs, + offset=offset + ) return loss + + def _reshape_neg_doc_inputs(self, inputs): + """ + Helper function to reshape negative doc inputs to (batch_size * num_neg_docs, ...) + """ + neg_doc_inputs = {k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")} + + for k in neg_doc_inputs: + # go from (batch_size, num_neg_docs, ...) to (batch_size * num_neg_docs, ...) + neg_doc_inputs[k] = neg_doc_inputs[k].view(-1, *neg_doc_inputs[k].shape[2:]) + + return neg_doc_inputs + + def _reshape_neg_doc_outputs(self, neg_doc_outputs, num_neg_docs): + """ + Helper function to reshape negative doc outputs to (batch_size, num_neg_docs, ...) + """ + neg_doc_outputs = neg_doc_outputs.view(-1, num_neg_docs, *neg_doc_outputs.shape[1:]) + + return neg_doc_outputs def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): query_outputs = model(**{k[6:]: v for k, v in inputs.items() if k.startswith("query")}) @@ -157,7 +187,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) if "neg_doc_input_ids" in inputs: # Negative docs are not gathered across processes, so we can use them without offset - neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) + num_negs = inputs["neg_doc_input_ids"].size(1) + neg_doc_inputs = self._reshape_neg_doc_inputs(inputs) + neg_doc_outputs = model(**neg_doc_inputs) + neg_doc_outputs = self._reshape_neg_doc_outputs(neg_doc_outputs, num_negs) else: neg_doc_outputs = None From 24cd010b8d8ade5e80dc7fcc6478fbe5eadab0cb Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 12:03:36 +0200 Subject: [PATCH 28/42] f --- colpali_engine/utils/processing_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 4c1d9617..4eaf2d6b 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -122,8 +122,8 @@ def score_single_vector( qs = qs.to(device) ps = ps.to(device) - scores = torch.einsum("bd,cd->bc", qs, ps) - assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" + scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) + assert scores.shape[0] == len(qs_stacked), f"Expected {len(qs_stacked)} scores, got {scores.shape[0]}" scores = scores.to(torch.float32) return scores From 5c11cd3df47204959e48bc997b268ca5d932ad4f Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Mon, 29 Sep 2025 13:47:46 +0200 Subject: [PATCH 29/42] rm colqwen_omni init --- colpali_engine/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index e2289faa..ef994eea 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -2,7 +2,6 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor -from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor from .eurovbert import BiEuroVBert, BiEuroVBertProcessor, ColEuroVBert, ColEuroVBertProcessor from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor from .siglip import BiSiglip, BiSiglipProcessor From da868ae3c21881d01e32cf17fc5704c8b9d7de20 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Tue, 30 Sep 2025 09:53:47 +0200 Subject: [PATCH 30/42] f --- colpali_engine/utils/processing_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 4eaf2d6b..4c1d9617 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -122,8 +122,8 @@ def score_single_vector( qs = qs.to(device) ps = ps.to(device) - scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) - assert scores.shape[0] == len(qs_stacked), f"Expected {len(qs_stacked)} scores, got {scores.shape[0]}" + scores = torch.einsum("bd,cd->bc", qs, ps) + assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" scores = scores.to(torch.float32) return scores From fa1ea7659e563d378e9f3447ff506bd231d6e10e Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Tue, 30 Sep 2025 10:09:19 +0200 Subject: [PATCH 31/42] modif tests --- tests/loss/test_bi_losses.py | 8 ++++---- tests/loss/test_li_losses.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/loss/test_bi_losses.py b/tests/loss/test_bi_losses.py index ea4cf5c3..d4c36ad1 100644 --- a/tests/loss/test_bi_losses.py +++ b/tests/loss/test_bi_losses.py @@ -64,10 +64,10 @@ def test_forward_with_filtering(self): class TestBiNegativeCELoss: def test_forward_no_inbatch(self): loss_fn = BiNegativeCELoss(temperature=1.0, in_batch_term_weight=0, pos_aware_negative_filtering=False) - B, D = 3, 4 + B, D, Nneg = 3, 4, 1 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg, D) loss = loss_fn(query, pos, neg) # softplus(0 - 0) = ln(2) expected = F.softplus(torch.tensor(0.0)) @@ -75,10 +75,10 @@ def test_forward_no_inbatch(self): def test_forward_with_inbatch(self): loss_fn = BiNegativeCELoss(temperature=1.0, in_batch_term_weight=0.5, pos_aware_negative_filtering=False) - B, D = 2, 3 + B, D, Nneg = 2, 3, 1 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg, D) loss = loss_fn(query, pos, neg) # in-batch CE on zeros: log(B) ce = torch.log(torch.tensor(float(B))) diff --git a/tests/loss/test_li_losses.py b/tests/loss/test_li_losses.py index 77faf0f1..b363baaf 100644 --- a/tests/loss/test_li_losses.py +++ b/tests/loss/test_li_losses.py @@ -109,10 +109,10 @@ def test_no_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0, ) - B, Nq, D, Nneg = 2, 1, 3, 1 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, Nneg, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) @@ -125,10 +125,10 @@ def test_with_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0.5, ) - B, Nq, D = 2, 1, 3 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, 1, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) From 3fb3df4568f224f66f8519fb0e05d13b27d9f7b9 Mon Sep 17 00:00:00 2001 From: Manuel Faysse <43467008+ManuelFay@users.noreply.github.com> Date: Fri, 3 Oct 2025 09:19:17 +0200 Subject: [PATCH 32/42] Change default model --- colpali_engine/models/modernvbert/configuration_vbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/models/modernvbert/configuration_vbert.py b/colpali_engine/models/modernvbert/configuration_vbert.py index 504f333b..659e1678 100644 --- a/colpali_engine/models/modernvbert/configuration_vbert.py +++ b/colpali_engine/models/modernvbert/configuration_vbert.py @@ -136,7 +136,7 @@ class VBertConfig(PretrainedConfig): is_composition = True # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig} - DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" + DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m" DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" def __init__( From 31630d16ab0c0cfbfce70c6c8b8e93b43258ba2a Mon Sep 17 00:00:00 2001 From: QuentinJGMace <95310069+QuentinJGMace@users.noreply.github.com> Date: Fri, 3 Oct 2025 09:20:25 +0200 Subject: [PATCH 33/42] Change default text model name in configuration --- colpali_engine/models/modernvbert/configuration_vbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/models/modernvbert/configuration_vbert.py b/colpali_engine/models/modernvbert/configuration_vbert.py index 659e1678..4f0a8daf 100644 --- a/colpali_engine/models/modernvbert/configuration_vbert.py +++ b/colpali_engine/models/modernvbert/configuration_vbert.py @@ -39,7 +39,7 @@ class VBertTextConfig(PretrainedConfig): def __init__( self, # Case for when vllama3 is from the hub with no vision_model_name - text_model_name="EuroBERT/EuroBERT-210m", + text_model_name="jhu-clsp/ettin-encoder-150m", **kwargs, ): self.text_model_name = text_model_name From 9ce2871560e252a19e2ca82033ab44981f17310f Mon Sep 17 00:00:00 2001 From: Paul Teiletche <73120933+paultltc@users.noreply.github.com> Date: Thu, 16 Oct 2025 14:23:43 -0700 Subject: [PATCH 34/42] fix: `ModernVBERT` modeling (#348) * modeling * update modeling * update token id default * init files * remove vllama + update torch lower bound for cpu * back to normal transformer bound * clean * Update colpali_engine/models/__init__.py --------- Co-authored-by: QuentinJGMace <95310069+QuentinJGMace@users.noreply.github.com> --- colpali_engine/__init__.py | 4 + colpali_engine/models/__init__.py | 5 +- colpali_engine/models/eurovbert/__init__.py | 2 - .../models/eurovbert/bivbert/__init__.py | 2 - .../eurovbert/bivbert/modeling_bieurovbert.py | 65 -- .../bivbert/processing_bieurovbert.py | 40 - .../models/eurovbert/colvbert/__init__.py | 2 - .../colvbert/modeling_coleurovbert.py | 51 - .../colvbert/processing_coleurovbert.py | 84 -- .../models/eurovbert/configuration_vbert.py | 210 ---- .../bivbert/modeling_bimodernvbert.py | 8 +- .../colvbert/modeling_colmodernvbert.py | 8 +- .../modernvbert/configuration_modernvbert.py | 273 +++++ .../models/modernvbert/configuration_vbert.py | 232 ----- .../modeling_modernvbert.py} | 333 ++----- .../models/modernvbert/modeling_vbert.py | 935 ------------------ colpali_engine/models/siglip/__init__.py | 2 - .../models/siglip/modeling_bisiglip.py | 50 - .../models/siglip/processing_bisiglip.py | 75 -- colpali_engine/models/vllama/__init__.py | 2 - .../models/vllama/bivllama/__init__.py | 2 - .../vllama/bivllama/modeling_bivllama.py | 64 -- .../vllama/bivllama/processing_bivllama.py | 51 - .../models/vllama/colvllama/__init__.py | 2 - .../vllama/colvllama/modeling_colvllama.py | 51 - .../vllama/colvllama/processing_colvllama.py | 93 -- .../models/vllama/configuration_vllama.py | 232 ----- .../models/vllama/modeling_vllama.py | 887 ----------------- pyproject.toml | 4 +- 29 files changed, 377 insertions(+), 3392 deletions(-) delete mode 100644 colpali_engine/models/eurovbert/__init__.py delete mode 100644 colpali_engine/models/eurovbert/bivbert/__init__.py delete mode 100644 colpali_engine/models/eurovbert/bivbert/modeling_bieurovbert.py delete mode 100644 colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py delete mode 100644 colpali_engine/models/eurovbert/colvbert/__init__.py delete mode 100644 colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py delete mode 100644 colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py delete mode 100644 colpali_engine/models/eurovbert/configuration_vbert.py create mode 100644 colpali_engine/models/modernvbert/configuration_modernvbert.py delete mode 100644 colpali_engine/models/modernvbert/configuration_vbert.py rename colpali_engine/models/{eurovbert/modeling_vbert.py => modernvbert/modeling_modernvbert.py} (64%) delete mode 100644 colpali_engine/models/modernvbert/modeling_vbert.py delete mode 100644 colpali_engine/models/siglip/__init__.py delete mode 100644 colpali_engine/models/siglip/modeling_bisiglip.py delete mode 100644 colpali_engine/models/siglip/processing_bisiglip.py delete mode 100644 colpali_engine/models/vllama/__init__.py delete mode 100644 colpali_engine/models/vllama/bivllama/__init__.py delete mode 100644 colpali_engine/models/vllama/bivllama/modeling_bivllama.py delete mode 100644 colpali_engine/models/vllama/bivllama/processing_bivllama.py delete mode 100644 colpali_engine/models/vllama/colvllama/__init__.py delete mode 100644 colpali_engine/models/vllama/colvllama/modeling_colvllama.py delete mode 100644 colpali_engine/models/vllama/colvllama/processing_colvllama.py delete mode 100644 colpali_engine/models/vllama/configuration_vllama.py delete mode 100644 colpali_engine/models/vllama/modeling_vllama.py diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py index 3341df6b..6934e80f 100644 --- a/colpali_engine/__init__.py +++ b/colpali_engine/__init__.py @@ -5,6 +5,8 @@ BiQwen2_5, BiQwen2_5_Processor, BiQwen2Processor, + BiModernVBert, + BiModernVBertProcessor, ColIdefics3, ColIdefics3Processor, ColPali, @@ -15,4 +17,6 @@ # ColQwen2_5Omni, # ColQwen2_5OmniProcessor, ColQwen2Processor, + ColModernVBert, + ColModernVBertProcessor, ) diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index ef994eea..ae92178d 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -2,7 +2,4 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor -from .eurovbert import BiEuroVBert, BiEuroVBertProcessor, ColEuroVBert, ColEuroVBertProcessor -from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor -from .siglip import BiSiglip, BiSiglipProcessor -from .vllama import BiVLlama, BiVLlamaProcessor, ColVLlama, ColVLlamaProcessor +from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor \ No newline at end of file diff --git a/colpali_engine/models/eurovbert/__init__.py b/colpali_engine/models/eurovbert/__init__.py deleted file mode 100644 index dc492dc5..00000000 --- a/colpali_engine/models/eurovbert/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .bivbert import BiEuroVBert, BiEuroVBertProcessor -from .colvbert import ColEuroVBert, ColEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/bivbert/__init__.py b/colpali_engine/models/eurovbert/bivbert/__init__.py deleted file mode 100644 index 3d04309f..00000000 --- a/colpali_engine/models/eurovbert/bivbert/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_bieurovbert import BiEuroVBert -from .processing_bieurovbert import BiEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/bivbert/modeling_bieurovbert.py b/colpali_engine/models/eurovbert/bivbert/modeling_bieurovbert.py deleted file mode 100644 index 03e5dc1d..00000000 --- a/colpali_engine/models/eurovbert/bivbert/modeling_bieurovbert.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import torch - -from typing import Literal, Union - -from colpali_engine.models.eurovbert.modeling_vbert import VBertModel, VBertPreTrainedModel -from colpali_engine.models.eurovbert.configuration_vbert import VBertConfig - - -class BiEuroVBert(VBertPreTrainedModel): - """ - Initializes the BiIdefics3 model. - - Args: - config : The model configuration. - """ - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - - def __init__(self, config, pooling_strategy = "mean", **kwargs): - super().__init__(config=config) - self.model = VBertModel(config, **kwargs) - self.pooling_strategy = pooling_strategy - self.post_init() - - def forward( - self, - pooling_strategy: Literal["cls", "last", "mean"] = None, - *args, - **kwargs, - ) -> torch.Tensor: - """ - Forward pass through model and pooling. - - Args: - - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". - - input_ids (torch.LongTensor): The input tokens tensor. - - attention_mask (torch.LongTensor): The attention mask tensor. - - Returns: - - torch.Tensor: Embeddings of shape (batch_size, dim) - """ - outputs = self.model(*args, **kwargs) - last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - - pooling_strategy = pooling_strategy or self.pooling_strategy - - # Get CLS token embedding, last token, or mean pool over sequence - if pooling_strategy == "cls": - # Use CLS token (first token) embedding - pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) - elif pooling_strategy == "last": - # Use last token - pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) - elif pooling_strategy == "mean": - # Mean pooling over sequence length - mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) - pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) - else: - raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") - - # L2 normalization - pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True) - return pooled_output diff --git a/colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py b/colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py deleted file mode 100644 index 7208b68f..00000000 --- a/colpali_engine/models/eurovbert/bivbert/processing_bieurovbert.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import List, Optional, Union - -import torch -from transformers import BatchEncoding, BatchFeature - -from colpali_engine.models.eurovbert.colvbert import ColEuroVBertProcessor - - -class BiEuroVBertProcessor(ColEuroVBertProcessor): # noqa: N801 - """ - Processor for BiVBert. - """ - - def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for BiModernVBert. - - Args: - texts: List of input texts. - - Returns: - Union[BatchFeature, BatchEncoding]: Processed texts. - """ - return self( - text=texts, - return_tensors="pt", - padding="longest", - ) - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the cosine similarity for the given query and passage embeddings. - """ - return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/eurovbert/colvbert/__init__.py b/colpali_engine/models/eurovbert/colvbert/__init__.py deleted file mode 100644 index 4e0b32a9..00000000 --- a/colpali_engine/models/eurovbert/colvbert/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_coleurovbert import ColEuroVBert -from .processing_coleurovbert import ColEuroVBertProcessor diff --git a/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py b/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py deleted file mode 100644 index 9cfb7709..00000000 --- a/colpali_engine/models/eurovbert/colvbert/modeling_coleurovbert.py +++ /dev/null @@ -1,51 +0,0 @@ -from torch import nn - -from colpali_engine.models.eurovbert.modeling_vbert import VBertModel, VBertPreTrainedModel - - -class ColEuroVBert(VBertPreTrainedModel): - """ - Initializes the ColVBert model. - - Args: - config : The model configuration. - mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings - except those of the image at inference. - Defaults to False --> Do not mask any embeddings during forward pass. - """ - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): - super().__init__(config=config) - self.model = VBertModel(config, **kwargs) - self.dim = 128 - self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) - self.mask_non_image_embeddings = mask_non_image_embeddings - self.main_input_name = "doc_input_ids" - - def forward(self, *args, **kwargs): - """ - Forward pass through the model and the linear layer for dimensionality reduction - - Args: - - input_ids (torch.LongTensor): The input tokens tensor. - - attention_mask (torch.LongTensor): The attention mask tensor. - - Returns: - - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) - """ - outputs = self.model(*args, **kwargs) - last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - proj = self.custom_text_proj(last_hidden_states) - # normalize l2 norm - proj = proj / proj.norm(dim=-1, keepdim=True) - proj = proj * kwargs["attention_mask"].unsqueeze(-1) - - if "pixel_values" in kwargs and self.mask_non_image_embeddings: - # Pools only the image embeddings - image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) - proj = proj * image_mask - return proj diff --git a/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py b/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py deleted file mode 100644 index 5bd476f1..00000000 --- a/colpali_engine/models/eurovbert/colvbert/processing_coleurovbert.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import ClassVar, List, Optional, Tuple, Union - -import torch -from PIL import Image -from transformers import BatchEncoding, BatchFeature, Idefics3Processor - -from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor - - -class ColEuroVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): - """ - Processor for ColIdefics3. - """ - - query_augmentation_token: ClassVar[str] = "" - image_token: ClassVar[str] = "" - visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" - - def __init__(self, *args, image_seq_len=64, **kwargs): - super().__init__(*args, image_seq_len=image_seq_len, **kwargs) - self.tokenizer.padding_side = "left" - - # @property - # def image_token_id(self) -> int: - # return self.tokenizer.convert_tokens_to_ids(self.image_token) - - def process_images( - self, - images: List[Image.Image], - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process images for ColEuroVBert. - - Args: - images: List of PIL images. - """ - images = [image.convert("RGB") for image in images] - - batch_doc = self( - text=[self.visual_prompt_prefix] * len(images), - images=images, - padding="longest", - return_tensors="pt", - truncation=True, - max_length=8192, - ) - return batch_doc - - def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for ColEuroVBert. - - Args: - texts: List of input texts. - - Returns: - Union[BatchFeature, BatchEncoding]: Processed texts. - """ - return self( - text=texts, - return_tensors="pt", - padding="longest", - truncation=True, - max_length=4096, - ) - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. - """ - return self.score_multi_vector(qs, ps, device=device, **kwargs) - - def get_n_patches( - self, - image_size: Tuple[int, int], - patch_size: int, - ) -> Tuple[int, int]: - raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/eurovbert/configuration_vbert.py b/colpali_engine/models/eurovbert/configuration_vbert.py deleted file mode 100644 index f67eb54f..00000000 --- a/colpali_engine/models/eurovbert/configuration_vbert.py +++ /dev/null @@ -1,210 +0,0 @@ -import copy -import os -from typing import Any, Dict, Union - -from transformers import AutoConfig -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -def collect_arg_in_candidates(config, candidates, default = None) -> Any: - """ Gets the argument in a config given a list of candidates """ - for c in candidates: - if hasattr(config, c): - return getattr(config, c) - elif c in config: - return config[c] - if default is not None: - return default - raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) - -class VBertTextConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "vbert" - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - text_model_name="EuroBERT/EuroBERT-210m", - **kwargs, - ): - self.text_model_name = text_model_name - text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - if hasattr(text_config, "text_config"): - text_config = text_config.text_config - - self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) - self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) - self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) - self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) - - super().__init__(text_model_name=text_model_name, **kwargs) - -class VBertVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "vbert" - attribute_map = { - "hidden_size": "embed_dim", - } - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - vision_model_name="google/siglip2-base-patch16-512", - **kwargs, - ): - self.vision_model_name = vision_model_name - vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - if hasattr(vision_config, "vision_config"): - vision_config = vision_config.vision_config - - self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) - self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) - self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) - self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) - - super().__init__(vision_model_name=vision_model_name, **kwargs) - -class VBertConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a - SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM - [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should cache the key/value pairs of the attention mechanism. Only - relevant if `config.is_decoder=True`. - image_token_id (`int`, *optional*, defaults to 128257): - The id of the "image" token. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to tie the word embeddings with the token embeddings. - vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): - Custom vision config or dict for the vision tower - text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): - Custom text config or dict for the text model - scale_factor (`int`, *optional*, defaults to 2): - The scale factor for the image encoder. - pad_token_id (`int`, *optional*, defaults to 128002): - The id of the padding token. - - Example: - ```python - >>> from transformers import SmolVLMModel, SmolVLMConfig - >>> # Initializing configuration - >>> configuration = SmolVLMConfig() - >>> # Initializing a model from the configuration - >>> model = SmolVLMModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "vbert" - is_composition = True - # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig} - - DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" - DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" - - def __init__( - self, - text_config: Union[PretrainedConfig, Dict[str, Any]] = None, - vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, - image_token_id: int = 128_257, - vocab_size=128_256, - use_cache = True, - tie_word_embeddings = False, - freeze_config = None, - pad_token_id = None, - initializer_range = 0.02, - pixel_shuffle_factor = 4, - use_resampler = False, - additional_vocab_size = 0, - neftune_noise_alpha = 0.0, - **kwargs, - ): - self.image_token_id = image_token_id - self.use_cache = use_cache - self.tie_word_embeddings = tie_word_embeddings - self.scale_factor = pixel_shuffle_factor - self.additional_vocab_size = additional_vocab_size - - if text_config is None: - text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) - elif isinstance(text_config, dict): - text_config = VBertTextConfig(text_config["text_model_name"]) - self.text_config = text_config - - if vision_config is None: - vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) - elif isinstance(vision_config, dict): - vision_config = VBertVisionConfig(vision_config["vision_model_name"]) - self.vision_config = vision_config - - self.freeze_config = freeze_config - - # Pixel shuffle factor - self.pixel_shuffle_factor = pixel_shuffle_factor - self.use_resampler = use_resampler - - self.neftune_noise_alpha = neftune_noise_alpha - - self.initializer_range = initializer_range - - hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) - - super().__init__( - **kwargs, - pad_token_id=pad_token_id, - tie_word_embeddings=tie_word_embeddings, - vocab_size=vocab_size, - hidden_size=hidden_size, - ) - - def to_dict(self): - """ - Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. - Returns: - `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, - """ - output = copy.deepcopy(self.__dict__) - - output["model_type"] = self.__class__.model_type - output["vision_config"] = self.vision_config.to_dict() - output["text_config"] = self.text_config.to_dict() - # output["freeze_config"] = self.freeze_config.to_dict() - - return output \ No newline at end of file diff --git a/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py b/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py index 30a6f86f..cf60cb1c 100644 --- a/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py +++ b/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py @@ -2,12 +2,12 @@ import torch -from colpali_engine.models.modernvbert.modeling_vbert import VBertModel, VBertPreTrainedModel +from colpali_engine.models.modernvbert.modeling_modernvbert import ModernVBertModel, ModernVBertPreTrainedModel -class BiModernVBert(VBertPreTrainedModel): +class BiModernVBert(ModernVBertPreTrainedModel): """ - Initializes the BiIdefics3 model. + Initializes the BiModernVBert model. Args: config : The model configuration. @@ -19,7 +19,7 @@ class BiModernVBert(VBertPreTrainedModel): def __init__(self, config, pooling_strategy = "mean", **kwargs): super().__init__(config=config) - self.model = VBertModel(config, **kwargs) + self.model = ModernVBertModel(config, **kwargs) self.pooling_strategy = pooling_strategy self.post_init() diff --git a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py index 89c6d5be..7db8bc9b 100644 --- a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py @@ -1,11 +1,11 @@ from torch import nn -from colpali_engine.models.modernvbert.modeling_vbert import VBertModel, VBertPreTrainedModel +from colpali_engine.models.modernvbert.modeling_modernvbert import ModernVBertModel, ModernVBertPreTrainedModel -class ColModernVBert(VBertPreTrainedModel): +class ColModernVBert(ModernVBertPreTrainedModel): """ - Initializes the ColVBert model. + Initializes the ColModernVBert model. Args: config : The model configuration. @@ -20,7 +20,7 @@ class ColModernVBert(VBertPreTrainedModel): def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): super().__init__(config=config) - self.model = VBertModel(config, **kwargs) + self.model = ModernVBertModel(config, **kwargs) self.dim = 128 self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) self.mask_non_image_embeddings = mask_non_image_embeddings diff --git a/colpali_engine/models/modernvbert/configuration_modernvbert.py b/colpali_engine/models/modernvbert/configuration_modernvbert.py new file mode 100644 index 00000000..d5225f8b --- /dev/null +++ b/colpali_engine/models/modernvbert/configuration_modernvbert.py @@ -0,0 +1,273 @@ +import copy +import os +from typing import Any, Dict, Union + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m" +DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" + +def collect_arg_in_candidates(config, candidates, default=None) -> Any: + """Gets the first available argument in a config given a list of candidate names.""" + for c in candidates: + if hasattr(config, c): + return getattr(config, c) + elif c in config: + return config[c] + if default is not None: + return default + raise ValueError( + f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}" + ) + +class ModernVBertTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBERT`]. It is used to instantiate an ModernBERT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + model_type = "modernvbert_text" + + def __init__( + self, + text_model_name=DEFAULT_TEXT_MODEL_NAME, + hidden_size=768, + num_hidden_layers=22, + intermediate_size=1152, + mlp_bias=False, + vocab_size=50368, + **kwargs, + ): + super().__init__( + text_model_name=text_model_name, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + mlp_bias=mlp_bias, + vocab_size=vocab_size, + **kwargs, + ) + + @classmethod + def from_base_model( + cls, + text_model_name=DEFAULT_TEXT_MODEL_NAME, + **kwargs, + ): + text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + if hasattr(text_config, "text_config"): + text_config = text_config.text_config + + hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) + num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) + intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) + mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False) + vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) + + return cls( + text_model_name=text_model_name, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + mlp_bias=mlp_bias, + vocab_size=vocab_size, + **kwargs, + ) + +class ModernVBertVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate the vision encoder part of the ModernVBERT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the SigLIP. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + model_type = "modernvbert_vision" + + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + vision_model_name=DEFAULT_VISION_MODEL_NAME, + embed_dim=768, + image_size=512, + patch_size=16, + num_hidden_layers=12, + intermediate_size=3072, + **kwargs, + ): + super().__init__( + vision_model_name=vision_model_name, + embed_dim=embed_dim, + image_size=image_size, + patch_size=patch_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + **kwargs, + ) + + @classmethod + def from_base_model( + cls, + vision_model_name=DEFAULT_VISION_MODEL_NAME, + **kwargs, + ): + vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + if hasattr(vision_config, "vision_config"): + vision_config = vision_config.vision_config + + embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) + image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) + patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) + num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) + intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) + + return cls( + vision_model_name=vision_model_name, + embed_dim=embed_dim, + image_size=image_size, + patch_size=patch_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + **kwargs, + ) + + +class ModernVBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a `ModernVBert` model. It is used to + instantiate a ModernVBert model according to the specified arguments and defines the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. + See the documentation for [`PretrainedConfig`] for more details. + + Args: + text_config (`PretrainedConfig` or `dict`, optional): + Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the + default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used. + vision_config (`PretrainedConfig` or `dict`, optional): + Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the + default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used. + image_token_id (`int`, optional, defaults to 128257): + Token id reserved for image tokens inserted into the text stream. + vocab_size (`int`, optional, defaults to 128256): + Vocabulary size used by the text embeddings. + use_cache (`bool`, optional, defaults to `True`): + Whether to cache key/value tensors for attention (relevant for decoder architectures). + tie_word_embeddings (`bool`, optional, defaults to `False`): + Whether to tie input token embeddings and output token embeddings. + pixel_shuffle_factor (`int`, optional, defaults to 4): + Scale factor used by any pixel-shuffle / upsampling operations in the vision head. + additional_vocab_size (`int`, optional, defaults to 0): + Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens). + pad_token_id (`int`, optional): + Padding token id. + initializer_range (`float`, optional, defaults to 0.02): + Stddev used for weight initialization. + freeze_config (`Any`, optional): + Optional config describing which submodules to freeze during training. + use_resampler (`bool`, optional, defaults to `False`): + Whether to enable an additional resampler on visual features. + neftune_noise_alpha (`float`, optional, defaults to 0.0): + Alpha parameter for neftune noise injection. + + Example: + ```python + >>> from modernvbert import ModernVBertConfig + >>> # Initializing configuration + >>> configuration = ModernVBertConfig() + >>> # Initializing a model from the configuration (model class is implemented in + >>> # `modernvbert.modeling_modernvbert`) + >>> # from modernvbert import ModernVBertModel + >>> # model = ModernVBertModel(configuration) + >>> # Accessing the model configuration + >>> # cfg = model.config + ```""" + + model_type = "modernvbert" + is_composition = True + + def __init__( + self, + text_config: Union[PretrainedConfig, Dict[str, Any]] = None, + vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, + image_token_id: int = 50407, + vocab_size=50368, + use_cache=True, + tie_word_embeddings=False, + freeze_config=None, + pad_token_id=None, + initializer_range=0.02, + pixel_shuffle_factor=4, + use_resampler=False, + additional_vocab_size=0, + neftune_noise_alpha=0.0, + **kwargs, + ): + self.image_token_id = image_token_id + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings + self.scale_factor = pixel_shuffle_factor + self.additional_vocab_size = additional_vocab_size + + if text_config is None: + base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) + text_config = ModernVBertTextConfig(base_text_config) + elif isinstance(text_config, dict): + text_config = ModernVBertTextConfig.from_dict(text_config) + self.text_config = text_config + + if vision_config is None: + base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) + vision_config = ModernVBertVisionConfig(base_vision_config) + elif isinstance(vision_config, dict): + vision_config = ModernVBertVisionConfig.from_dict(vision_config) + self.vision_config = vision_config + + self.freeze_config = freeze_config + self.pixel_shuffle_factor = pixel_shuffle_factor + self.use_resampler = use_resampler + self.neftune_noise_alpha = neftune_noise_alpha + self.initializer_range = initializer_range + + hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) + + super().__init__( + **kwargs, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + vocab_size=vocab_size, + hidden_size=hidden_size, + ) + + def to_dict(self): + output = copy.deepcopy(self.__dict__) + output["model_type"] = self.__class__.model_type + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + return output + + @classmethod + def from_pretrained_models( + cls, + text_model_name: Union[str, os.PathLike], + vision_model_name: Union[str, os.PathLike], + **kwargs, + ) -> "PretrainedConfig": + text_model_config = ModernVBertTextConfig.from_base_model(text_model_name) + vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name) + return cls( + text_config=text_model_config, + vision_config=vision_model_config, + **kwargs, + ) \ No newline at end of file diff --git a/colpali_engine/models/modernvbert/configuration_vbert.py b/colpali_engine/models/modernvbert/configuration_vbert.py deleted file mode 100644 index 4f0a8daf..00000000 --- a/colpali_engine/models/modernvbert/configuration_vbert.py +++ /dev/null @@ -1,232 +0,0 @@ -import copy -import os -from typing import Any, Dict, Union - -from transformers import AutoConfig -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -def collect_arg_in_candidates(config, candidates, default = None) -> Any: - """ Gets the argument in a config given a list of candidates """ - for c in candidates: - if hasattr(config, c): - return getattr(config, c) - elif c in config: - return config[c] - if default is not None: - return default - raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) - -class VBertTextConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "vbert" - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - text_model_name="jhu-clsp/ettin-encoder-150m", - **kwargs, - ): - self.text_model_name = text_model_name - text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - if hasattr(text_config, "text_config"): - text_config = text_config.text_config - - self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) - self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) - self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) - self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) - - super().__init__(text_model_name=text_model_name, **kwargs) - -class VBertVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "vbert" - attribute_map = { - "hidden_size": "embed_dim", - } - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - vision_model_name="google/siglip2-base-patch16-512", - **kwargs, - ): - self.vision_model_name = vision_model_name - vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - if hasattr(vision_config, "vision_config"): - vision_config = vision_config.vision_config - - self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) - self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) - self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) - self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) - - super().__init__(vision_model_name=vision_model_name, **kwargs) - -class VBertConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a - SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM - [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should cache the key/value pairs of the attention mechanism. Only - relevant if `config.is_decoder=True`. - image_token_id (`int`, *optional*, defaults to 128257): - The id of the "image" token. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to tie the word embeddings with the token embeddings. - vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): - Custom vision config or dict for the vision tower - text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): - Custom text config or dict for the text model - scale_factor (`int`, *optional*, defaults to 2): - The scale factor for the image encoder. - pad_token_id (`int`, *optional*, defaults to 128002): - The id of the padding token. - - Example: - ```python - >>> from transformers import SmolVLMModel, SmolVLMConfig - >>> # Initializing configuration - >>> configuration = SmolVLMConfig() - >>> # Initializing a model from the configuration - >>> model = SmolVLMModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "vbert" - is_composition = True - # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig} - - DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m" - DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" - - def __init__( - self, - text_config: Union[PretrainedConfig, Dict[str, Any]] = None, - vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, - image_token_id: int = 128_257, - vocab_size=128_256, - use_cache = True, - tie_word_embeddings = False, - freeze_config = None, - pad_token_id = None, - initializer_range = 0.02, - pixel_shuffle_factor = 4, - use_resampler = False, - additional_vocab_size = 0, - neftune_noise_alpha = 0.0, - **kwargs, - ): - self.image_token_id = image_token_id - self.use_cache = use_cache - self.tie_word_embeddings = tie_word_embeddings - self.scale_factor = pixel_shuffle_factor - self.additional_vocab_size = additional_vocab_size - - if text_config is None: - text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) - elif isinstance(text_config, dict): - text_config = VBertTextConfig(text_config["text_model_name"]) - self.text_config = text_config - - if vision_config is None: - vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) - elif isinstance(vision_config, dict): - vision_config = VBertVisionConfig(vision_config["vision_model_name"]) - self.vision_config = vision_config - - self.freeze_config = freeze_config - - # Pixel shuffle factor - self.pixel_shuffle_factor = pixel_shuffle_factor - self.use_resampler = use_resampler - - self.neftune_noise_alpha = neftune_noise_alpha - - self.initializer_range = initializer_range - - hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) - - super().__init__( - **kwargs, - pad_token_id=pad_token_id, - tie_word_embeddings=tie_word_embeddings, - vocab_size=vocab_size, - hidden_size=hidden_size, - ) - - def to_dict(self): - """ - Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. - Returns: - `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, - """ - output = copy.deepcopy(self.__dict__) - - output["model_type"] = self.__class__.model_type - output["vision_config"] = self.vision_config.to_dict() - output["text_config"] = self.text_config.to_dict() - # output["freeze_config"] = self.freeze_config.to_dict() - - return output - - # @classmethod - # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - # outputs = super(VBertConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) - # return outputs - - @classmethod - def from_pretrained_models( - cls, - text_model_name: Union[str, os.PathLike], - vision_model_name: Union[str, os.PathLike], - **kwargs - ) -> "PretrainedConfig": - # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - text_model_config = VBertTextConfig(text_model_name) - vision_model_config = VBertVisionConfig(vision_model_name) - return cls( - text_config=text_model_config, - vision_config=vision_model_config, - **kwargs - ) diff --git a/colpali_engine/models/eurovbert/modeling_vbert.py b/colpali_engine/models/modernvbert/modeling_modernvbert.py similarity index 64% rename from colpali_engine/models/eurovbert/modeling_vbert.py rename to colpali_engine/models/modernvbert/modeling_modernvbert.py index 5a8ca5e1..94736d17 100644 --- a/colpali_engine/models/eurovbert/modeling_vbert.py +++ b/colpali_engine/models/modernvbert/modeling_modernvbert.py @@ -4,14 +4,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging -from transformers.cache_utils import DynamicCache from transformers.modeling_outputs import BaseModelOutput from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput -from .configuration_vbert import VBertConfig +from .configuration_modernvbert import ModernVBertConfig logger = logging.get_logger(__name__) @@ -43,6 +41,7 @@ def __init__( """ if padding_idx is not None and padding_idx > num_embeddings: raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") + super().__init__( num_embeddings=num_embeddings, embedding_dim=embedding_dim, @@ -52,7 +51,6 @@ def __init__( **kwargs, ) self.num_embeddings = num_embeddings - self.padding_idx = padding_idx self.num_additional_embeddings = num_additional_embeddings self.partially_freeze = partially_freeze @@ -61,7 +59,7 @@ def __init__( if self.num_additional_embeddings > 0: self.additional_embedding = nn.Embedding( - num_embeddings=self.num_additional_embeddings, + num_embeddings=num_additional_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype, @@ -89,9 +87,8 @@ def forward(self, input_ids): """ if self.num_additional_embeddings == 0: - return self.additional_embedding(input_ids) + return super().forward(input_ids) - # Clone so that we don't modify the original input_ids later on input_ids = input_ids.clone() additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) input_ids_additional_vocab = input_ids[additional_vocab_indices] @@ -100,37 +97,19 @@ def forward(self, input_ids): # for successful lookup replace input_ids with 0, the results of these will be discarded anyway input_ids[additional_vocab_indices] = 0 full_vector = F.embedding(input_ids, self.weight) - - # overwrite the records with high indices - full_vector[additional_vocab_indices] = additional_embeddings - + full_vector[additional_vocab_indices] = additional_embeddings # overwrite the records with high indices return full_vector - def extra_repr(self) -> str: - return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( - self.num_embeddings, - self.num_additional_embeddings, - self.embedding_dim, - self.partially_freeze, - ) @dataclass -class VBertBaseModelOutput(BaseModelOutput): +class ModernVBertBaseModelOutput(BaseModelOutput): """ - Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding). Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. @@ -145,16 +124,16 @@ class VBertBaseModelOutput(BaseModelOutput): sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder """ - last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + @dataclass -class VBertMaskedLMOutput(MaskedLMOutput): +class ModernVBertMaskedLMOutput(MaskedLMOutput): """ - Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding). Args: loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): Masked language modeling (MLM) loss. @@ -180,7 +159,9 @@ class VBertMaskedLMOutput(MaskedLMOutput): attentions: Optional[Tuple[torch.FloatTensor, ...]] = None image_hidden_states: Optional[torch.FloatTensor] = None -class VBertSimpleMLP(nn.Module): + +class ModernVBertSimpleMLP(nn.Module): + """A simple linear projection layer to project the vision hidden states to the text hidden states.""" def __init__(self, input_size, output_size): super().__init__() self.proj = nn.Linear(input_size, output_size, bias=False) @@ -188,13 +169,18 @@ def __init__(self, input_size, output_size): def forward(self, x): return self.proj(x) -class VBertConnector(nn.Module): + +class ModernVBertConnector(nn.Module): + """ + Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size. + Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html + """ def __init__(self, config): super().__init__() self.scale_factor = config.pixel_shuffle_factor - self.modality_projection = VBertSimpleMLP( + self.modality_projection = ModernVBertSimpleMLP( input_size=config.vision_config.hidden_size * (config.scale_factor**2), - output_size=config.text_config.hidden_size + output_size=config.text_config.hidden_size, ) def pixel_shuffle(self, x, scale_factor): @@ -205,36 +191,25 @@ def pixel_shuffle(self, x, scale_factor): x = x.permute(0, 2, 1, 3) x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) - return x + return x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) def forward(self, image_hidden_states): image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) - image_hidden_states = self.modality_projection(image_hidden_states) - return image_hidden_states + return self.modality_projection(image_hidden_states) + -class VBertPreTrainedModel(PreTrainedModel): - config_class = VBertConfig +class ModernVBertPreTrainedModel(PreTrainedModel): + config_class = ModernVBertConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["VBertDecoderLayer"] + _no_split_modules = ["ModernVBertDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): - """Initialize the weights.""" - - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - + std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: @@ -244,61 +219,41 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() -class VBertModel(VBertPreTrainedModel): - """ - A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger - in forward. Instead, we override inputs_merger here with custom logic. - """ - def __init__(self, config: VBertConfig, **kwargs): +class ModernVBertModel(ModernVBertPreTrainedModel): + def __init__(self, config: ModernVBertConfig, **kwargs): super().__init__(config) - - self.vision_model = VBertModel.init_vision_model(config, **kwargs) - self.connector = VBertConnector(config) - self.text_model = VBertModel.init_language_model(config, **kwargs) - + self.vision_model = ModernVBertModel.init_vision_model(config, **kwargs) + self.connector = ModernVBertConnector(config) + self.text_model = ModernVBertModel.init_language_model(config, **kwargs) self.image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) ) - self.image_token_id = self.config.image_token_id - + self.image_token_id = config.image_token_id self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self.post_init() @staticmethod - def init_vision_model(config: VBertConfig, **kwargs): + def init_vision_model(config: ModernVBertConfig, **kwargs): vision_model_config = AutoConfig.from_pretrained( config.vision_config.vision_model_name, _attn_implementation=config._attn_implementation, - torch_dtype=config.torch_dtype, + dtype=config.torch_dtype, **kwargs, ) - - vision_model = AutoModel.from_config(vision_model_config,**kwargs) - - if hasattr(vision_model, "vision_model"): - # If the model has a vision_model attribute, it means it's a wrapper around another model - vision_model = vision_model.vision_model - - return vision_model + vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) + return getattr(vision_model, "vision_model", vision_model) @staticmethod - def init_language_model(config: VBertConfig, **kwargs): + def init_language_model(config: ModernVBertConfig, **kwargs): text_model_config = AutoConfig.from_pretrained( config.text_config.text_model_name, _attn_implementation=config._attn_implementation, - torch_dtype=config.torch_dtype, + dtype=config.torch_dtype, trust_remote_code=True, **kwargs, ) - - text_model = AutoModel.from_config( - text_model_config, - trust_remote_code=True, - **kwargs - ) - + text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) embed_layer = DecoupledEmbedding( num_embeddings=text_model_config.vocab_size, num_additional_embeddings=config.additional_vocab_size, @@ -306,11 +261,9 @@ def init_language_model(config: VBertConfig, **kwargs): partially_freeze=config.freeze_config["freeze_text_layers"], padding_idx=config.pad_token_id, ) - text_model.set_input_embeddings(embed_layer) - return text_model - + def enable_input_require_grads(self): """ Enables the gradients for the input embeddings. @@ -337,20 +290,15 @@ def make_inputs_require_grads(module, input, output): make_inputs_require_grads ) - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.text_model.get_input_embeddings() def set_input_embeddings(self, value): self.text_model.set_input_embeddings(value) - def inputs_merger( - self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor - ): - """ + def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states): + """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py + This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. The merging happens as follows: - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. @@ -359,34 +307,28 @@ def inputs_merger( - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. """ - _, patch_size, _ = image_hidden_states.shape + _, patch_size, _ = image_hidden_states.shape image_mask = input_ids == self.image_token_id num_image_tokens = image_mask.sum(dim=1) if not torch.all(num_image_tokens % patch_size == 0): - raise ValueError("At least one sample has tokens not divisible by patch_size.") - + raise ValueError("Number of tokens not divisible by patch_size.") blocks_per_sample = num_image_tokens // patch_size - offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) block_offset = offsets[:-1] row_cum = image_mask.cumsum(dim=-1) chunk_idx = (row_cum - 1) // patch_size local_idx = (row_cum - 1) % patch_size block_idx = block_offset.unsqueeze(1) + chunk_idx - image_embeds = torch.zeros_like(inputs_embeds) image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] - - merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) - return merged_embeds + return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -400,76 +342,22 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) - - # START VISUAL INPUTS INTEGRATION - if pixel_values is not None and image_hidden_states is not None: - raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") - elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values + if pixel_values is not None: + batch_size, num_images, _, _, _ = pixel_values.shape pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - if not any(real_images_inds): - # no images, leave one empty image. real_images_inds[0] = True - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=[pixel_values.shape[i] for i in (0, 2, 3)], - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - # patch_size = self.config.vision_config.patch_size - # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - # patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling + image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state image_hidden_states = self.connector(image_hidden_states) - elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - if inputs_embeds is not None and image_hidden_states is not None: - # When we embed, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self.inputs_merger( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - image_hidden_states=image_hidden_states, - ) - + inputs_embeds = self.inputs_merger(input_ids, inputs_embeds, image_hidden_states) outputs = self.text_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, @@ -478,99 +366,64 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - if not return_dict: return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - - return VBertBaseModelOutput( + return ModernVBertBaseModelOutput( last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_hidden_states, ) -class VBertForMaskedLM(VBertPreTrainedModel): - # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] +class ModernVBertLMHead(nn.Module): + def __init__(self, config, **kwargs): + super().__init__() + pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True, **kwargs) + pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) + self.head = pretrained_model.head + self.decoder = pretrained_model.decoder + + def forward(self, hidden_states): + return self.decoder(self.head(hidden_states)) + +class ModernVBertForMaskedLM(ModernVBertPreTrainedModel): def __init__(self, config, **kwargs): super().__init__(config) - self.image_token_id = config.image_token_id self.in_features = config.hidden_size self.out_additional_features = config.additional_vocab_size self.vocab_size = config.vocab_size - - if config.is_decoder: - logger.warning( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.model = VBertModel(config, **kwargs) - self.lm_head = VBertForMaskedLM.init_lm_head(config, **kwargs) + self.model = ModernVBertModel(config, **kwargs) + self.lm_head = ModernVBertLMHead(config, **kwargs) if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - # Initialize weights and apply final processing + self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False) self.post_init() - @staticmethod - def init_lm_head(config, **kwargs): - # Get the pretrained model config - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) - # Get the lm head - lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None - if lm_head is None: - logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") - lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) - return lm_head - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, VBertMaskedLMOutput]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, ModernVBertMaskedLMOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Pass the inputs to VBertModel outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, @@ -579,29 +432,21 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - # Pass the outputs to the MLM head hidden_states = outputs[0] - logits = self.lm_head(hidden_states) if self.out_additional_features > 0: - additional_features = self.additional_fc(hidden_states) + proj_states = self.lm_head.head(hidden_states) + additional_features = self.additional_fc(proj_states) logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - masked_lm_loss = None + loss = None if labels is not None: - # print the ratio of not ignored tokens - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) - + loss = CrossEntropyLoss()(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return VBertMaskedLMOutput( - loss=masked_lm_loss, - logits=logits, + return ((loss,) + output) if loss is not None else output + return ModernVBertMaskedLMOutput( + loss=loss, + logits=logits.float(), hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, diff --git a/colpali_engine/models/modernvbert/modeling_vbert.py b/colpali_engine/models/modernvbert/modeling_vbert.py deleted file mode 100644 index fe7903e0..00000000 --- a/colpali_engine/models/modernvbert/modeling_vbert.py +++ /dev/null @@ -1,935 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss -from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging -from transformers.cache_utils import DynamicCache -from transformers.modeling_outputs import BaseModelOutput -from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput - -from .configuration_vbert import VBertConfig - -logger = logging.get_logger(__name__) - -torch.set_float32_matmul_precision('high') - -class DecoupledEmbedding(nn.Embedding): - # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding - """ - Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. - In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. - If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. - """ - - def __init__( - self, - num_embeddings, - num_additional_embeddings, - embedding_dim, - partially_freeze=False, - device=None, - dtype=None, - padding_idx=None, - **kwargs, - ) -> None: - """ - num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. - partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. - - Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. - """ - if padding_idx is not None and padding_idx > num_embeddings: - raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") - super().__init__( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - padding_idx=padding_idx, - **kwargs, - ) - self.num_embeddings = num_embeddings - self.padding_idx = padding_idx - self.num_additional_embeddings = num_additional_embeddings - self.partially_freeze = partially_freeze - - if partially_freeze: - self.weight.requires_grad_(False) - - if self.num_additional_embeddings > 0: - self.additional_embedding = nn.Embedding( - num_embeddings=self.num_additional_embeddings, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - ) - - def forward(self, input_ids): - """ - we have 2 embeddings, with different indices - one pretrained self.weight and another - self.additional_embedding.weight that is being trained. - - in order to make a lookup of the input ids, we: - 1. find out the indices of the entries belonging to the 2nd embedding - 2. extract those values while subtracting the size of the first embedding (num_embeddings), - since the 2nd embedding starts from 0 and not num_embeddings - 3. perform the 2nd embedding lookup - 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index - 5. perform the 1st embedding lookup - 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup - - note: for the 1st embedding lookup we could have looked up only the low indices and not do - the padding, but then we have to create a new tensor and populate it with 2 tensors that are - spread out across various indices - i.e. not a simple concat - I haven't benchmarked the - complex case if it's any faster, given that seqlens are usually relatively short it's - probably not faster or if faster not by much - but might be a good idea to measure. - - """ - if self.num_additional_embeddings == 0: - return self.additional_embedding(input_ids) - - # Clone so that we don't modify the original input_ids later on - input_ids = input_ids.clone() - additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) - input_ids_additional_vocab = input_ids[additional_vocab_indices] - additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) - - # for successful lookup replace input_ids with 0, the results of these will be discarded anyway - input_ids[additional_vocab_indices] = 0 - full_vector = F.embedding(input_ids, self.weight) - - # overwrite the records with high indices - full_vector[additional_vocab_indices] = additional_embeddings - - return full_vector - - def extra_repr(self) -> str: - return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( - self.num_embeddings, - self.num_additional_embeddings, - self.embedding_dim, - self.partially_freeze, - ) - -@dataclass -class VBertBaseModelOutput(BaseModelOutput): - """ - Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - -@dataclass -class VBertMaskedLMOutput(MaskedLMOutput): - """ - Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). - Args: - loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. - logits (`torch.FloatTensor`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder - """ - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - image_hidden_states: Optional[torch.FloatTensor] = None - -class VBertSimpleMLP(nn.Module): - def __init__(self, input_size, output_size): - super().__init__() - self.proj = nn.Linear(input_size, output_size, bias=False) - - def forward(self, x): - return self.proj(x) - -class VBertConnector(nn.Module): - def __init__(self, config): - super().__init__() - self.scale_factor = config.pixel_shuffle_factor - self.modality_projection = VBertSimpleMLP( - input_size=config.vision_config.hidden_size * (config.scale_factor**2), - output_size=config.text_config.hidden_size - ) - - def pixel_shuffle(self, x, scale_factor): - bsz, seq, embed_dim = x.size() - height = width = int(seq**0.5) - x = x.view(bsz, height, width, embed_dim) - x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) - x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) - x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) - return x - - def forward(self, image_hidden_states): - image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) - image_hidden_states = self.modality_projection(image_hidden_states) - return image_hidden_states - -class VBertPreTrainedModel(PreTrainedModel): - config_class = VBertConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["VBertDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - """Initialize the weights.""" - - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - -class VBertModel(VBertPreTrainedModel): - """ - A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger - in forward. Instead, we override inputs_merger here with custom logic. - """ - - def __init__(self, config: VBertConfig, **kwargs): - super().__init__(config) - - self.vision_model = VBertModel.init_vision_model(config, **kwargs) - self.connector = VBertConnector(config) - self.text_model = VBertModel.init_language_model(config, **kwargs) - - self.image_seq_len = int( - ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) - ) - self.image_token_id = self.config.image_token_id - - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - self.post_init() - - @staticmethod - def init_vision_model(config: VBertConfig, **kwargs): - vision_model_config = AutoConfig.from_pretrained( - config.vision_config.vision_model_name, - _attn_implementation=config._attn_implementation, - torch_dtype=config.torch_dtype, - **kwargs, - ) - - vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) - - if hasattr(vision_model, "vision_model"): - # If the model has a vision_model attribute, it means it's a wrapper around another model - vision_model = vision_model.vision_model - - return vision_model - - @staticmethod - def init_language_model(config: VBertConfig, **kwargs): - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - _attn_implementation=config._attn_implementation, - torch_dtype=config.torch_dtype, - trust_remote_code=True, - **kwargs, - ) - - text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) - - embed_layer = DecoupledEmbedding( - num_embeddings=text_model_config.vocab_size, - num_additional_embeddings=config.additional_vocab_size, - embedding_dim=config.hidden_size, - partially_freeze=config.freeze_config["freeze_text_layers"], - padding_idx=config.pad_token_id, - ) - - text_model.set_input_embeddings(embed_layer) - - return text_model - - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - - def get_input_embeddings(self): - return self.text_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.text_model.set_input_embeddings(value) - - def inputs_merger( - self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor - ): - """ - This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. - The merging happens as follows: - - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. - - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. - We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. - - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. - """ - _, patch_size, _ = image_hidden_states.shape - - image_mask = input_ids == self.image_token_id - num_image_tokens = image_mask.sum(dim=1) - if not torch.all(num_image_tokens % patch_size == 0): - raise ValueError("At least one sample has tokens not divisible by patch_size.") - - blocks_per_sample = num_image_tokens // patch_size - - offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) - block_offset = offsets[:-1] - row_cum = image_mask.cumsum(dim=-1) - chunk_idx = (row_cum - 1) // patch_size - local_idx = (row_cum - 1) % patch_size - block_idx = block_offset.unsqueeze(1) + chunk_idx - - image_embeds = torch.zeros_like(inputs_embeds) - image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] - - merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) - return merged_embeds - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.training and self.text_model.gradient_checkpointing and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # retrieve input_ids and inputs_embeds - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_seen_tokens = 0 - if use_cache: - if past_key_values is None: - past_key_values = DynamicCache() - past_seen_tokens = past_key_values.get_seq_length() - - if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") - - if inputs_embeds is None: - inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) - - # START VISUAL INPUTS INTEGRATION - if pixel_values is not None and image_hidden_states is not None: - raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") - elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values - pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - - if not any(real_images_inds): - # no images, leave one empty image. - real_images_inds[0] = True - - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=[pixel_values.shape[i] for i in (0, 2, 3)], - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - # patch_size = self.config.vision_config.patch_size - # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - # patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling - image_hidden_states = self.connector(image_hidden_states) - - elif image_hidden_states is not None: - image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - - if inputs_embeds is not None and image_hidden_states is not None: - # When we embed, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self.inputs_merger( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - image_hidden_states=image_hidden_states, - ) - - outputs = self.text_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - # past_key_values=past_key_values, - # use_cache=use_cache, - # cache_position=cache_position, - ) - - if not return_dict: - return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - - return VBertBaseModelOutput( - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_hidden_states, - ) - -class VBertForMaskedLM(VBertPreTrainedModel): - # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] - - def __init__(self, config, **kwargs): - super().__init__(config) - - self.image_token_id = config.image_token_id - self.in_features = config.hidden_size - self.out_additional_features = config.additional_vocab_size - self.vocab_size = config.vocab_size - - if config.is_decoder: - logger.warning( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.model = VBertModel(config, **kwargs) - self.lm_head = VBertForMaskedLM.init_lm_head(config, **kwargs) - if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - # Initialize weights and apply final processing - self.post_init() - - @staticmethod - def init_lm_head(config, **kwargs): - # Get the pretrained model config - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) - # Get the lm head - lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None - if lm_head is None: - logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") - lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) - return lm_head - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, VBertMaskedLMOutput]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - # Pass the inputs to VBertModel - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Pass the outputs to the MLM head - hidden_states = outputs[0] - - logits = self.lm_head(hidden_states) - if self.out_additional_features > 0: - additional_features = self.additional_fc(hidden_states) - logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - masked_lm_loss = None - if labels is not None: - # print the ratio of not ignored tokens - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return VBertMaskedLMOutput( - loss=masked_lm_loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - # @classmethod - # def from_pretrained_models( - # cls, - # text_model_name, - # vision_model_name, - # vl_config, - # *args, - # **kwargs - # ): - # """ - # Use this method when creating a new vloom model that hasn't been yet trained and it'll be - # composed of 2 pre-trained models - hence `pretrained_models`. - # """ - # model = super().from_pretrained_models( - # text_model_name=text_model_name, - # vision_model_name=vision_model_name, - # vl_config=vl_config, - # *args, - # **kwargs - # ) - # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - # # fetch the pretrained text model w/o zero.Init - # pretrained_lm_head = AutoModelForMaskedLM.from_pretrained( - # text_model_name, trust_remote_code=True, **kwargs - # ).lm_head - - # # Load the lm_head - # load_state_dict_into_model(model.lm_head, pretrained_lm_head.state_dict(), start_prefix="") - - # return model - -class VModernBertLMHead(nn.Module): - def __init__(self, config, **kwargs): - super().__init__() - pretrained_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) - - self.head = pretrained_model.head - self.decoder = pretrained_model.decoder - - def forward(self, hidden_states): - hidden_states = self.head(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - # @classmethod - # def from_pretrained( - # cls, - # text_model_name, - # vl_config, - # *args, - # **kwargs - # ): - # """ - # Use this method when creating a new vloom model that hasn't been yet trained and it'll be - # composed of 2 pre-trained models - hence `pretrained_models`. - # """ - # lm_head = cls(vl_config, *args, **kwargs) - - # with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - # # fetch the pretrained text model w/o zero.Init - # pretrained_model = AutoModelForMaskedLM.from_pretrained( - # text_model_name, trust_remote_code=True, **kwargs - # ) - - # pretrained_head = pretrained_model.head - # pretrained_decoder = pretrained_model.decoder - - # # Load the head - # load_state_dict_into_model(lm_head.head, pretrained_head.state_dict(), start_prefix="") - # load_state_dict_into_model(lm_head.decoder, pretrained_decoder.state_dict(), start_prefix="") - - # return lm_head - -class VModernBertForMaskedLM(VBertPreTrainedModel): - # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] - - def __init__(self, config, **kwargs): - super().__init__(config) - - self.image_token_id = config.image_token_id - self.in_features = config.hidden_size - self.out_additional_features = config.additional_vocab_size - self.vocab_size = config.vocab_size - - if config.is_decoder: - logger.warning( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.model = VBertModel(config, **kwargs) - self.lm_head = VModernBertLMHead(config, **kwargs) - - if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, VBertMaskedLMOutput]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - # Pass the inputs to VBertModel - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - # Pass the outputs to the MLM head - hidden_states = outputs[0] - - logits = self.lm_head(hidden_states) - if self.out_additional_features > 0: - proj_states = self.lm_head.head(hidden_states) - additional_features = self.additional_fc(proj_states) - logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - masked_lm_loss = None - if labels is not None: - # print the ratio of not ignored tokens - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return VBertMaskedLMOutput( - loss=masked_lm_loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None): - config_vl_model = self.config - - lm_config = config_vl_model.text_config - - language_embed_size = lm_config.hidden_size - num_language_layers = lm_config.num_hidden_layers - ffn_inner_size = lm_config.intermediate_size - - vision_config = config_vl_model.vision_config - - # Get vision model blocks infos - vision_patch_size = vision_config.patch_size - vision_hidden_size = vision_config.embed_dim - num_vision_layers = vision_config.num_hidden_layers - # The +1 is for the CLS token - single_image_vision_encoder_seq_len = int(((vision_config.image_size // vision_patch_size) ** 2) // (self.config.pixel_shuffle_factor**2)) - vision_exp_factor = vision_config.intermediate_size // vision_hidden_size - - # Get language blocks infos - language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len - language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4 - - # Get modality projection infos - vision_pipeline_output_seq_len = ( - self.config.perceiver_config.resampler_n_latents - if self.config.use_resampler - else single_image_vision_encoder_seq_len - ) - - language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( - num_layers=num_language_layers, - batch_size=hparams.batch_size_per_gpu, - q_seq_len=language_seq_len, - k_seq_len=language_seq_len, - hidden_size=language_embed_size, - kv_in_dim=language_embed_size, - ff_exp_factor=language_exp_factor, - grad_acc_size=hparams.grad_acc_size, - swiglu=True, - vocab_size=tokenizer.vocab_size, - count_backward=True, # Always True regardless of freezing, because gradients are computed for vision adaptor - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu( - batch_size=hparams.batch_size_per_gpu * max_num_images, - seq_len=vision_pipeline_output_seq_len, - in_features=vision_hidden_size, - out_features=language_embed_size, - count_backward=True, - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - - vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu( - num_layers=num_vision_layers, - batch_size=hparams.batch_size_per_gpu * max_num_images, - q_seq_len=single_image_vision_encoder_seq_len, - k_seq_len=single_image_vision_encoder_seq_len, - hidden_size=vision_hidden_size, - kv_in_dim=vision_hidden_size, - ff_exp_factor=vision_exp_factor, - grad_acc_size=hparams.grad_acc_size, - swiglu=False, - vocab_size=None, - count_backward=not hparams.model_config["freeze_config"]["freeze_vision_layers"], - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - if self.config.use_resampler: - perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu( - num_layers=self.config.perceiver_config.resampler_depth, - batch_size=hparams.batch_size_per_gpu * max_num_images, - q_seq_len=self.config.perceiver_config.resampler_n_latents, - vision_embed_seq_len=single_image_vision_encoder_seq_len, - q_k_v_input_dim=vision_hidden_size, - attention_hidden_size=self.config.perceiver_config.resampler_n_heads - * self.config.perceiver_config.resampler_head_dim, - ff_exp_factor=4, - count_backward=True, - use_grad_checkpointing=hparams.gradient_checkpointing, - ) - tflop_count = ( - language_tflops_per_batch_per_gpu - + modality_projection_tflops_per_batch_per_gpu - + perceiver_tflops_per_batch_per_gpu - + vision_tflops_per_batch_per_gpu - ) - else: - tflop_count = ( - language_tflops_per_batch_per_gpu - + modality_projection_tflops_per_batch_per_gpu - + vision_tflops_per_batch_per_gpu - ) - return tflop_count - - @classmethod - def from_pretrained_models( - cls, - text_model_name, - vision_model_name, - vl_config, - *args, - **kwargs - ): - """ - Use this method when creating a new vloom model that hasn't been yet trained and it'll be - composed of 2 pre-trained models - hence `pretrained_models`. - """ - model = super().from_pretrained_models( - text_model_name=text_model_name, - vision_model_name=vision_model_name, - vl_config=vl_config, - *args, - **kwargs - ) - - # Load the lm_head - model.lm_head = VModernBertLMHead.from_pretrained( - text_model_name=text_model_name, - vl_config=vl_config, - *args, - **kwargs - ) - - return model diff --git a/colpali_engine/models/siglip/__init__.py b/colpali_engine/models/siglip/__init__.py deleted file mode 100644 index f1bb314b..00000000 --- a/colpali_engine/models/siglip/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_bisiglip import BiSiglip -from .processing_bisiglip import BiSiglipProcessor \ No newline at end of file diff --git a/colpali_engine/models/siglip/modeling_bisiglip.py b/colpali_engine/models/siglip/modeling_bisiglip.py deleted file mode 100644 index 97f65a6f..00000000 --- a/colpali_engine/models/siglip/modeling_bisiglip.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import ClassVar - -from transformers import SiglipModel - - -class BiSiglip(SiglipModel): - main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related - - def forward(self, *args, **kwargs): - """ - Forward pass through Llama and the linear layer for dimensionality reduction - - Args: - - input_ids (torch.LongTensor): The input tokens tensor. - - attention_mask (torch.LongTensor): The attention mask tensor. - - Returns: - - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) - """ - - output_attentions = kwargs.pop("output_attentions", None) - output_hidden_states = kwargs.pop("output_hidden_states", None) - return_dict = kwargs.pop("return_dict", None) - interpolate_pos_encoding = kwargs.pop("interpolate_pos_encoding", None) - - if "pixel_values" in kwargs: - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - pixel_values = kwargs.pop("pixel_values") - - embeds = self.vision_model( - pixel_values=pixel_values.to(dtype=self.dtype), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - interpolate_pos_encoding=interpolate_pos_encoding, - ).pooler_output - - else: - embeds = self.text_model( - input_ids=kwargs.pop("input_ids", None), - attention_mask=kwargs.pop("attention_mask", None), - position_ids=kwargs.pop("position_ids", None), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ).pooler_output - - # normalized features - embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True) - return embeds \ No newline at end of file diff --git a/colpali_engine/models/siglip/processing_bisiglip.py b/colpali_engine/models/siglip/processing_bisiglip.py deleted file mode 100644 index 073b6372..00000000 --- a/colpali_engine/models/siglip/processing_bisiglip.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import ClassVar, List, Optional, Tuple, Union - -import torch -from PIL import Image -from transformers import BatchEncoding, BatchFeature -from transformers.models.siglip import SiglipProcessor - -from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor - - -class BiSiglipProcessor(BaseVisualRetrieverProcessor, SiglipProcessor): # noqa: N801 - """ - Processor for BiSiglip - """ - - query_augmentation_token: ClassVar[str] = "" - - def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: - """ - Args: - texts: List of input texts. - - Returns: - Union[BatchFeature, BatchEncoding]: Processed texts. - """ - return self( - text=texts, - return_tensors="pt", - padding="max_length", # the model was trained with max_length padding - max_length=64, - truncation=True, - ) - - def process_images( - self, - images: List[Image.Image], - ) -> Union[BatchFeature, BatchEncoding]: - """ - Args: - images: List of PIL images. - """ - images = [image.convert("RGB") for image in images] - - batch_doc = self( - images=images, - return_tensors="pt", - padding="longest", # the model was trained with max_length padding - ) - return batch_doc - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the cosine similarity for the given query and passage embeddings. - """ - return self.score_single_vector(qs, ps, device=device) - - def get_n_patches( - self, - image_size: Tuple[int, int], - spatial_merge_size: int, - ) -> Tuple[int, int]: - """ - Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of - size (height, width) with the given patch size. - - The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in - as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`. - """ - raise NotImplementedError("BiSiglip does not support the `get_n_patches` method. ") \ No newline at end of file diff --git a/colpali_engine/models/vllama/__init__.py b/colpali_engine/models/vllama/__init__.py deleted file mode 100644 index 534ea814..00000000 --- a/colpali_engine/models/vllama/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .bivllama import BiVLlama, BiVLlamaProcessor -from .colvllama import ColVLlama, ColVLlamaProcessor diff --git a/colpali_engine/models/vllama/bivllama/__init__.py b/colpali_engine/models/vllama/bivllama/__init__.py deleted file mode 100644 index 55f602b0..00000000 --- a/colpali_engine/models/vllama/bivllama/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_bivllama import BiVLlama -from .processing_bivllama import BiVLlamaProcessor diff --git a/colpali_engine/models/vllama/bivllama/modeling_bivllama.py b/colpali_engine/models/vllama/bivllama/modeling_bivllama.py deleted file mode 100644 index a1ec56e6..00000000 --- a/colpali_engine/models/vllama/bivllama/modeling_bivllama.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Literal - -import torch - -from colpali_engine.models.vllama.modeling_vllama import VLlamaModel, VLlamaPreTrainedModel - - -class BiVLlama(VLlamaPreTrainedModel): - """ - Initializes the BiVLlama model. - - Args: - config : The model configuration. - """ - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def __init__(self, config, pooling_strategy = "last", **kwargs): - super().__init__(config=config) - self.model = VLlamaModel(config, **kwargs) - self.pooling_strategy = pooling_strategy - self.post_init() - - def forward( - self, - pooling_strategy: Literal["cls", "last", "mean"] = None, - *args, - **kwargs, - ) -> torch.Tensor: - """ - Forward pass through model and pooling. - - Args: - - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". - - input_ids (torch.LongTensor): The input tokens tensor. - - attention_mask (torch.LongTensor): The attention mask tensor. - - Returns: - - torch.Tensor: Embeddings of shape (batch_size, dim) - """ - outputs = self.model(*args, **kwargs) - last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - - pooling_strategy = pooling_strategy or self.pooling_strategy - - # Get CLS token embedding, last token, or mean pool over sequence - if pooling_strategy == "cls": - # Use CLS token (first token) embedding - pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) - elif pooling_strategy == "last": - # use last token since we are left padding - pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) - elif pooling_strategy == "mean": - # Mean pooling over sequence length - mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) - pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) - else: - raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") - - # L2 normalization - pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True) - return pooled_output diff --git a/colpali_engine/models/vllama/bivllama/processing_bivllama.py b/colpali_engine/models/vllama/bivllama/processing_bivllama.py deleted file mode 100644 index e6fa2908..00000000 --- a/colpali_engine/models/vllama/bivllama/processing_bivllama.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import List, Optional, Union - -import torch -from transformers import BatchFeature, BatchEncoding - -from colpali_engine.models.vllama.colvllama import ColVLlamaProcessor - - -class BiVLlamaProcessor(ColVLlamaProcessor): # noqa: N801 - """ - Processor for BiVLlama model. - """ - - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for BiVLlama. - - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - suffix = self.query_augmentation_token # we remove buffer tokens - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] - - batch_texts = self( - text=prompts, - return_tensors="pt", - padding="longest", - ) - - return batch_texts - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the cosine similarity for the given query and passage embeddings. - """ - return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/vllama/colvllama/__init__.py b/colpali_engine/models/vllama/colvllama/__init__.py deleted file mode 100644 index 00dae459..00000000 --- a/colpali_engine/models/vllama/colvllama/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_colvllama import ColVLlama -from .processing_colvllama import ColVLlamaProcessor diff --git a/colpali_engine/models/vllama/colvllama/modeling_colvllama.py b/colpali_engine/models/vllama/colvllama/modeling_colvllama.py deleted file mode 100644 index a5a0114e..00000000 --- a/colpali_engine/models/vllama/colvllama/modeling_colvllama.py +++ /dev/null @@ -1,51 +0,0 @@ -from torch import nn - -from colpali_engine.models.vllama.modeling_vllama import VLlamaModel, VLlamaPreTrainedModel - - -class ColVLlama(VLlamaPreTrainedModel): - """ - Initializes the ColVBert model. - - Args: - config : The model configuration. - mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings - except those of the image at inference. - Defaults to False --> Do not mask any embeddings during forward pass. - """ - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): - super().__init__(config=config) - self.model = VLlamaModel(config, **kwargs) - self.dim = 128 - self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) - self.mask_non_image_embeddings = mask_non_image_embeddings - self.main_input_name = "doc_input_ids" - - def forward(self, *args, **kwargs): - """ - Forward pass through the model and the linear layer for dimensionality reduction - - Args: - - input_ids (torch.LongTensor): The input tokens tensor. - - attention_mask (torch.LongTensor): The attention mask tensor. - - Returns: - - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) - """ - outputs = self.model(*args, **kwargs) - last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) - proj = self.custom_text_proj(last_hidden_states) - # normalize l2 norm - proj = proj / proj.norm(dim=-1, keepdim=True) - proj = proj * kwargs["attention_mask"].unsqueeze(-1) - - if "pixel_values" in kwargs and self.mask_non_image_embeddings: - # Pools only the image embeddings - image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) - proj = proj * image_mask - return proj diff --git a/colpali_engine/models/vllama/colvllama/processing_colvllama.py b/colpali_engine/models/vllama/colvllama/processing_colvllama.py deleted file mode 100644 index d983e075..00000000 --- a/colpali_engine/models/vllama/colvllama/processing_colvllama.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import ClassVar, List, Optional, Tuple, Union - -import torch -from PIL import Image -from transformers import BatchEncoding, BatchFeature, Idefics3Processor - -from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor - - -class ColVLlamaProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): - """ - Processor for ColVLlama. - """ - - query_augmentation_token: ClassVar[str] = "" - image_token: ClassVar[str] = "" - visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:Describe the image.\nAssistant:" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.tokenizer.padding_side = "left" - - def process_images( - self, - images: List[Image.Image], - contexts: Optional[List[str]] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process images for ColVLlama. - - Args: - images: List of PIL images. - contexts: List of optional context prompts, i.e. some text description of the context of the image. - """ - # if contexts is None: - # contexts = [self.visual_prompt_prefix] * len(images) - contexts = [self.visual_prompt_prefix] * len(images) - - images = [image.convert("RGB") for image in images] - - batch_doc = self( - text=contexts, - images=images, - padding="longest", - return_tensors="pt", - ) - return batch_doc - - def process_texts( - self, - texts: List[str], - max_length: int = 50, - contexts: Optional[List[str]] = None, - suffix: Optional[str] = None, - ) -> Union[BatchFeature, BatchEncoding]: - """ - Process texts for ColVLlama. - - NOTE: `max_length` is not used and kept only for trainer compatibility. - """ - if suffix is None: - suffix = self.query_augmentation_token * 10 - if contexts is None: - contexts = [""] * len(texts) - - prompts = [context + text + suffix for context, text in zip(contexts, texts)] - - batch_texts = self( - text=prompts, - return_tensors="pt", - padding="longest", - ) - - return batch_texts - - def score( - self, - qs: List[torch.Tensor], - ps: List[torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ) -> torch.Tensor: - """ - Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. - """ - return self.score_multi_vector(qs, ps, device=device, **kwargs) - - def get_n_patches( - self, - image_size: Tuple[int, int], - patch_size: int, - ) -> Tuple[int, int]: - raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/vllama/configuration_vllama.py b/colpali_engine/models/vllama/configuration_vllama.py deleted file mode 100644 index 576b6497..00000000 --- a/colpali_engine/models/vllama/configuration_vllama.py +++ /dev/null @@ -1,232 +0,0 @@ -import copy -import os -from typing import Any, Dict, Union - -from transformers import AutoConfig -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -def collect_arg_in_candidates(config, candidates, default = None) -> Any: - """ Gets the argument in a config given a list of candidates """ - for c in candidates: - if hasattr(config, c): - return getattr(config, c) - elif c in config: - return config[c] - if default is not None: - return default - raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config)) - -class VLlamaTextConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "VLlama" - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - text_model_name="HuggingFaceTB/SmolLM2-135M-Instruct", - **kwargs, - ): - self.text_model_name = text_model_name - text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - if hasattr(text_config, "text_config"): - text_config = text_config.text_config - - self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) - self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) - self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False) - self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) - - super().__init__(text_model_name=text_model_name, **kwargs) - -class VLlamaVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - embed_dim (`int`, *optional*, defaults to 1152): - Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`) - image_size (`int`, *optional*, defaults to 384): - The size (resolution) of each image. - """ - model_type = "VLlama" - attribute_map = { - "hidden_size": "embed_dim", - } - - def __init__( - self, - # Case for when vllama3 is from the hub with no vision_model_name - vision_model_name="google/siglip2-base-patch16-512", - **kwargs, - ): - self.vision_model_name = vision_model_name - vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - if hasattr(vision_config, "vision_config"): - vision_config = vision_config.vision_config - - self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) - self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) - self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) - self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) - self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) - - super().__init__(vision_model_name=vision_model_name, **kwargs) - -class VLlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a - SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM - [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should cache the key/value pairs of the attention mechanism. Only - relevant if `config.is_decoder=True`. - image_token_id (`int`, *optional*, defaults to 128257): - The id of the "image" token. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to tie the word embeddings with the token embeddings. - vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): - Custom vision config or dict for the vision tower - text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): - Custom text config or dict for the text model - scale_factor (`int`, *optional*, defaults to 2): - The scale factor for the image encoder. - pad_token_id (`int`, *optional*, defaults to 128002): - The id of the padding token. - - Example: - ```python - >>> from transformers import SmolVLMModel, SmolVLMConfig - >>> # Initializing configuration - >>> configuration = SmolVLMConfig() - >>> # Initializing a model from the configuration - >>> model = SmolVLMModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "VLlama" - is_composition = True - # sub_configs = {"text_config": VLlamaTextConfig, "vision_config": VLlamaVisionConfig} - - DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m" - DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" - - def __init__( - self, - text_config: Union[PretrainedConfig, Dict[str, Any]] = None, - vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, - image_token_id: int = 128_257, - vocab_size=128_256, - use_cache = True, - tie_word_embeddings = False, - freeze_config = None, - pad_token_id = None, - initializer_range = 0.02, - pixel_shuffle_factor = 4, - use_resampler = False, - additional_vocab_size = 0, - neftune_noise_alpha = 0.0, - **kwargs, - ): - self.image_token_id = image_token_id - self.use_cache = use_cache - self.tie_word_embeddings = tie_word_embeddings - self.scale_factor = pixel_shuffle_factor - self.additional_vocab_size = additional_vocab_size - - if text_config is None: - text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) - elif isinstance(text_config, dict): - text_config = VLlamaTextConfig(text_config["text_model_name"]) - self.text_config = text_config - - if vision_config is None: - vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) - elif isinstance(vision_config, dict): - vision_config = VLlamaVisionConfig(vision_config["vision_model_name"]) - self.vision_config = vision_config - - self.freeze_config = freeze_config - - # Pixel shuffle factor - self.pixel_shuffle_factor = pixel_shuffle_factor - self.use_resampler = use_resampler - - self.neftune_noise_alpha = neftune_noise_alpha - - self.initializer_range = initializer_range - - hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) - - super().__init__( - **kwargs, - pad_token_id=pad_token_id, - tie_word_embeddings=tie_word_embeddings, - vocab_size=vocab_size, - hidden_size=hidden_size, - ) - - def to_dict(self): - """ - Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. - Returns: - `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, - """ - output = copy.deepcopy(self.__dict__) - - output["model_type"] = self.__class__.model_type - output["vision_config"] = self.vision_config.to_dict() - output["text_config"] = self.text_config.to_dict() - # output["freeze_config"] = self.freeze_config.to_dict() - - return output - - # @classmethod - # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - # outputs = super(VLlamaConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs) - # return outputs - - @classmethod - def from_pretrained_models( - cls, - text_model_name: Union[str, os.PathLike], - vision_model_name: Union[str, os.PathLike], - **kwargs - ) -> "PretrainedConfig": - # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) - # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) - text_model_config = VLlamaTextConfig(text_model_name) - vision_model_config = VLlamaVisionConfig(vision_model_name) - return cls( - text_config=text_model_config, - vision_config=vision_model_config, - **kwargs - ) diff --git a/colpali_engine/models/vllama/modeling_vllama.py b/colpali_engine/models/vllama/modeling_vllama.py deleted file mode 100644 index e1d9793e..00000000 --- a/colpali_engine/models/vllama/modeling_vllama.py +++ /dev/null @@ -1,887 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss -from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, GenerationMixin, logging -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_outputs import BaseModelOutput -from transformers.modeling_utils import PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs - -# from transformers.models.smolvlm import SmolVLMModel, SmolVLMPreTrainedModel -from .configuration_vllama import VLlamaConfig - -logger = logging.get_logger(__name__) - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - -class DecoupledEmbedding(nn.Embedding): - # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding - """ - Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. - In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. - If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. - """ - - def __init__( - self, - num_embeddings, - num_additional_embeddings, - embedding_dim, - partially_freeze=False, - device=None, - dtype=None, - padding_idx=None, - **kwargs, - ) -> None: - """ - num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. - partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. - Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. - """ - if padding_idx is not None and padding_idx > num_embeddings: - raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") - super().__init__( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - padding_idx=padding_idx, - **kwargs, - ) - self.num_embeddings = num_embeddings - self.padding_idx = padding_idx - self.num_additional_embeddings = num_additional_embeddings - self.partially_freeze = partially_freeze - - if partially_freeze: - self.weight.requires_grad_(False) - - if self.num_additional_embeddings > 0: - self.additional_embedding = nn.Embedding( - num_embeddings=self.num_additional_embeddings, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - ) - - def forward(self, input_ids): - """ - we have 2 embeddings, with different indices - one pretrained self.weight and another - self.additional_embedding.weight that is being trained. - in order to make a lookup of the input ids, we: - 1. find out the indices of the entries belonging to the 2nd embedding - 2. extract those values while subtracting the size of the first embedding (num_embeddings), - since the 2nd embedding starts from 0 and not num_embeddings - 3. perform the 2nd embedding lookup - 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index - 5. perform the 1st embedding lookup - 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup - note: for the 1st embedding lookup we could have looked up only the low indices and not do - the padding, but then we have to create a new tensor and populate it with 2 tensors that are - spread out across various indices - i.e. not a simple concat - I haven't benchmarked the - complex case if it's any faster, given that seqlens are usually relatively short it's - probably not faster or if faster not by much - but might be a good idea to measure. - """ - if self.num_additional_embeddings == 0: - return self.additional_embedding(input_ids) - - # Clone so that we don't modify the original input_ids later on - input_ids = input_ids.clone() - additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) - input_ids_additional_vocab = input_ids[additional_vocab_indices] - additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) - - # for successful lookup replace input_ids with 0, the results of these will be discarded anyway - input_ids[additional_vocab_indices] = 0 - full_vector = F.embedding(input_ids, self.weight) - - # overwrite the records with high indices - full_vector[additional_vocab_indices] = additional_embeddings - - return full_vector - - def extra_repr(self) -> str: - return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( - self.num_embeddings, - self.num_additional_embeddings, - self.embedding_dim, - self.partially_freeze, - ) - -@dataclass -class VLlamaBaseModelOutputWithPast(BaseModelOutput): - """ - Base class for VLlama3 model's outputs that may also contain a past key/values (to speed up sequential decoding). - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - -@dataclass -class VLlamaCausalLMOutputWithPast(BaseModelOutput): - """ - Base class for VLlama3 causal language model (or autoregressive) outputs. - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - - -class VLlamaSimpleMLP(nn.Module): - def __init__(self, input_size, output_size): - super().__init__() - self.proj = nn.Linear(input_size, output_size, bias=False) - - def forward(self, x): - return self.proj(x) - -class VLlamaConnector(nn.Module): - def __init__(self, config): - super().__init__() - self.scale_factor = config.pixel_shuffle_factor - self.modality_projection = VLlamaSimpleMLP( - input_size=config.vision_config.hidden_size * (config.scale_factor**2), - output_size=config.text_config.hidden_size - ) - - def pixel_shuffle(self, x, scale_factor): - bsz, seq, embed_dim = x.size() - height = width = int(seq**0.5) - x = x.view(bsz, height, width, embed_dim) - x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) - x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) - x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) - return x - - def forward(self, image_hidden_states): - image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) - image_hidden_states = self.modality_projection(image_hidden_states) - return image_hidden_states - -class VLlamaPreTrainedModel(PreTrainedModel): - config_class = VLlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["VLlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - """Initialize the weights.""" - - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - -class VLlamaModel(VLlamaPreTrainedModel): - """ - A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger - in forward. Instead, we override inputs_merger here with custom logic. - """ - - def __init__(self, config: VLlamaConfig, **kwargs): - super().__init__(config) - - self.vision_model = VLlamaModel.init_vision_model(config, **kwargs) - self.connector = VLlamaConnector(config) - self.text_model = VLlamaModel.init_language_model(config, **kwargs) - - self.image_seq_len = int( - ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) - ) - self.image_token_id = self.config.image_token_id - - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - self.post_init() - - @staticmethod - def init_vision_model(config: VLlamaConfig, **kwargs): - vision_model_config = AutoConfig.from_pretrained( - config.vision_config.vision_model_name, - _attn_implementation=config._attn_implementation, - torch_dtype=config.torch_dtype, - **kwargs, - ) - - vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) - - if hasattr(vision_model, "vision_model"): - # If the model has a vision_model attribute, it means it's a wrapper around another model - vision_model = vision_model.vision_model - - return vision_model - - @staticmethod - def init_language_model(config: VLlamaConfig, **kwargs): - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - _attn_implementation=config._attn_implementation, - torch_dtype=config.torch_dtype, - trust_remote_code=True, - **kwargs, - ) - - text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) - - embed_layer = DecoupledEmbedding( - num_embeddings=text_model_config.vocab_size, - num_additional_embeddings=config.additional_vocab_size, - embedding_dim=config.hidden_size, - partially_freeze=config.freeze_config["freeze_text_layers"], - padding_idx=config.pad_token_id, - ) - - text_model.set_input_embeddings(embed_layer) - - return text_model - - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - - def get_input_embeddings(self): - return self.text_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.text_model.set_input_embeddings(value) - - def inputs_merger( - self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor - ): - """ - This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. - The merging happens as follows: - - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. - - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. - We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. - - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. - """ - _, patch_size, _ = image_hidden_states.shape - - image_mask = input_ids == self.image_token_id - num_image_tokens = image_mask.sum(dim=1) - if not torch.all(num_image_tokens % patch_size == 0): - raise ValueError("At least one sample has tokens not divisible by patch_size.") - - blocks_per_sample = num_image_tokens // patch_size - - offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) - block_offset = offsets[:-1] - row_cum = image_mask.cumsum(dim=-1) - chunk_idx = (row_cum - 1) // patch_size - local_idx = (row_cum - 1) % patch_size - block_idx = block_offset.unsqueeze(1) + chunk_idx - - image_embeds = torch.zeros_like(inputs_embeds) - image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] - - merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) - return merged_embeds - - def embed_tokens(self, input_ids: torch.LongTensor) -> torch.FloatTensor: - """ - Override the embed_tokens method to use the text model's input embeddings. - This is necessary to ensure that the image token ID is correctly handled. - """ - if self.text_model.get_input_embeddings() is None: - raise ValueError("The text model does not have input embeddings.") - - return self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, VLlamaBaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if self.training and self.text_model.gradient_checkpointing and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if inputs_embeds is not None and input_ids is None: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") - - # START VISUAL INPUTS INTEGRATION - if pixel_values is not None and image_hidden_states is not None: - raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") - elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values - pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - - if not any(real_images_inds): - # no images, leave one empty image. - real_images_inds[0] = True - - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=[pixel_values.shape[i] for i in (0, 2, 3)], - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - # patch_size = self.config.vision_config.patch_size - # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - # patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling - image_hidden_states = self.connector(image_hidden_states) - - elif image_hidden_states is not None: - image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - - if inputs_embeds is not None and image_hidden_states is not None: - # When we embed, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self.inputs_merger( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - image_hidden_states=image_hidden_states, - ) - - outputs = self.text_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - ) - - if not return_dict: - return tuple(v for v in [*outputs, image_hidden_states] if v is not None) - - return VLlamaBaseModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_hidden_states, - ) - -class VLlamaForCausalLM(VLlamaPreTrainedModel): - # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] - - def __init__(self, config, **kwargs): - super().__init__(config) - - self.image_token_id = config.image_token_id - self.in_features = config.hidden_size - self.out_additional_features = config.additional_vocab_size - self.vocab_size = config.vocab_size - - self.model = VLlamaModel(config, **kwargs) - self.lm_head = VLlamaForCausalLM.init_lm_head(config, **kwargs) - if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - # Initialize weights and apply final processing - self.post_init() - - @staticmethod - def init_lm_head(config, **kwargs): - # Get the pretrained model config - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) - # Get the lm head - lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None - if lm_head is None: - logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") - lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) - return lm_head - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, VLlamaCausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - # Pass the inputs to VLlamaModel - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - # Pass the outputs to the MLM head - hidden_states = outputs[0] - - logits = self.lm_head(hidden_states) - if self.out_additional_features > 0: - additional_features = self.additional_fc(hidden_states) - logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] - shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return VLlamaCausalLMOutputWithPast( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - -class VLlamaForVision2Seq(VLlamaPreTrainedModel, GenerationMixin): - def __init__(self, config, **kwargs): - super().__init__(config) - - self.image_token_id = config.image_token_id - self.in_features = config.hidden_size - self.out_additional_features = config.additional_vocab_size - self.vocab_size = config.vocab_size - - self.model = VLlamaModel(config, **kwargs) - self.lm_head = VLlamaForVision2Seq.init_lm_head(config, **kwargs) - if self.out_additional_features > 0: - self.additional_fc = nn.Linear( - in_features=self.in_features, - out_features=self.out_additional_features, - bias=False, - ) - - self.loss_fct = CrossEntropyLoss() - - # Initialize weights and apply final processing - self.post_init() - - @staticmethod - def init_lm_head(config, **kwargs): - # Get the pretrained model config - text_model_config = AutoConfig.from_pretrained( - config.text_config.text_model_name, - trust_remote_code=True, - **kwargs, - ) - model = AutoModelForMaskedLM.from_config(text_model_config, trust_remote_code=True, **kwargs) - # Get the lm head - lm_head = model.lm_head if hasattr(model, "lm_head") else model.decoder if hasattr(model, "decoder") else None - if lm_head is None: - logger.warning(f"No lm head was found for {config.text_config.text_model_name}, initializing a new one.") - lm_head = nn.Linear(config.hidden_size, config.vocab_size, False) - return lm_head - - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping - the model weights fixed. - """ - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - - def get_input_embeddings(self): - return self.model.text_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.text_model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, VLlamaCausalLMOutputWithPast]: - r""" - pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): - Mask to avoid performing attention on padding pixel indices. - image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): - The hidden states of the image encoder after modality projection. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `SmolVLMForConditionalGeneration`). - Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only - computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> import requests - >>> import torch - >>> from PIL import Image - >>> from io import BytesIO - - >>> from transformers import AutoProcessor, AutoModelForImageTextToText - >>> from transformers.image_utils import load_image - - >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible - >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") - >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") - >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") - - >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") - >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto") - - >>> # Create inputs - >>> messages = [ - ... { - ... "role": "user", - ... "content": [ - ... {"type": "video", "path": path/to/video}, - ... {"type": "text", "text": "What is happening in this video?"}, - ... ] - ... } - ... ] - - >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True) - - >>> # Generate - >>> generated_ids = model.generate(**inputs, max_new_tokens=256) - >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) - - >>> print(generated_texts) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - return_dict=True, - **kwargs, - ) - - hidden_states = outputs[0] - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - hidden_states = hidden_states[:, slice_indices, :] - logits = self.lm_head(hidden_states) - if self.out_additional_features > 0: - additional_features = self.additional_fc(hidden_states) - logits = torch.cat((logits, additional_features), -1) - logits = logits.float() - - loss = None - if labels is not None: - loss = self.loss_fct( - logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs - ) - - return VLlamaCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - pixel_values=None, - pixel_attention_mask=None, - image_hidden_states=None, - logits_to_keep=None, - **kwargs, - ): - # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take - # precedence is moved to the model, we can remove this fn) - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - logits_to_keep=logits_to_keep, - **kwargs, - ) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - # but IDEFICS requires both ids and embeds to be present - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs["input_ids"] = input_ids - - if image_hidden_states is not None: - model_inputs["pixel_values"] = None - model_inputs["pixel_attention_mask"] = None - - return model_inputs - - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - **kwargs, - ) - # Get the precomputed image_hidden_states - model_kwargs["image_hidden_states"] = outputs.image_hidden_states - return model_kwargs - - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past diff --git a/pyproject.toml b/pyproject.toml index 7850ff53..e38ed4b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,9 @@ dependencies = [ "pillow>=10.0.0", "requests", "scipy", - "torch>=2.5.0,<2.8.0", + "torch>=2.2.0,<2.8.0", "torchvision", - "transformers>=4.51.1,<4.52.0" + "transformers>=4.53.1,<4.54.0" ] [project.optional-dependencies] From 20e78cc27f2858adc7061bef86cdc0bbe10cc7e3 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Thu, 16 Oct 2025 23:56:44 +0200 Subject: [PATCH 35/42] add tests for modernvbert --- .../test_modeling_colmodernvbert.py | 148 ++++++++++++++++++ .../test_processing_colmodernvbert.py | 64 ++++++++ 2 files changed, 212 insertions(+) create mode 100644 tests/models/modernvbert/test_modeling_colmodernvbert.py create mode 100644 tests/models/modernvbert/test_processing_colmodernvbert.py diff --git a/tests/models/modernvbert/test_modeling_colmodernvbert.py b/tests/models/modernvbert/test_modeling_colmodernvbert.py new file mode 100644 index 00000000..ed1305a3 --- /dev/null +++ b/tests/models/modernvbert/test_modeling_colmodernvbert.py @@ -0,0 +1,148 @@ +import logging +from typing import Generator, cast + +import pytest +import torch +from datasets import load_dataset +from PIL import Image +from transformers.utils.import_utils import is_flash_attn_2_available + +from colpali_engine.models import ColModernVBert, ColModernVBertProcessor +from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "ModernVBERT/colmodernvbert" + + +@pytest.fixture(scope="module") +def model_without_mask(model_name: str) -> Generator[ColModernVBert, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColModernVBert, + ColModernVBert.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + mask_non_image_embeddings=False, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def model_with_mask(model_name: str) -> Generator[ColModernVBert, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColModernVBert, + ColModernVBert.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + mask_non_image_embeddings=True, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def processor(model_name: str) -> Generator[ColModernVBertProcessor, None, None]: + yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name)) + + +class TestColModernVBert_Model: # noqa N801 + @pytest.mark.slow + def test_load_model_from_pretrained(self, model_without_mask: ColModernVBert): + assert isinstance(model_without_mask, ColModernVBert) + + +class TestColModernVBert_ModelIntegration: # noqa N801 + @pytest.mark.slow + def test_forward_images_integration( + self, + model_without_mask: ColModernVBert, + processor: ColModernVBertProcessor, + ): + # Create a batch of dummy images + images = [ + Image.new("RGB", (64, 64), color="white"), + Image.new("RGB", (32, 32), color="black"), + ] + + # Process the image + batch_images = processor.process_images(images).to(model_without_mask.device) + + # Forward pass + with torch.no_grad(): + outputs = model_without_mask(**batch_images) + + # Assertions + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_visual_tokens, emb_dim = outputs.shape + assert batch_size == len(images) + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_forward_queries_integration( + self, + model_without_mask: ColModernVBert, + processor: ColModernVBertProcessor, + ): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_queries = processor.process_queries(queries).to(model_without_mask.device) + + # Forward pass + with torch.no_grad(): + outputs = model_without_mask(**batch_queries) + + # Assertions + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_query_tokens, emb_dim = outputs.shape + assert batch_size == len(queries) + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_retrieval_integration( + self, + model_without_mask: ColModernVBert, + processor: ColModernVBertProcessor, + ): + # Load the test dataset + ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") + + # Preprocess the examples + batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device) + batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device) + + # Run inference + with torch.inference_mode(): + image_embeddings = model_without_mask(**batch_images) + query_embeddings = model_without_mask(**batch_queries) + + # Compute retrieval scores + scores = processor.score_multi_vector( + qs=query_embeddings, + ps=image_embeddings, + ) # (len(qs), len(ps)) + + assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" + assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" + + # Check if the maximum scores per row are in the diagonal of the matrix score + assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() diff --git a/tests/models/modernvbert/test_processing_colmodernvbert.py b/tests/models/modernvbert/test_processing_colmodernvbert.py new file mode 100644 index 00000000..236ebc8a --- /dev/null +++ b/tests/models/modernvbert/test_processing_colmodernvbert.py @@ -0,0 +1,64 @@ +from typing import Generator, cast + +import pytest +import torch +from PIL import Image + +from colpali_engine.models import ColModernVBertProcessor + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "ModernVBERT/colmodernvbert" + + +@pytest.fixture(scope="module") +def processor_from_pretrained(model_name: str) -> Generator[ColModernVBertProcessor, None, None]: + yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name)) + + +def test_load_processor_from_pretrained(processor_from_pretrained: ColModernVBertProcessor): + assert isinstance(processor_from_pretrained, ColModernVBertProcessor) + + +def test_process_images(processor_from_pretrained: ColModernVBertProcessor): + # Create a dummy image + image_size = (64, 32) + image = Image.new("RGB", image_size, color="black") + images = [image] + + # Process the image + batch_feature = processor_from_pretrained.process_images(images) + + # Assertions + assert "pixel_values" in batch_feature + assert isinstance(batch_feature["pixel_values"], torch.Tensor) + assert batch_feature["pixel_values"].shape[0] == 1 + +def test_process_texts(processor_from_pretrained: ColModernVBertProcessor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_encoding = processor_from_pretrained.process_texts(queries) + + # Assertions + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) + +def test_process_queries(processor_from_pretrained: ColModernVBertProcessor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_encoding = processor_from_pretrained.process_queries(queries) + + # Assertions + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) From 43fba9831c650a3c22990206b1095fc10fa008d9 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 17 Oct 2025 00:00:26 +0200 Subject: [PATCH 36/42] f test --- tests/models/modernvbert/test_modeling_colmodernvbert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/modernvbert/test_modeling_colmodernvbert.py b/tests/models/modernvbert/test_modeling_colmodernvbert.py index ed1305a3..ecf3ced1 100644 --- a/tests/models/modernvbert/test_modeling_colmodernvbert.py +++ b/tests/models/modernvbert/test_modeling_colmodernvbert.py @@ -27,7 +27,7 @@ def model_without_mask(model_name: str) -> Generator[ColModernVBert, None, None] ColModernVBert, ColModernVBert.from_pretrained( model_name, - torch_dtype=torch.bfloat16, + torch_dtype=torch.float16, device_map=device, attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, mask_non_image_embeddings=False, @@ -45,7 +45,7 @@ def model_with_mask(model_name: str) -> Generator[ColModernVBert, None, None]: ColModernVBert, ColModernVBert.from_pretrained( model_name, - torch_dtype=torch.bfloat16, + torch_dtype=torch.float16, device_map=device, attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, mask_non_image_embeddings=True, From 058a2998b103437e14287de7caa7d0228129f2a1 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 17 Oct 2025 00:05:04 +0200 Subject: [PATCH 37/42] f --- .../models/modernvbert/test_modeling_colmodernvbert.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/modernvbert/test_modeling_colmodernvbert.py b/tests/models/modernvbert/test_modeling_colmodernvbert.py index ecf3ced1..05e055cf 100644 --- a/tests/models/modernvbert/test_modeling_colmodernvbert.py +++ b/tests/models/modernvbert/test_modeling_colmodernvbert.py @@ -29,7 +29,7 @@ def model_without_mask(model_name: str) -> Generator[ColModernVBert, None, None] model_name, torch_dtype=torch.float16, device_map=device, - attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + attn_implementation="eager", mask_non_image_embeddings=False, ).eval(), ) @@ -47,7 +47,7 @@ def model_with_mask(model_name: str) -> Generator[ColModernVBert, None, None]: model_name, torch_dtype=torch.float16, device_map=device, - attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + attn_implementation="eager", mask_non_image_embeddings=True, ).eval(), ) @@ -104,7 +104,7 @@ def test_forward_queries_integration( ] # Process the queries - batch_queries = processor.process_queries(queries).to(model_without_mask.device) + batch_queries = processor.process_queries(queries).to(model_without_mask.device).to(torch.float16) # Forward pass with torch.no_grad(): @@ -127,8 +127,8 @@ def test_retrieval_integration( ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") # Preprocess the examples - batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device) - batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device) + batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device).to(torch.float16) + batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device).to(torch.float16) # Run inference with torch.inference_mode(): From d1e3f387803ef56c5318b8c2d0ab04001394f1dd Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 17 Oct 2025 00:07:55 +0200 Subject: [PATCH 38/42] ff --- .../modernvbert/test_modeling_colmodernvbert.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/models/modernvbert/test_modeling_colmodernvbert.py b/tests/models/modernvbert/test_modeling_colmodernvbert.py index 05e055cf..098f8223 100644 --- a/tests/models/modernvbert/test_modeling_colmodernvbert.py +++ b/tests/models/modernvbert/test_modeling_colmodernvbert.py @@ -27,7 +27,7 @@ def model_without_mask(model_name: str) -> Generator[ColModernVBert, None, None] ColModernVBert, ColModernVBert.from_pretrained( model_name, - torch_dtype=torch.float16, + torch_dtype=torch.float32, device_map=device, attn_implementation="eager", mask_non_image_embeddings=False, @@ -45,7 +45,7 @@ def model_with_mask(model_name: str) -> Generator[ColModernVBert, None, None]: ColModernVBert, ColModernVBert.from_pretrained( model_name, - torch_dtype=torch.float16, + torch_dtype=torch.float32, device_map=device, attn_implementation="eager", mask_non_image_embeddings=True, @@ -104,7 +104,7 @@ def test_forward_queries_integration( ] # Process the queries - batch_queries = processor.process_queries(queries).to(model_without_mask.device).to(torch.float16) + batch_queries = processor.process_queries(queries).to(model_without_mask.device).to(torch.float32) # Forward pass with torch.no_grad(): @@ -127,8 +127,8 @@ def test_retrieval_integration( ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") # Preprocess the examples - batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device).to(torch.float16) - batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device).to(torch.float16) + batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device).to(torch.float32) + batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device).to(torch.float32) # Run inference with torch.inference_mode(): @@ -144,5 +144,5 @@ def test_retrieval_integration( assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" - # Check if the maximum scores per row are in the diagonal of the matrix score - assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() + # # Check if the maximum scores per row are in the diagonal of the matrix score + # assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() From 133bc51a32bf9da613cc4bfceeed7bfc4ced3e87 Mon Sep 17 00:00:00 2001 From: Paul Teiletche <73120933+paultltc@users.noreply.github.com> Date: Fri, 17 Oct 2025 00:53:17 -0700 Subject: [PATCH 39/42] update dtype assign (#349) --- .../modernvbert/modeling_modernvbert.py | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/colpali_engine/models/modernvbert/modeling_modernvbert.py b/colpali_engine/models/modernvbert/modeling_modernvbert.py index 94736d17..2dc468ba 100644 --- a/colpali_engine/models/modernvbert/modeling_modernvbert.py +++ b/colpali_engine/models/modernvbert/modeling_modernvbert.py @@ -202,11 +202,8 @@ class ModernVBertPreTrainedModel(PreTrainedModel): config_class = ModernVBertConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["ModernVBertDecoderLayer"] - _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", 0.02) @@ -221,39 +218,44 @@ def _init_weights(self, module): class ModernVBertModel(ModernVBertPreTrainedModel): - def __init__(self, config: ModernVBertConfig, **kwargs): + def __init__(self, config: ModernVBertConfig): super().__init__(config) - self.vision_model = ModernVBertModel.init_vision_model(config, **kwargs) + self.vision_model = ModernVBertModel.init_vision_model(config) self.connector = ModernVBertConnector(config) - self.text_model = ModernVBertModel.init_language_model(config, **kwargs) + self.text_model = ModernVBertModel.init_language_model(config) self.image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) ) self.image_token_id = config.image_token_id self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + # set the correct dtype for vision and text models + self.vision_model.to(self.dtype) + self.text_model.to(self.dtype) self.post_init() @staticmethod - def init_vision_model(config: ModernVBertConfig, **kwargs): + def init_vision_model(config: ModernVBertConfig): vision_model_config = AutoConfig.from_pretrained( config.vision_config.vision_model_name, _attn_implementation=config._attn_implementation, - dtype=config.torch_dtype, - **kwargs, ) - vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs) + vision_model = AutoModel.from_config( + vision_model_config, + trust_remote_code=True, + ) return getattr(vision_model, "vision_model", vision_model) @staticmethod - def init_language_model(config: ModernVBertConfig, **kwargs): + def init_language_model(config: ModernVBertConfig): text_model_config = AutoConfig.from_pretrained( config.text_config.text_model_name, _attn_implementation=config._attn_implementation, - dtype=config.torch_dtype, trust_remote_code=True, - **kwargs, ) - text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs) + text_model = AutoModel.from_config( + text_model_config, + trust_remote_code=True + ) embed_layer = DecoupledEmbedding( num_embeddings=text_model_config.vocab_size, num_additional_embeddings=config.additional_vocab_size, @@ -376,10 +378,10 @@ def forward( ) class ModernVBertLMHead(nn.Module): - def __init__(self, config, **kwargs): + def __init__(self, config): super().__init__() - pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True, **kwargs) - pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs) + pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True) + pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True) self.head = pretrained_model.head self.decoder = pretrained_model.decoder @@ -388,16 +390,17 @@ def forward(self, hidden_states): class ModernVBertForMaskedLM(ModernVBertPreTrainedModel): - def __init__(self, config, **kwargs): + def __init__(self, config): super().__init__(config) self.image_token_id = config.image_token_id self.in_features = config.hidden_size self.out_additional_features = config.additional_vocab_size self.vocab_size = config.vocab_size - self.model = ModernVBertModel(config, **kwargs) - self.lm_head = ModernVBertLMHead(config, **kwargs) + self.model = ModernVBertModel(config) + self.lm_head = ModernVBertLMHead(config) if self.out_additional_features > 0: self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False) + self.lm_head.to(self.dtype) self.post_init() def forward( From df0d1a89d94392ac0eff03c68e796813218d836a Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 17 Oct 2025 17:38:59 +0200 Subject: [PATCH 40/42] oopsie --- colpali_engine/collators/collator_copy.py | 142 ---------------------- 1 file changed, 142 deletions(-) delete mode 100644 colpali_engine/collators/collator_copy.py diff --git a/colpali_engine/collators/collator_copy.py b/colpali_engine/collators/collator_copy.py deleted file mode 100644 index f131c233..00000000 --- a/colpali_engine/collators/collator_copy.py +++ /dev/null @@ -1,142 +0,0 @@ -import random -import torch -from typing import Any, Dict, List, Union - -from PIL.Image import Image - -from colpali_engine.data.dataset import ColPaliEngineDataset -from colpali_engine.models.paligemma import ColPaliProcessor -from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor - - -def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: - """ - Prefix all keys in a dictionary with the given prefix. - """ - return {f"{prefix}{k}": v for k, v in data.items()} - - -class VisualRetrieverCollator: - """ - Collator for training vision retrieval models. - """ - - # Prefixes - query_prefix = "query_" - pos_doc_prefix = "doc_" - neg_doc_prefix = "neg_doc_" - - def __init__( - self, - processor: BaseVisualRetrieverProcessor, - max_length: int = 2048, - ): - self.processor = processor - self.max_length = max_length - self.image_token_id = None - - # If processor is one of the supported types, extract the token id. - if isinstance(self.processor, (ColPaliProcessor,)): - image_token = "" - try: - idx = self.processor.tokenizer.additional_special_tokens.index(image_token) - self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[idx] - except ValueError: - self.image_token_id = None - - # Force padding to be on the right for ColPaliProcessor. - if isinstance(self.processor, ColPaliProcessor) and self.processor.tokenizer.padding_side != "right": - print("Setting padding side to right") - self.processor.tokenizer.padding_side = "right" - - def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: - queries: List[Union[None, str, Image]] = [] - pos_targets: List[Union[str, Image]] = [] - neg_targets: List[Union[str, Image]] = [] - selected_ids: List[int] = [] - - # Parse the examples. - positive_ids_tensor = -torch.ones((len(examples), 100), dtype=torch.long) - for i, example in enumerate(examples): - assert ColPaliEngineDataset.QUERY_KEY in example, f"Missing {ColPaliEngineDataset.QUERY_KEY} in example." - query = example[ColPaliEngineDataset.QUERY_KEY] - sampled_query = random.choice(query) if isinstance(query, list) else query - queries.append(sampled_query) - - assert ColPaliEngineDataset.POS_TARGET_KEY in example, ( - f"Missing {ColPaliEngineDataset.POS_TARGET_KEY} in example." - ) - pos_tgt = example[ColPaliEngineDataset.POS_TARGET_KEY] - positive_ids = example.get("positive_ids", None) - if isinstance(pos_tgt, list): - sample_tuple = random.choice([(t, id_) for t, id_ in zip(pos_tgt, positive_ids)]) - sample_pos = sample_tuple[0] - selected_ids.append(sample_tuple[1]) - else: - sample_pos = pos_tgt - pos_targets.append(sample_pos) - if positive_ids is not None: - positive_ids_tensor[i, :len(positive_ids)] = torch.tensor(positive_ids) - - neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) - if neg_tgt is not None: - # sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt - # neg_targets.append(random.choice(neg_tgt)) #neg_tgts) - neg_targets.append(neg_tgt) - - # Ensure all queries are strings or images. - assert all(isinstance(q, str) for q in queries), ( - "All queries must be strings, this collator does not support images in queries." - ) - - # Process queries. - queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] - batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) - - # Process targets. - batch_pos_target = self.auto_collate(pos_targets, key_prefix=self.pos_doc_prefix) - batch_neg_target = self.auto_collate(neg_targets, key_prefix=self.neg_doc_prefix) if neg_targets else {} - - return { - **batch_query, - **batch_pos_target, - **batch_neg_target, - "selected_ids": torch.Tensor(selected_ids), - "positive_ids_tensor": positive_ids_tensor, - } - - def auto_collate(self, batch: List[Union[str, Image, List[str], List[Image]]], key_prefix: str = "") -> Dict[str, Any]: - """Automatically collate a batch of documents.""" - # Convert Document objects to their underlying data. - # if type is mixed across the batch, raise an error. - all_types = set(type(item) for item in batch) - if str in all_types and Image in all_types: - raise ValueError(f"Batch contains mixed types: {all_types}. Expected all items to be of the same type.") - if isinstance(batch[0], str): - proc_batch = self.processor.process_texts(texts=batch) - elif isinstance(batch[0], Image): - proc_batch = self.processor.process_images(images=batch) - elif isinstance(batch[0], list): - if isinstance(batch[0][0], str): - proc_texts_batch = [] - batch_size = len(batch) - all_texts = [text for texts in batch for text in texts] - num_negatives = len(all_texts) // batch_size - proc_batch = self.processor.process_texts(texts=all_texts) - elif isinstance(batch[0][0], Image): - proc_imgs_batch = [] - batch_size = len(batch) - all_imgs = [img for imgs in batch for img in imgs] - num_negatives = len(all_imgs) // batch_size - proc_batch = self.processor.process_images(images=all_imgs) - else: - raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.") - for k, v in proc_batch.items(): - if isinstance(v, torch.Tensor): - proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:]) - else: - proc_batch[k] = v - else: - raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") - - return prefix_keys(proc_batch, key_prefix) \ No newline at end of file From 8c89c49aac3aecac4c0fd836fbe18bb1ca493ba0 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Mon, 20 Oct 2025 21:28:36 +0200 Subject: [PATCH 41/42] update other losses --- colpali_engine/loss/bi_encoder_losses.py | 8 ++++---- colpali_engine/loss/late_interaction_losses.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colpali_engine/loss/bi_encoder_losses.py b/colpali_engine/loss/bi_encoder_losses.py index a5ea142f..274dfd02 100644 --- a/colpali_engine/loss/bi_encoder_losses.py +++ b/colpali_engine/loss/bi_encoder_losses.py @@ -332,16 +332,16 @@ def forward( Args: query_embeddings (Tensor[B, D]): Query vectors. doc_embeddings (Tensor[B, D]): Positive document vectors. - neg_doc_embeddings (Tensor[B, D]): Negative document vectors. + neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors. Returns: Tensor: Scalar loss value. """ # dot product for matching pairs only - pos = (query_embeddings * doc_embeddings).sum(dim=1) - neg = (query_embeddings * neg_doc_embeddings).sum(dim=1) + pos = (query_embeddings * doc_embeddings).sum(dim=1) # B + neg = (query_embeddings.unsqueeze(1) * neg_doc_embeddings).sum(dim=2) # B x N - loss = torch.nn.functional.softplus((neg - pos) / self.temperature).mean() + loss = torch.nn.functional.softplus((neg - pos.unsqueeze(1)) / self.temperature).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings) diff --git a/colpali_engine/loss/late_interaction_losses.py b/colpali_engine/loss/late_interaction_losses.py index 2219abf3..eb4d060c 100644 --- a/colpali_engine/loss/late_interaction_losses.py +++ b/colpali_engine/loss/late_interaction_losses.py @@ -374,23 +374,23 @@ def forward( Args: query_embeddings (Tensor): [B, Nq, D] doc_embeddings (Tensor): [B, Nd, D] positive docs - neg_doc_embeddings (Tensor): [B, Nneg, D] negative docs + neg_doc_embeddings (Tensor): [B, Nneg, Lneg, D] negative docs offset (int): Positional offset for positives. Returns: Tensor: Scalar loss value. """ lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) - pos_raw = torch.einsum("bnd,bsd->bns", query_embeddings, doc_embeddings) - neg_raw = torch.einsum("bnd,bsd->bns", query_embeddings, neg_doc_embeddings) + pos_raw = torch.einsum("bnd,bld->bnl", query_embeddings, doc_embeddings) + neg_raw = torch.einsum("bnd,bsld->bsnl", query_embeddings, neg_doc_embeddings) # B x Nneg x Nq x Lneg pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1) - neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=2, dim_sum=1) + neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2) if self.normalize_scores: pos_scores = self._apply_normalization(pos_scores, lengths) neg_scores = self._apply_normalization(neg_scores, lengths) - loss = F.softplus((neg_scores - pos_scores) / self.temperature).mean() + loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings, offset) From c6d4dd0a5d47b2872e61b0de378a4abffebe05ab Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Tue, 21 Oct 2025 11:06:07 +0200 Subject: [PATCH 42/42] correct tests to handle multiple neg --- tests/loss/test_bi_losses.py | 8 ++++---- tests/loss/test_li_losses.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/loss/test_bi_losses.py b/tests/loss/test_bi_losses.py index d4c36ad1..b96ebafd 100644 --- a/tests/loss/test_bi_losses.py +++ b/tests/loss/test_bi_losses.py @@ -110,20 +110,20 @@ def test_forward_with_filtering(self): class TestBiPairwiseNegativeCELoss: def test_forward_no_inbatch(self): loss_fn = BiPairwiseNegativeCELoss(temperature=1.0, in_batch_term_weight=0) - B, D = 5, 4 + B, Nneg, D = 5, 2, 4 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg,D) loss = loss_fn(query, pos, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) def test_forward_with_inbatch(self): loss_fn = BiPairwiseNegativeCELoss(temperature=1.0, in_batch_term_weight=0.5) - B, D = 2, 3 + B, Nneg, D = 2, 3, 4 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg, D) loss = loss_fn(query, pos, neg) # both explicit and in-batch pairwise yield ln(2), average remains ln(2) expected = F.softplus(torch.tensor(0.0)) diff --git a/tests/loss/test_li_losses.py b/tests/loss/test_li_losses.py index b363baaf..4b34f586 100644 --- a/tests/loss/test_li_losses.py +++ b/tests/loss/test_li_losses.py @@ -156,10 +156,10 @@ def test_no_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0, ) - B, Nq, D = 2, 1, 3 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, 1, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) @@ -172,10 +172,10 @@ def test_with_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0.5, ) - B, Nq, D = 2, 1, 3 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, 1, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected)