diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py index 67b4dbc1..6934e80f 100644 --- a/colpali_engine/__init__.py +++ b/colpali_engine/__init__.py @@ -5,6 +5,8 @@ BiQwen2_5, BiQwen2_5_Processor, BiQwen2Processor, + BiModernVBert, + BiModernVBertProcessor, ColIdefics3, ColIdefics3Processor, ColPali, @@ -12,7 +14,9 @@ ColQwen2, ColQwen2_5, ColQwen2_5_Processor, - ColQwen2_5Omni, - ColQwen2_5OmniProcessor, + # ColQwen2_5Omni, + # ColQwen2_5OmniProcessor, ColQwen2Processor, + ColModernVBert, + ColModernVBertProcessor, ) diff --git a/colpali_engine/collators/visual_retriever_collator.py b/colpali_engine/collators/visual_retriever_collator.py index 21a2f222..bae3066e 100644 --- a/colpali_engine/collators/visual_retriever_collator.py +++ b/colpali_engine/collators/visual_retriever_collator.py @@ -1,5 +1,6 @@ import random from typing import Any, Dict, List, Union +import torch from PIL.Image import Image @@ -69,16 +70,18 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) if neg_tgt is not None: - sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt - neg_targets.append(sampled_neg) + neg_targets.append(neg_tgt) # Ensure all queries are strings or images. assert all(isinstance(q, str) for q in queries), ( "All queries must be strings, this collator does not support images in queries." ) + is_str = isinstance(queries[0], str) + # Process queries. - queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] + # queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] + queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) # Process targets. @@ -102,6 +105,26 @@ def auto_collate(self, batch: List[Union[str, Image]], key_prefix: str = "") -> proc_batch = self.processor.process_texts(texts=batch) elif isinstance(batch[0], Image): proc_batch = self.processor.process_images(images=batch) + elif isinstance(batch[0], list): + if isinstance(batch[0][0], str): + proc_texts_batch = [] + batch_size = len(batch) + all_texts = [text for texts in batch for text in texts] + num_negatives = len(all_texts) // batch_size + proc_batch = self.processor.process_texts(texts=all_texts) + elif isinstance(batch[0][0], Image): + proc_imgs_batch = [] + batch_size = len(batch) + all_imgs = [img for imgs in batch for img in imgs] + num_negatives = len(all_imgs) // batch_size + proc_batch = self.processor.process_images(images=all_imgs) + else: + raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.") + for k, v in proc_batch.items(): + if isinstance(v, torch.Tensor): + proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:]) + else: + proc_batch[k] = v else: raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") return prefix_keys(proc_batch, key_prefix) diff --git a/colpali_engine/data/dataset.py b/colpali_engine/data/dataset.py index 311d7421..8eec842e 100644 --- a/colpali_engine/data/dataset.py +++ b/colpali_engine/data/dataset.py @@ -77,6 +77,7 @@ def __init__( query_column_name: str = "query", pos_target_column_name: str = "pos_target", neg_target_column_name: str = None, + num_negatives: int = 3, ): """ Initialize the dataset with the provided data and external document corpus. @@ -94,6 +95,7 @@ def __init__( self.pos_target_column_name = pos_target_column_name self.neg_target_column_name = neg_target_column_name + self.num_negatives = num_negatives assert isinstance( self.data, (list, Dataset, HFDataset), @@ -131,8 +133,8 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: pos_targets = [self.corpus.retrieve(doc_id) for doc_id in pos_targets] if neg_targets is not None: # to avoid oveflowing CPU memory - if len(neg_targets) > 5: - neg_targets = random.sample(neg_targets, 5) + if len(neg_targets) > self.num_negatives: + neg_targets = random.sample(neg_targets, self.num_negatives) neg_targets = [self.corpus.retrieve(doc_id) for doc_id in neg_targets] return { diff --git a/colpali_engine/loss/__init__.py b/colpali_engine/loss/__init__.py index db060015..0e3ecbc2 100644 --- a/colpali_engine/loss/__init__.py +++ b/colpali_engine/loss/__init__.py @@ -4,6 +4,7 @@ BiNegativeCELoss, BiPairwiseCELoss, BiPairwiseNegativeCELoss, + BiSigmoidLoss, ) from .late_interaction_losses import ( ColbertLoss, @@ -11,4 +12,5 @@ ColbertNegativeCELoss, ColbertPairwiseCELoss, ColbertPairwiseNegativeCELoss, + ColbertSigmoidLoss, ) diff --git a/colpali_engine/loss/bi_encoder_losses.py b/colpali_engine/loss/bi_encoder_losses.py index b8dbef34..274dfd02 100644 --- a/colpali_engine/loss/bi_encoder_losses.py +++ b/colpali_engine/loss/bi_encoder_losses.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F # noqa: N812 from torch.nn import CrossEntropyLoss @@ -111,6 +112,60 @@ def forward( return self.ce_loss(scores / self.temperature, pos_idx) +class BiPairedEncoderLoss(BiEncoderModule): + """ + InfoNCE loss for bi-encoders without explicit negatives. + + Args: + temperature (float): Scaling factor for logits. + pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True. + max_batch_size (int): Max batch size for index buffer caching. + filter_threshold (float): Threshold ratio for negative filtering. + filter_factor (float): Factor to down-weight filtered negatives. + """ + + def __init__( + self, + temperature: float = 0.02, + pos_aware_negative_filtering: bool = False, + max_batch_size: int = 1024, + filter_threshold: float = 0.95, + filter_factor: float = 0.5, + ): + super().__init__(max_batch_size, temperature, filter_threshold, filter_factor) + self.pos_aware_negative_filtering = pos_aware_negative_filtering + self.ce_loss = CrossEntropyLoss() + + def forward( + self, + query_embeddings: torch.Tensor, + doc_embeddings: torch.Tensor, + offset: int = 0, + ) -> torch.Tensor: + """ + Compute the InfoNCE loss over a batch of bi-encoder embeddings. + + Args: + query_embeddings (Tensor[B, D]): Query vectors. + doc_embeddings (Tensor[B, D]): Document vectors. + offset (int): Offset for positive indices (multi-GPU). + + Returns: + Tensor: Scalar cross-entropy loss. + """ + # Compute in-batch similarity matrix + scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) + batch_size = scores.size(0) + idx, pos_idx = self._get_idx(batch_size, offset, scores.device) + + if self.pos_aware_negative_filtering: + self._filter_high_negatives(scores, pos_idx) + + q2t = self.ce_loss(scores / self.temperature, pos_idx) + t2q = self.ce_loss(scores.T / self.temperature, ...) + + return (q2t + t2q) / 2.0 + class BiNegativeCELoss(BiEncoderModule): """ @@ -161,17 +216,18 @@ def forward( Args: query_embeddings (Tensor[B, D]): Query vectors. doc_embeddings (Tensor[B, D]): Positive document vectors. - neg_doc_embeddings (Tensor[B, D]): Negative document vectors. + neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors. offset (int): Offset for in-batch CE positives. Returns: Tensor: Scalar loss value. """ # Dot-product only for matching pairs - pos_scores = (query_embeddings * doc_embeddings).sum(dim=1) / self.temperature - neg_scores = (query_embeddings * neg_doc_embeddings).sum(dim=1) / self.temperature + pos_scores = (query_embeddings * doc_embeddings[offset:offset + neg_doc_embeddings.size(0)]).sum(dim=1) + pos_scores /= self.temperature + neg_scores = torch.einsum("bd,bnd->bn", query_embeddings, neg_doc_embeddings) / self.temperature - loss = torch.nn.functional.softplus(neg_scores - pos_scores).mean() + loss = F.softplus(neg_scores - pos_scores.unsqueeze(1)).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset) @@ -206,6 +262,7 @@ def forward( self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, + offset: int = 0, ) -> torch.Tensor: """ Compute softplus(hardest_neg - pos) where hardest_neg is the highest off-diagonal score. @@ -267,6 +324,7 @@ def forward( query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, neg_doc_embeddings: torch.Tensor, + offset: int = 0, ) -> torch.Tensor: """ Compute softplus(neg-explicit - pos) plus optional pairwise in-batch loss. @@ -274,19 +332,85 @@ def forward( Args: query_embeddings (Tensor[B, D]): Query vectors. doc_embeddings (Tensor[B, D]): Positive document vectors. - neg_doc_embeddings (Tensor[B, D]): Negative document vectors. + neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors. Returns: Tensor: Scalar loss value. """ # dot product for matching pairs only - pos = (query_embeddings * doc_embeddings).sum(dim=1) - neg = (query_embeddings * neg_doc_embeddings).sum(dim=1) + pos = (query_embeddings * doc_embeddings).sum(dim=1) # B + neg = (query_embeddings.unsqueeze(1) * neg_doc_embeddings).sum(dim=2) # B x N - loss = torch.nn.functional.softplus((neg - pos) / self.temperature).mean() + loss = torch.nn.functional.softplus((neg - pos.unsqueeze(1)) / self.temperature).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings) loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight return loss + +class BiSigmoidLoss(BiEncoderModule): + """ + Sigmoid loss for ColBERT with in-batch negatives. + + Args: + temperature (float): Scaling factor for logits. + pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True. + max_batch_size (int): Max batch size for index buffer caching. + filter_threshold (float): Threshold ratio for negative filtering. + filter_factor (float): Factor to down-weight filtered negatives. + """ + + def __init__( + self, + temperature: float = 0.02, + pos_aware_negative_filtering: bool = False, + max_batch_size: int = 1024, + filter_threshold: float = 0.95, + filter_factor: float = 0.5, + ): + super().__init__(max_batch_size, temperature, filter_threshold, filter_factor) + self.pos_aware_negative_filtering = pos_aware_negative_filtering + + def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor: + """ + Compute the sigmoid loss for a batch of bi-encoder embeddings. + + Args: + query_embeddings (Tensor[B, D]): Query vectors. + doc_embeddings (Tensor[B, D]): Document vectors. + offset (int): Offset for positive indices (multi-GPU). + + Returns: + Tensor: Scalar cross-entropy loss. + """ + + # Compute in-batch similarity matrix + scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) + + batch_size, num_targets = scores.shape + device = scores.device + + _, pos_idx = self._get_idx(batch_size, offset, device) + + if self.pos_aware_negative_filtering: + self._filter_high_negatives(scores, pos_idx) + + all_losses = [] + for k in range(num_targets // batch_size): + # mask equal to 1 on offset -> offset + batch_size + curr_idx = torch.arange(offset, offset + batch_size, device=device) + # keep only the scores for the current batch + curr_scores = scores[:, curr_idx].view(-1) / self.temperature + # compute the labels + labels = -torch.ones(batch_size * batch_size, device=device) + if k == 0: + flat_pos = (pos_idx - offset) * (batch_size + 1) + labels[flat_pos] = 1.0 + # compute the loss + block_loss = F.softplus(curr_scores * labels) + all_losses.append(block_loss) + # shift the offset for the next batch + offset = (offset + batch_size) % num_targets + + return torch.stack(all_losses, dim=0).mean() diff --git a/colpali_engine/loss/late_interaction_losses.py b/colpali_engine/loss/late_interaction_losses.py index 95bcf6ff..eb4d060c 100644 --- a/colpali_engine/loss/late_interaction_losses.py +++ b/colpali_engine/loss/late_interaction_losses.py @@ -152,7 +152,6 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings) scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2) - if self.normalize_scores: scores = self._apply_normalization(scores, lengths) @@ -163,7 +162,6 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, self._filter_high_negatives(scores, pos_idx) # print(f"Scores shape: {scores.shape}, offset: {offset}") - return self.ce_loss(scores / self.temperature, pos_idx) @@ -226,25 +224,29 @@ def forward( Compute InfoNCE loss with explicit negatives and optional in-batch term. Args: - query_embeddings (Tensor): [B, Nq, D] - doc_embeddings (Tensor): [B, Nd, D] positive docs - neg_doc_embeddings (Tensor): [B, Nneg, D] negative docs + query_embeddings (Tensor): [B, Lq, D] + doc_embeddings (Tensor): [B, Ld, D] positive docs + neg_doc_embeddings (Tensor): [B, Nneg, Lneg, D] negative docs offset (int): Positional offset for in-batch CE. Returns: Tensor: Scalar loss. """ lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) - pos_raw = torch.einsum("bnd,bsd->bns", query_embeddings, doc_embeddings) - neg_raw = torch.einsum("bnd,bsd->bns", query_embeddings, neg_doc_embeddings) + pos_raw = torch.einsum( + "bnd,bsd->bns", + query_embeddings, + doc_embeddings[offset:offset + neg_doc_embeddings.size(0)] + ) + neg_raw = torch.einsum("bnd,blsd->blns", query_embeddings, neg_doc_embeddings) pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1) - neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=2, dim_sum=1) + neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2) if self.normalize_scores: pos_scores = self._apply_normalization(pos_scores, lengths) neg_scores = self._apply_normalization(neg_scores, lengths) - loss = F.softplus((neg_scores - pos_scores) / self.temperature).mean() + loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset) @@ -372,26 +374,93 @@ def forward( Args: query_embeddings (Tensor): [B, Nq, D] doc_embeddings (Tensor): [B, Nd, D] positive docs - neg_doc_embeddings (Tensor): [B, Nneg, D] negative docs + neg_doc_embeddings (Tensor): [B, Nneg, Lneg, D] negative docs offset (int): Positional offset for positives. Returns: Tensor: Scalar loss value. """ lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) - pos_raw = torch.einsum("bnd,bsd->bns", query_embeddings, doc_embeddings) - neg_raw = torch.einsum("bnd,bsd->bns", query_embeddings, neg_doc_embeddings) + pos_raw = torch.einsum("bnd,bld->bnl", query_embeddings, doc_embeddings) + neg_raw = torch.einsum("bnd,bsld->bsnl", query_embeddings, neg_doc_embeddings) # B x Nneg x Nq x Lneg pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1) - neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=2, dim_sum=1) + neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2) if self.normalize_scores: pos_scores = self._apply_normalization(pos_scores, lengths) neg_scores = self._apply_normalization(neg_scores, lengths) - loss = F.softplus((neg_scores - pos_scores) / self.temperature).mean() + loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean() if self.in_batch_term_weight > 0: loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings, offset) loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight return loss + + +class ColbertSigmoidLoss(ColbertModule): + """ + Sigmoid loss for ColBERT with explicit negatives. + + Args: + temperature (float): Scaling for logits. + normalize_scores (bool): Normalize scores by query lengths. + use_smooth_max (bool): Use log-sum-exp instead of amax. + pos_aware_negative_filtering (bool): Apply pos-aware negative filtering. + """ + + def __init__( + self, + temperature: float = 0.02, + normalize_scores: bool = True, + use_smooth_max: bool = False, + pos_aware_negative_filtering: bool = False, + max_batch_size: int = 1024, + tau: float = 0.1, + norm_tol: float = 1e-3, + filter_threshold: float = 0.95, + filter_factor: float = 0.5, + ): + super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor) + self.temperature = temperature + self.normalize_scores = normalize_scores + self.use_smooth_max = use_smooth_max + self.pos_aware_negative_filtering = pos_aware_negative_filtering + self.ce_loss = CrossEntropyLoss() + + def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor: + """ + Compute sigmoid loss over positive and negative document pairs. + + Args: + query_embeddings (Tensor): [B, Nq, D] + doc_embeddings (Tensor): [B, Nd, D] positive docs + + Returns: + Tensor: Scalar loss value. + """ + + lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1) + raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings) + scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2) + + if self.normalize_scores: + scores = self._apply_normalization(scores, lengths) + + batch_size = scores.size(0) + idx, pos_idx = self._get_idx(batch_size, offset, scores.device) + + if self.pos_aware_negative_filtering: + self._filter_high_negatives(scores, pos_idx) + + # for each idx in pos_idx, the 2D index (idx, idx) → flat index = idx * B + idx + # build a 1-D mask of length B*B with ones at those positions + flat_pos = pos_idx * (batch_size + 1) + pos_mask = -torch.ones(batch_size * batch_size, device=scores.device) + pos_mask[flat_pos] = 1.0 + + # flatten the scores to [B * B] + scores = scores.view(-1) / self.temperature + + return F.softplus(scores * pos_mask).mean() diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 0f0c8118..ae92178d 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -2,4 +2,4 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor -from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor +from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor \ No newline at end of file diff --git a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py index afed014b..497fe12c 100644 --- a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py +++ b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py @@ -18,6 +18,7 @@ class ColIdefics3Processor(BaseVisualRetrieverProcessor, Idefics3Processor): def __init__(self, *args, image_seq_len=64, **kwargs): super().__init__(*args, image_seq_len=image_seq_len, **kwargs) + self.tokenizer.padding_side = "left" def process_images( self, diff --git a/colpali_engine/models/modernvbert/__init__.py b/colpali_engine/models/modernvbert/__init__.py new file mode 100644 index 00000000..d6626781 --- /dev/null +++ b/colpali_engine/models/modernvbert/__init__.py @@ -0,0 +1,2 @@ +from .bivbert import BiModernVBert, BiModernVBertProcessor +from .colvbert import ColModernVBert, ColModernVBertProcessor diff --git a/colpali_engine/models/modernvbert/bivbert/__init__.py b/colpali_engine/models/modernvbert/bivbert/__init__.py new file mode 100644 index 00000000..e6098099 --- /dev/null +++ b/colpali_engine/models/modernvbert/bivbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_bimodernvbert import BiModernVBert +from .processing_bimodernvbert import BiModernVBertProcessor diff --git a/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py b/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py new file mode 100644 index 00000000..cf60cb1c --- /dev/null +++ b/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py @@ -0,0 +1,64 @@ +from typing import Literal + +import torch + +from colpali_engine.models.modernvbert.modeling_modernvbert import ModernVBertModel, ModernVBertPreTrainedModel + + +class BiModernVBert(ModernVBertPreTrainedModel): + """ + Initializes the BiModernVBert model. + + Args: + config : The model configuration. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, pooling_strategy = "mean", **kwargs): + super().__init__(config=config) + self.model = ModernVBertModel(config, **kwargs) + self.pooling_strategy = pooling_strategy + self.post_init() + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass through model and pooling. + + Args: + - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean". + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + + pooling_strategy = pooling_strategy or self.pooling_strategy + + # Get CLS token embedding, last token, or mean pool over sequence + if pooling_strategy == "cls": + # Use CLS token (first token) embedding + pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size) + elif pooling_strategy == "last": + # Use last token + pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size) + elif pooling_strategy == "mean": + # Mean pooling over sequence length + mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1) + pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + # L2 normalization + pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True).clamp_min(1e-12) + return pooled_output diff --git a/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py b/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py new file mode 100644 index 00000000..80a961ac --- /dev/null +++ b/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py @@ -0,0 +1,42 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.modernvbert.colvbert import ColModernVBertProcessor # noqa: N801 + + +class BiModernVBertProcessor(ColModernVBertProcessor): # noqa: N801 + """ + Processor for BiVBert. + """ + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiModernVBert. + + Args: + texts: List of input texts. + + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=4096 + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/modernvbert/colvbert/__init__.py b/colpali_engine/models/modernvbert/colvbert/__init__.py new file mode 100644 index 00000000..1b073552 --- /dev/null +++ b/colpali_engine/models/modernvbert/colvbert/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colmodernvbert import ColModernVBert +from .processing_colmodernvbert import ColModernVBertProcessor diff --git a/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py new file mode 100644 index 00000000..7db8bc9b --- /dev/null +++ b/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py @@ -0,0 +1,52 @@ +from torch import nn + +from colpali_engine.models.modernvbert.modeling_modernvbert import ModernVBertModel, ModernVBertPreTrainedModel + + +class ColModernVBert(ModernVBertPreTrainedModel): + """ + Initializes the ColModernVBert model. + + Args: + config : The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs): + super().__init__(config=config) + self.model = ModernVBertModel(config, **kwargs) + self.dim = 128 + self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) + self.mask_non_image_embeddings = mask_non_image_embeddings + self.main_input_name = "doc_input_ids" + + def forward(self, *args, **kwargs): + """ + Forward pass through the model and the linear layer for dimensionality reduction + + Args: + - input_ids (torch.LongTensor): The input tokens tensor. + - attention_mask (torch.LongTensor): The attention mask tensor. + + Returns: + - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim) + """ + outputs = self.model(*args, **kwargs) + last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size) + proj = self.custom_text_proj(last_hidden_states) + # normalize l2 norm + # proj = torch.where(kwargs["attention_mask"].unsqueeze(-1).bool(), proj / proj.norm(dim=-1, keepdim=True), torch.zeros_like(proj)) + proj = proj / proj.norm(dim=-1, keepdim=True).clamp_min(1e-12) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj diff --git a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py new file mode 100644 index 00000000..f9e5515c --- /dev/null +++ b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py @@ -0,0 +1,84 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature, Idefics3Processor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColModernVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): + """ + Processor for ColIdefics3. + """ + + query_augmentation_token: ClassVar[str] = "" + image_token: ClassVar[str] = "" + visual_prompt_prefix: ClassVar[str] = "<|begin_of_text|>User:Describe the image.\nAssistant:" + + def __init__(self, *args, image_seq_len=64, **kwargs): + super().__init__(*args, image_seq_len=image_seq_len, **kwargs) + self.tokenizer.padding_side = "left" + + # @property + # def image_token_id(self) -> int: + # return self.tokenizer.convert_tokens_to_ids(self.image_token) + + def process_images( + self, + images: List[Image.Image], + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColModernVBert. + + Args: + images: List of PIL images. + """ + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=[self.visual_prompt_prefix] * len(images), + images=images, + padding="longest", + return_tensors="pt", + truncation=True, + max_length=8192, + ) + return batch_doc + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColModernVBert. + + Args: + texts: List of input texts. + + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=4096, + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + patch_size: int, + ) -> Tuple[int, int]: + raise NotImplementedError("This method is not implemented for ColIdefics3.") diff --git a/colpali_engine/models/modernvbert/configuration_modernvbert.py b/colpali_engine/models/modernvbert/configuration_modernvbert.py new file mode 100644 index 00000000..d5225f8b --- /dev/null +++ b/colpali_engine/models/modernvbert/configuration_modernvbert.py @@ -0,0 +1,273 @@ +import copy +import os +from typing import Any, Dict, Union + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m" +DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" + +def collect_arg_in_candidates(config, candidates, default=None) -> Any: + """Gets the first available argument in a config given a list of candidate names.""" + for c in candidates: + if hasattr(config, c): + return getattr(config, c) + elif c in config: + return config[c] + if default is not None: + return default + raise ValueError( + f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}" + ) + +class ModernVBertTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBERT`]. It is used to instantiate an ModernBERT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + model_type = "modernvbert_text" + + def __init__( + self, + text_model_name=DEFAULT_TEXT_MODEL_NAME, + hidden_size=768, + num_hidden_layers=22, + intermediate_size=1152, + mlp_bias=False, + vocab_size=50368, + **kwargs, + ): + super().__init__( + text_model_name=text_model_name, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + mlp_bias=mlp_bias, + vocab_size=vocab_size, + **kwargs, + ) + + @classmethod + def from_base_model( + cls, + text_model_name=DEFAULT_TEXT_MODEL_NAME, + **kwargs, + ): + text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) + if hasattr(text_config, "text_config"): + text_config = text_config.text_config + + hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) + num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) + intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) + mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False) + vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) + + return cls( + text_model_name=text_model_name, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + mlp_bias=mlp_bias, + vocab_size=vocab_size, + **kwargs, + ) + +class ModernVBertVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate the vision encoder part of the ModernVBERT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the SigLIP. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + model_type = "modernvbert_vision" + + attribute_map = { + "hidden_size": "embed_dim", + } + + def __init__( + self, + vision_model_name=DEFAULT_VISION_MODEL_NAME, + embed_dim=768, + image_size=512, + patch_size=16, + num_hidden_layers=12, + intermediate_size=3072, + **kwargs, + ): + super().__init__( + vision_model_name=vision_model_name, + embed_dim=embed_dim, + image_size=image_size, + patch_size=patch_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + **kwargs, + ) + + @classmethod + def from_base_model( + cls, + vision_model_name=DEFAULT_VISION_MODEL_NAME, + **kwargs, + ): + vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) + if hasattr(vision_config, "vision_config"): + vision_config = vision_config.vision_config + + embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) + image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) + patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) + num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) + intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) + + return cls( + vision_model_name=vision_model_name, + embed_dim=embed_dim, + image_size=image_size, + patch_size=patch_size, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + **kwargs, + ) + + +class ModernVBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a `ModernVBert` model. It is used to + instantiate a ModernVBert model according to the specified arguments and defines the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. + See the documentation for [`PretrainedConfig`] for more details. + + Args: + text_config (`PretrainedConfig` or `dict`, optional): + Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the + default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used. + vision_config (`PretrainedConfig` or `dict`, optional): + Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the + default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used. + image_token_id (`int`, optional, defaults to 128257): + Token id reserved for image tokens inserted into the text stream. + vocab_size (`int`, optional, defaults to 128256): + Vocabulary size used by the text embeddings. + use_cache (`bool`, optional, defaults to `True`): + Whether to cache key/value tensors for attention (relevant for decoder architectures). + tie_word_embeddings (`bool`, optional, defaults to `False`): + Whether to tie input token embeddings and output token embeddings. + pixel_shuffle_factor (`int`, optional, defaults to 4): + Scale factor used by any pixel-shuffle / upsampling operations in the vision head. + additional_vocab_size (`int`, optional, defaults to 0): + Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens). + pad_token_id (`int`, optional): + Padding token id. + initializer_range (`float`, optional, defaults to 0.02): + Stddev used for weight initialization. + freeze_config (`Any`, optional): + Optional config describing which submodules to freeze during training. + use_resampler (`bool`, optional, defaults to `False`): + Whether to enable an additional resampler on visual features. + neftune_noise_alpha (`float`, optional, defaults to 0.0): + Alpha parameter for neftune noise injection. + + Example: + ```python + >>> from modernvbert import ModernVBertConfig + >>> # Initializing configuration + >>> configuration = ModernVBertConfig() + >>> # Initializing a model from the configuration (model class is implemented in + >>> # `modernvbert.modeling_modernvbert`) + >>> # from modernvbert import ModernVBertModel + >>> # model = ModernVBertModel(configuration) + >>> # Accessing the model configuration + >>> # cfg = model.config + ```""" + + model_type = "modernvbert" + is_composition = True + + def __init__( + self, + text_config: Union[PretrainedConfig, Dict[str, Any]] = None, + vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, + image_token_id: int = 50407, + vocab_size=50368, + use_cache=True, + tie_word_embeddings=False, + freeze_config=None, + pad_token_id=None, + initializer_range=0.02, + pixel_shuffle_factor=4, + use_resampler=False, + additional_vocab_size=0, + neftune_noise_alpha=0.0, + **kwargs, + ): + self.image_token_id = image_token_id + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings + self.scale_factor = pixel_shuffle_factor + self.additional_vocab_size = additional_vocab_size + + if text_config is None: + base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) + text_config = ModernVBertTextConfig(base_text_config) + elif isinstance(text_config, dict): + text_config = ModernVBertTextConfig.from_dict(text_config) + self.text_config = text_config + + if vision_config is None: + base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) + vision_config = ModernVBertVisionConfig(base_vision_config) + elif isinstance(vision_config, dict): + vision_config = ModernVBertVisionConfig.from_dict(vision_config) + self.vision_config = vision_config + + self.freeze_config = freeze_config + self.pixel_shuffle_factor = pixel_shuffle_factor + self.use_resampler = use_resampler + self.neftune_noise_alpha = neftune_noise_alpha + self.initializer_range = initializer_range + + hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) + + super().__init__( + **kwargs, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + vocab_size=vocab_size, + hidden_size=hidden_size, + ) + + def to_dict(self): + output = copy.deepcopy(self.__dict__) + output["model_type"] = self.__class__.model_type + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + return output + + @classmethod + def from_pretrained_models( + cls, + text_model_name: Union[str, os.PathLike], + vision_model_name: Union[str, os.PathLike], + **kwargs, + ) -> "PretrainedConfig": + text_model_config = ModernVBertTextConfig.from_base_model(text_model_name) + vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name) + return cls( + text_config=text_model_config, + vision_config=vision_model_config, + **kwargs, + ) \ No newline at end of file diff --git a/colpali_engine/models/modernvbert/modeling_modernvbert.py b/colpali_engine/models/modernvbert/modeling_modernvbert.py new file mode 100644 index 00000000..2dc468ba --- /dev/null +++ b/colpali_engine/models/modernvbert/modeling_modernvbert.py @@ -0,0 +1,456 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput + +from .configuration_modernvbert import ModernVBertConfig + +logger = logging.get_logger(__name__) + + +class DecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. + In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained. + If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=False, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`. + partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen. + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}") + + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), + since the 2nd embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do + the padding, but then we have to create a new tensor and populate it with 2 tensors that are + spread out across various indices - i.e. not a simple concat - I haven't benchmarked the + complex case if it's any faster, given that seqlens are usually relatively short it's + probably not faster or if faster not by much - but might be a good idea to measure. + + """ + if self.num_additional_embeddings == 0: + return super().forward(input_ids) + + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + full_vector[additional_vocab_indices] = additional_embeddings # overwrite the records with high indices + return full_vector + + +@dataclass +class ModernVBertBaseModelOutput(BaseModelOutput): + """ + Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ModernVBertMaskedLMOutput(MaskedLMOutput): + """ + Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class ModernVBertSimpleMLP(nn.Module): + """A simple linear projection layer to project the vision hidden states to the text hidden states.""" + def __init__(self, input_size, output_size): + super().__init__() + self.proj = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + return self.proj(x) + + +class ModernVBertConnector(nn.Module): + """ + Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size. + Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html + """ + def __init__(self, config): + super().__init__() + self.scale_factor = config.pixel_shuffle_factor + self.modality_projection = ModernVBertSimpleMLP( + input_size=config.vision_config.hidden_size * (config.scale_factor**2), + output_size=config.text_config.hidden_size, + ) + + def pixel_shuffle(self, x, scale_factor): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + return x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + return self.modality_projection(image_hidden_states) + + +class ModernVBertPreTrainedModel(PreTrainedModel): + config_class = ModernVBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", 0.02) + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class ModernVBertModel(ModernVBertPreTrainedModel): + def __init__(self, config: ModernVBertConfig): + super().__init__(config) + self.vision_model = ModernVBertModel.init_vision_model(config) + self.connector = ModernVBertConnector(config) + self.text_model = ModernVBertModel.init_language_model(config) + self.image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) + ) + self.image_token_id = config.image_token_id + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + # set the correct dtype for vision and text models + self.vision_model.to(self.dtype) + self.text_model.to(self.dtype) + self.post_init() + + @staticmethod + def init_vision_model(config: ModernVBertConfig): + vision_model_config = AutoConfig.from_pretrained( + config.vision_config.vision_model_name, + _attn_implementation=config._attn_implementation, + ) + vision_model = AutoModel.from_config( + vision_model_config, + trust_remote_code=True, + ) + return getattr(vision_model, "vision_model", vision_model) + + @staticmethod + def init_language_model(config: ModernVBertConfig): + text_model_config = AutoConfig.from_pretrained( + config.text_config.text_model_name, + _attn_implementation=config._attn_implementation, + trust_remote_code=True, + ) + text_model = AutoModel.from_config( + text_model_config, + trust_remote_code=True + ) + embed_layer = DecoupledEmbedding( + num_embeddings=text_model_config.vocab_size, + num_additional_embeddings=config.additional_vocab_size, + embedding_dim=config.hidden_size, + partially_freeze=config.freeze_config["freeze_text_layers"], + padding_idx=config.pad_token_id, + ) + text_model.set_input_embeddings(embed_layer) + return text_model + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. + + This is useful for lora when using gradient checkpointing. + c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 + + Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. + """ + + def get_lowest_module(module): + if len(list(module.children())) == 0: + # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) + return module + else: + # Recursively call the function on each child module + return get_lowest_module(list(module.children())[0]) + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( + make_inputs_require_grads + ) + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states): + """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py + + This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. + The merging happens as follows: + - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. + - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. + We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. + - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. + - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. + """ + + _, patch_size, _ = image_hidden_states.shape + image_mask = input_ids == self.image_token_id + num_image_tokens = image_mask.sum(dim=1) + if not torch.all(num_image_tokens % patch_size == 0): + raise ValueError("Number of tokens not divisible by patch_size.") + blocks_per_sample = num_image_tokens // patch_size + offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) + block_offset = offsets[:-1] + row_cum = image_mask.cumsum(dim=-1) + chunk_idx = (row_cum - 1) // patch_size + local_idx = (row_cum - 1) % patch_size + block_idx = block_offset.unsqueeze(1) + chunk_idx + image_embeds = torch.zeros_like(inputs_embeds) + image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] + return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) + if pixel_values is not None: + batch_size, num_images, _, _, _ = pixel_values.shape + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + if not any(real_images_inds): + real_images_inds[0] = True + pixel_values = pixel_values[real_images_inds].contiguous() + image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state + image_hidden_states = self.connector(image_hidden_states) + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + if inputs_embeds is not None and image_hidden_states is not None: + inputs_embeds = self.inputs_merger(input_ids, inputs_embeds, image_hidden_states) + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + return ModernVBertBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + +class ModernVBertLMHead(nn.Module): + def __init__(self, config): + super().__init__() + pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True) + pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True) + self.head = pretrained_model.head + self.decoder = pretrained_model.decoder + + def forward(self, hidden_states): + return self.decoder(self.head(hidden_states)) + + +class ModernVBertForMaskedLM(ModernVBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.image_token_id = config.image_token_id + self.in_features = config.hidden_size + self.out_additional_features = config.additional_vocab_size + self.vocab_size = config.vocab_size + self.model = ModernVBertModel(config) + self.lm_head = ModernVBertLMHead(config) + if self.out_additional_features > 0: + self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False) + self.lm_head.to(self.dtype) + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, ModernVBertMaskedLMOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + if self.out_additional_features > 0: + proj_states = self.lm_head.head(hidden_states) + additional_features = self.additional_fc(proj_states) + logits = torch.cat((logits, additional_features), -1) + loss = None + if labels is not None: + loss = CrossEntropyLoss()(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1)) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + return ModernVBertMaskedLMOutput( + loss=loss, + logits=logits.float(), + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) \ No newline at end of file diff --git a/colpali_engine/models/qwen_omni/__init__.py b/colpali_engine/models/qwen_omni/__init__.py deleted file mode 100644 index 7dd08129..00000000 --- a/colpali_engine/models/qwen_omni/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .colqwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py b/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py deleted file mode 100644 index b754b552..00000000 --- a/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .modeling_colqwen_omni import ColQwen2_5Omni -from .processing_colqwen_omni import ColQwen2_5OmniProcessor diff --git a/colpali_engine/trainer/contrastive_trainer.py b/colpali_engine/trainer/contrastive_trainer.py index 78514eba..ccd699c3 100644 --- a/colpali_engine/trainer/contrastive_trainer.py +++ b/colpali_engine/trainer/contrastive_trainer.py @@ -18,31 +18,41 @@ def concat_all_gather(t: torch.Tensor) -> torch.Tensor: return t +def concat_datasets(datasets: list[Dataset], batch_size: int) -> Dataset: + """ + Concatenates a list of datasets into a single dataset. + This is a utility function to handle the case where multiple datasets are provided. + """ + # round down each dataset if not divible by global batch size + for i in range(len(datasets)): + if len(datasets[i]) % batch_size != 0: + total_samples = (len(datasets[i]) // batch_size) * batch_size + datasets[i] = datasets[i].take(total_samples) + + return ConcatDataset(datasets) + + class ContrastiveTrainer(Trainer): - def __init__(self, loss_func, is_vision_model, *args, **kwargs): - if isinstance(kwargs["train_dataset"], DatasetDict): - dataset_list = list(kwargs["train_dataset"].values()) - elif isinstance(kwargs["train_dataset"], list): - dataset_list = kwargs["train_dataset"] + def __init__(self, loss_func, is_vision_model, compute_symetric_loss=False, *args, **kwargs): + if isinstance(kwargs["train_dataset"], list): + train_dataset_list = kwargs["train_dataset"] + kwargs["train_dataset"] = concat_datasets(train_dataset_list, batch_size=kwargs["args"].train_batch_size) else: - dataset_list = None + train_dataset_list = None - if isinstance(dataset_list, list): - # round down each dataset if not divible by global batch size - batch_size = kwargs["args"].train_batch_size - for i in range(len(dataset_list)): - if len(dataset_list[i]) % batch_size != 0: - total_samples = (len(dataset_list[i]) // batch_size) * batch_size - dataset_list[i] = dataset_list[i].take(total_samples) - - if dataset_list is not None: - kwargs["train_dataset"] = ConcatDataset(dataset_list) + if isinstance(kwargs["eval_dataset"], list): + eval_dataset_list = kwargs["eval_dataset"] + kwargs["eval_dataset"] = concat_datasets(eval_dataset_list) + else: + eval_dataset_list = None super().__init__(*args, **kwargs) self.loss_func = loss_func self.is_vision_model = is_vision_model # Unused argument, will be removed in 0.4.0 self.args.remove_unused_columns = False # Safety, don't remove dataset columns from dataloader - self.dataset_list = dataset_list + self.train_dataset_list = train_dataset_list + self.eval_dataset_list = eval_dataset_list + self.compute_symetric_loss = compute_symetric_loss def get_train_dataloader(self) -> DataLoader: """ @@ -55,6 +65,10 @@ def get_train_dataloader(self) -> DataLoader: """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") + + if self.train_dataset_list is None: + # If no dataset list, use the default behavior + return super().get_train_dataloader() dataset = self.train_dataset description = "Training" @@ -63,9 +77,6 @@ def get_train_dataloader(self) -> DataLoader: is_training = True dataloader_key = None - if self.dataset_list is None: - return super()._get_dataloader(dataset, description, batch_size, sampler_fn, is_training, dataloader_key) - data_collator = self.data_collator if is_datasets_available() and isinstance(dataset, datasets.Dataset): dataset = self._remove_unused_columns(dataset, description=description) @@ -83,7 +94,7 @@ def get_train_dataloader(self) -> DataLoader: if not isinstance(dataset, torch.utils.data.IterableDataset): if sampler_fn is not None: ###### batch_sampler set instead of sampler in trainer code ####### - dataloader_params["batch_sampler"] = sampler_fn(dataset) + dataloader_params["batch_sampler"] = sampler_fn() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if is_training: @@ -103,9 +114,9 @@ def get_train_dataloader(self) -> DataLoader: return self.accelerator.prepare(dataloader) - def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: - if self.dataset_list is None: - return super()._get_train_sampler(train_dataset=train_dataset) + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset_list is None: + return super()._get_train_sampler() # Use SingleDatasetBatchSampler to ensure that each dataset in the list is sampled independently # Note: Surely breaks in distributed training @@ -113,33 +124,84 @@ def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optiona generator = torch.Generator() generator.manual_seed(self.args.seed) return SingleDatasetBatchSampler( - self.dataset_list, + self.train_dataset_list, self.args.train_batch_size, drop_last=self.args.dataloader_drop_last, generator=generator, ) + + def _compute_loss_from_outputs( + self, + query_outputs, + pos_target_outputs, + neg_target_outputs=None, + ): + offset = 0 + batch_size = query_outputs.size(0) + if self.accelerator.num_processes > 1 and self.accelerator.sync_gradients: + # gather docs across all processes + pos_target_outputs = self.accelerator.pad_across_processes(pos_target_outputs, dim=1, pad_index=0, pad_first=True) + pos_target_outputs = concat_all_gather(pos_target_outputs) + rank = self.accelerator.process_index + offset = rank * batch_size + + if neg_target_outputs is not None: + loss = self.loss_func( + query_embeddings=query_outputs, + doc_embeddings=pos_target_outputs, + neg_doc_embeddings=neg_target_outputs, + offset=offset + ) + else: + loss = self.loss_func( + query_embeddings=query_outputs, + doc_embeddings=pos_target_outputs, + offset=offset + ) + + return loss + + def _reshape_neg_doc_inputs(self, inputs): + """ + Helper function to reshape negative doc inputs to (batch_size * num_neg_docs, ...) + """ + neg_doc_inputs = {k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")} + + for k in neg_doc_inputs: + # go from (batch_size, num_neg_docs, ...) to (batch_size * num_neg_docs, ...) + neg_doc_inputs[k] = neg_doc_inputs[k].view(-1, *neg_doc_inputs[k].shape[2:]) + + return neg_doc_inputs + + def _reshape_neg_doc_outputs(self, neg_doc_outputs, num_neg_docs): + """ + Helper function to reshape negative doc outputs to (batch_size, num_neg_docs, ...) + """ + neg_doc_outputs = neg_doc_outputs.view(-1, num_neg_docs, *neg_doc_outputs.shape[1:]) + + return neg_doc_outputs def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) + query_outputs = model(**{k[6:]: v for k, v in inputs.items() if k.startswith("query")}) # feed only kwargs with 'doc_' prefix doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) if "neg_doc_input_ids" in inputs: # Negative docs are not gathered across processes, so we can use them without offset - neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")}) - loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs) - return (loss, (query_outputs, doc_outputs, neg_doc_outputs)) if return_outputs else loss - - offset = 0 - if self.accelerator.num_processes > 1 and self.accelerator.sync_gradients: - # gather docs across all processes - if num_items_in_batch is None: - num_items_in_batch = inputs["doc_input_ids"].shape[0] - doc_outputs = self.accelerator.pad_across_processes(doc_outputs, dim=1, pad_index=0, pad_first=True) - doc_outputs = concat_all_gather(doc_outputs) - rank = self.accelerator.process_index - offset = rank * num_items_in_batch - - loss = self.loss_func(query_outputs, doc_outputs, offset=offset) + num_negs = inputs["neg_doc_input_ids"].size(1) + neg_doc_inputs = self._reshape_neg_doc_inputs(inputs) + neg_doc_outputs = model(**neg_doc_inputs) + neg_doc_outputs = self._reshape_neg_doc_outputs(neg_doc_outputs, num_negs) + else: + neg_doc_outputs = None + + # query -> doc loss + loss = self._compute_loss_from_outputs(query_outputs, doc_outputs, neg_doc_outputs) + + if self.compute_symetric_loss: + assert neg_doc_outputs is None, "Symmetric loss is not compatible with negative documents." + # doc -> query loss + sym_loss = self._compute_loss_from_outputs(doc_outputs, query_outputs) + loss = (loss + sym_loss) / 2 return (loss, (query_outputs, doc_outputs)) if return_outputs else loss diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 537b21a2..4c1d9617 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -22,7 +22,7 @@ class BaseVisualRetrieverProcessor(ABC): Base class for visual retriever processors. """ - query_prefix: ClassVar[str] = "Query: " # Default prefix for queries. Override in subclasses if needed. + query_prefix: ClassVar[str] = "" # Default prefix for queries. Override in subclasses if needed. @abstractmethod def process_images( @@ -56,6 +56,7 @@ def process_queries( texts: Optional[List[str]] = None, queries: Optional[List[str]] = None, max_length: int = 50, + contexts: Optional[List[str]] = None, suffix: Optional[str] = None, ) -> Union[BatchFeature, BatchEncoding]: """ @@ -109,14 +110,17 @@ def score_single_vector( """ device = device or get_torch_device("auto") - if len(qs) == 0: - raise ValueError("No queries provided") - if len(ps) == 0: - raise ValueError("No passages provided") + if isinstance(qs, list) and isinstance(ps, list): + if len(qs) == 0: + raise ValueError("No queries provided") + if len(ps) == 0: + raise ValueError("No passages provided") - if isinstance(qs, list): qs = torch.stack(qs).to(device) ps = torch.stack(ps).to(device) + else: + qs = qs.to(device) + ps = ps.to(device) scores = torch.einsum("bd,cd->bc", qs, ps) assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" diff --git a/pyproject.toml b/pyproject.toml index 74b5a372..e38ed4b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,9 @@ dependencies = [ "pillow>=10.0.0", "requests", "scipy", - "torch>=2.5.0,<2.8.0", + "torch>=2.2.0,<2.8.0", "torchvision", - "transformers>=4.53.1,<4.54.0", + "transformers>=4.53.1,<4.54.0" ] [project.optional-dependencies] diff --git a/tests/loss/test_bi_losses.py b/tests/loss/test_bi_losses.py index ea4cf5c3..b96ebafd 100644 --- a/tests/loss/test_bi_losses.py +++ b/tests/loss/test_bi_losses.py @@ -64,10 +64,10 @@ def test_forward_with_filtering(self): class TestBiNegativeCELoss: def test_forward_no_inbatch(self): loss_fn = BiNegativeCELoss(temperature=1.0, in_batch_term_weight=0, pos_aware_negative_filtering=False) - B, D = 3, 4 + B, D, Nneg = 3, 4, 1 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg, D) loss = loss_fn(query, pos, neg) # softplus(0 - 0) = ln(2) expected = F.softplus(torch.tensor(0.0)) @@ -75,10 +75,10 @@ def test_forward_no_inbatch(self): def test_forward_with_inbatch(self): loss_fn = BiNegativeCELoss(temperature=1.0, in_batch_term_weight=0.5, pos_aware_negative_filtering=False) - B, D = 2, 3 + B, D, Nneg = 2, 3, 1 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg, D) loss = loss_fn(query, pos, neg) # in-batch CE on zeros: log(B) ce = torch.log(torch.tensor(float(B))) @@ -110,20 +110,20 @@ def test_forward_with_filtering(self): class TestBiPairwiseNegativeCELoss: def test_forward_no_inbatch(self): loss_fn = BiPairwiseNegativeCELoss(temperature=1.0, in_batch_term_weight=0) - B, D = 5, 4 + B, Nneg, D = 5, 2, 4 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg,D) loss = loss_fn(query, pos, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) def test_forward_with_inbatch(self): loss_fn = BiPairwiseNegativeCELoss(temperature=1.0, in_batch_term_weight=0.5) - B, D = 2, 3 + B, Nneg, D = 2, 3, 4 query = torch.zeros(B, D) pos = torch.zeros(B, D) - neg = torch.zeros(B, D) + neg = torch.zeros(B, Nneg, D) loss = loss_fn(query, pos, neg) # both explicit and in-batch pairwise yield ln(2), average remains ln(2) expected = F.softplus(torch.tensor(0.0)) diff --git a/tests/loss/test_li_losses.py b/tests/loss/test_li_losses.py index 77faf0f1..4b34f586 100644 --- a/tests/loss/test_li_losses.py +++ b/tests/loss/test_li_losses.py @@ -109,10 +109,10 @@ def test_no_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0, ) - B, Nq, D, Nneg = 2, 1, 3, 1 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, Nneg, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) @@ -125,10 +125,10 @@ def test_with_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0.5, ) - B, Nq, D = 2, 1, 3 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, 1, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) @@ -156,10 +156,10 @@ def test_no_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0, ) - B, Nq, D = 2, 1, 3 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, 1, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) @@ -172,10 +172,10 @@ def test_with_inbatch(self): pos_aware_negative_filtering=False, in_batch_term_weight=0.5, ) - B, Nq, D = 2, 1, 3 - query = torch.zeros(B, Nq, D) - doc = torch.zeros(B, Nq, D) - neg = torch.zeros(B, 1, D) + B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1 + query = torch.zeros(B, Lq, D) + doc = torch.zeros(B, Lq, D) + neg = torch.zeros(B, Nneg, Lneg, D) loss = loss_fn(query, doc, neg) expected = F.softplus(torch.tensor(0.0)) assert torch.allclose(loss, expected) diff --git a/tests/models/modernvbert/test_modeling_colmodernvbert.py b/tests/models/modernvbert/test_modeling_colmodernvbert.py new file mode 100644 index 00000000..098f8223 --- /dev/null +++ b/tests/models/modernvbert/test_modeling_colmodernvbert.py @@ -0,0 +1,148 @@ +import logging +from typing import Generator, cast + +import pytest +import torch +from datasets import load_dataset +from PIL import Image +from transformers.utils.import_utils import is_flash_attn_2_available + +from colpali_engine.models import ColModernVBert, ColModernVBertProcessor +from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "ModernVBERT/colmodernvbert" + + +@pytest.fixture(scope="module") +def model_without_mask(model_name: str) -> Generator[ColModernVBert, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColModernVBert, + ColModernVBert.from_pretrained( + model_name, + torch_dtype=torch.float32, + device_map=device, + attn_implementation="eager", + mask_non_image_embeddings=False, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def model_with_mask(model_name: str) -> Generator[ColModernVBert, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColModernVBert, + ColModernVBert.from_pretrained( + model_name, + torch_dtype=torch.float32, + device_map=device, + attn_implementation="eager", + mask_non_image_embeddings=True, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def processor(model_name: str) -> Generator[ColModernVBertProcessor, None, None]: + yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name)) + + +class TestColModernVBert_Model: # noqa N801 + @pytest.mark.slow + def test_load_model_from_pretrained(self, model_without_mask: ColModernVBert): + assert isinstance(model_without_mask, ColModernVBert) + + +class TestColModernVBert_ModelIntegration: # noqa N801 + @pytest.mark.slow + def test_forward_images_integration( + self, + model_without_mask: ColModernVBert, + processor: ColModernVBertProcessor, + ): + # Create a batch of dummy images + images = [ + Image.new("RGB", (64, 64), color="white"), + Image.new("RGB", (32, 32), color="black"), + ] + + # Process the image + batch_images = processor.process_images(images).to(model_without_mask.device) + + # Forward pass + with torch.no_grad(): + outputs = model_without_mask(**batch_images) + + # Assertions + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_visual_tokens, emb_dim = outputs.shape + assert batch_size == len(images) + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_forward_queries_integration( + self, + model_without_mask: ColModernVBert, + processor: ColModernVBertProcessor, + ): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_queries = processor.process_queries(queries).to(model_without_mask.device).to(torch.float32) + + # Forward pass + with torch.no_grad(): + outputs = model_without_mask(**batch_queries) + + # Assertions + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_query_tokens, emb_dim = outputs.shape + assert batch_size == len(queries) + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_retrieval_integration( + self, + model_without_mask: ColModernVBert, + processor: ColModernVBertProcessor, + ): + # Load the test dataset + ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") + + # Preprocess the examples + batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device).to(torch.float32) + batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device).to(torch.float32) + + # Run inference + with torch.inference_mode(): + image_embeddings = model_without_mask(**batch_images) + query_embeddings = model_without_mask(**batch_queries) + + # Compute retrieval scores + scores = processor.score_multi_vector( + qs=query_embeddings, + ps=image_embeddings, + ) # (len(qs), len(ps)) + + assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" + assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" + + # # Check if the maximum scores per row are in the diagonal of the matrix score + # assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() diff --git a/tests/models/modernvbert/test_processing_colmodernvbert.py b/tests/models/modernvbert/test_processing_colmodernvbert.py new file mode 100644 index 00000000..236ebc8a --- /dev/null +++ b/tests/models/modernvbert/test_processing_colmodernvbert.py @@ -0,0 +1,64 @@ +from typing import Generator, cast + +import pytest +import torch +from PIL import Image + +from colpali_engine.models import ColModernVBertProcessor + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "ModernVBERT/colmodernvbert" + + +@pytest.fixture(scope="module") +def processor_from_pretrained(model_name: str) -> Generator[ColModernVBertProcessor, None, None]: + yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name)) + + +def test_load_processor_from_pretrained(processor_from_pretrained: ColModernVBertProcessor): + assert isinstance(processor_from_pretrained, ColModernVBertProcessor) + + +def test_process_images(processor_from_pretrained: ColModernVBertProcessor): + # Create a dummy image + image_size = (64, 32) + image = Image.new("RGB", image_size, color="black") + images = [image] + + # Process the image + batch_feature = processor_from_pretrained.process_images(images) + + # Assertions + assert "pixel_values" in batch_feature + assert isinstance(batch_feature["pixel_values"], torch.Tensor) + assert batch_feature["pixel_values"].shape[0] == 1 + +def test_process_texts(processor_from_pretrained: ColModernVBertProcessor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_encoding = processor_from_pretrained.process_texts(queries) + + # Assertions + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) + +def test_process_queries(processor_from_pretrained: ColModernVBertProcessor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + # Process the queries + batch_encoding = processor_from_pretrained.process_queries(queries) + + # Assertions + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)