-
Notifications
You must be signed in to change notification settings - Fork 207
Vbert #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Vbert #339
Changes from all commits
fc27f3c
31c0709
32de63c
d00b267
7fba1c6
43d3d36
ed11060
55ebd0c
4ddc453
e54df49
0375d68
2ab0cb0
00337b1
ec4d4dd
91e9f36
91ba4be
81eef80
9a82c1f
245bb33
2ebe2ab
44fe1e6
1ec65fc
1b8510f
3dfbe4b
d748aa1
afd0e95
dcbbe15
24cd010
5c11cd3
da868ae
fa1ea76
3fb3df4
31630d1
9ce2871
20e78cc
43fba98
058a299
d1e3f38
133bc51
df0d1a8
8c89c49
c6d4dd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,22 @@ | ||
| from .models import ( | ||
| BiPali, | ||
| BiPaliProj, | ||
| BiQwen2, | ||
| BiQwen2_5, | ||
| BiQwen2_5_Processor, | ||
| BiQwen2Processor, | ||
| BiModernVBert, | ||
| BiModernVBertProcessor, | ||
| ColIdefics3, | ||
| ColIdefics3Processor, | ||
| ColPali, | ||
| ColPaliProcessor, | ||
| ColQwen2, | ||
| ColQwen2_5, | ||
| ColQwen2_5_Processor, | ||
| ColQwen2_5Omni, | ||
| ColQwen2_5OmniProcessor, | ||
| # ColQwen2_5Omni, | ||
| # ColQwen2_5OmniProcessor, | ||
| ColQwen2Processor, | ||
| ColModernVBert, | ||
| ColModernVBertProcessor, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| import random | ||
| from typing import Any, Dict, List, Union | ||
| import torch | ||
|
|
||
| from PIL.Image import Image | ||
|
|
||
| from colpali_engine.data.dataset import ColPaliEngineDataset | ||
| from colpali_engine.models.paligemma import ColPaliProcessor | ||
| from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | ||
|
|
||
|
|
||
| def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: | ||
|
|
@@ -69,16 +70,18 @@ | |
|
|
||
| neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) | ||
| if neg_tgt is not None: | ||
| sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt | ||
| neg_targets.append(sampled_neg) | ||
| neg_targets.append(neg_tgt) | ||
|
|
||
| # Ensure all queries are strings or images. | ||
| assert all(isinstance(q, str) for q in queries), ( | ||
| "All queries must be strings, this collator does not support images in queries." | ||
| ) | ||
|
|
||
| is_str = isinstance(queries[0], str) | ||
|
|
||
| # Process queries. | ||
| queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
| # queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented lines if not useful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually usefull, in modernvbert |
||
| queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put 10 into a constant (e.g. |
||
| batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) | ||
|
|
||
| # Process targets. | ||
|
|
@@ -102,6 +105,26 @@ | |
| proc_batch = self.processor.process_texts(texts=batch) | ||
| elif isinstance(batch[0], Image): | ||
| proc_batch = self.processor.process_images(images=batch) | ||
| elif isinstance(batch[0], list): | ||
| if isinstance(batch[0][0], str): | ||
| proc_texts_batch = [] | ||
| batch_size = len(batch) | ||
| all_texts = [text for texts in batch for text in texts] | ||
| num_negatives = len(all_texts) // batch_size | ||
| proc_batch = self.processor.process_texts(texts=all_texts) | ||
| elif isinstance(batch[0][0], Image): | ||
| proc_imgs_batch = [] | ||
| batch_size = len(batch) | ||
| all_imgs = [img for imgs in batch for img in imgs] | ||
| num_negatives = len(all_imgs) // batch_size | ||
| proc_batch = self.processor.process_images(images=all_imgs) | ||
| else: | ||
| raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.") | ||
| for k, v in proc_batch.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:]) | ||
| else: | ||
| proc_batch[k] = v | ||
|
Comment on lines
+126
to
+127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary |
||
| else: | ||
| raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") | ||
| return prefix_keys(proc_batch, key_prefix) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import torch | ||
| import torch.nn.functional as F # noqa: N812 | ||
| from torch.nn import CrossEntropyLoss | ||
|
|
||
|
|
||
|
|
@@ -111,6 +112,60 @@ def forward( | |
|
|
||
| return self.ce_loss(scores / self.temperature, pos_idx) | ||
|
|
||
| class BiPairedEncoderLoss(BiEncoderModule): | ||
| """ | ||
| InfoNCE loss for bi-encoders without explicit negatives. | ||
|
|
||
| Args: | ||
| temperature (float): Scaling factor for logits. | ||
| pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True. | ||
| max_batch_size (int): Max batch size for index buffer caching. | ||
| filter_threshold (float): Threshold ratio for negative filtering. | ||
| filter_factor (float): Factor to down-weight filtered negatives. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| temperature: float = 0.02, | ||
| pos_aware_negative_filtering: bool = False, | ||
| max_batch_size: int = 1024, | ||
| filter_threshold: float = 0.95, | ||
| filter_factor: float = 0.5, | ||
| ): | ||
| super().__init__(max_batch_size, temperature, filter_threshold, filter_factor) | ||
| self.pos_aware_negative_filtering = pos_aware_negative_filtering | ||
| self.ce_loss = CrossEntropyLoss() | ||
|
|
||
| def forward( | ||
| self, | ||
| query_embeddings: torch.Tensor, | ||
| doc_embeddings: torch.Tensor, | ||
| offset: int = 0, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute the InfoNCE loss over a batch of bi-encoder embeddings. | ||
|
|
||
| Args: | ||
| query_embeddings (Tensor[B, D]): Query vectors. | ||
| doc_embeddings (Tensor[B, D]): Document vectors. | ||
| offset (int): Offset for positive indices (multi-GPU). | ||
|
|
||
| Returns: | ||
| Tensor: Scalar cross-entropy loss. | ||
| """ | ||
| # Compute in-batch similarity matrix | ||
| scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) | ||
| batch_size = scores.size(0) | ||
| idx, pos_idx = self._get_idx(batch_size, offset, scores.device) | ||
|
|
||
| if self.pos_aware_negative_filtering: | ||
| self._filter_high_negatives(scores, pos_idx) | ||
|
|
||
| q2t = self.ce_loss(scores / self.temperature, pos_idx) | ||
| t2q = self.ce_loss(scores.T / self.temperature, ...) | ||
|
|
||
| return (q2t + t2q) / 2.0 | ||
|
Comment on lines
+164
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be good to add a comment (either here or in function comment) to mention the symmetric loss |
||
|
|
||
|
|
||
| class BiNegativeCELoss(BiEncoderModule): | ||
| """ | ||
|
|
@@ -161,17 +216,18 @@ def forward( | |
| Args: | ||
| query_embeddings (Tensor[B, D]): Query vectors. | ||
| doc_embeddings (Tensor[B, D]): Positive document vectors. | ||
| neg_doc_embeddings (Tensor[B, D]): Negative document vectors. | ||
| neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors. | ||
| offset (int): Offset for in-batch CE positives. | ||
|
|
||
| Returns: | ||
| Tensor: Scalar loss value. | ||
| """ | ||
| # Dot-product only for matching pairs | ||
| pos_scores = (query_embeddings * doc_embeddings).sum(dim=1) / self.temperature | ||
| neg_scores = (query_embeddings * neg_doc_embeddings).sum(dim=1) / self.temperature | ||
| pos_scores = (query_embeddings * doc_embeddings[offset:offset + neg_doc_embeddings.size(0)]).sum(dim=1) | ||
| pos_scores /= self.temperature | ||
| neg_scores = torch.einsum("bd,bnd->bn", query_embeddings, neg_doc_embeddings) / self.temperature | ||
|
|
||
| loss = torch.nn.functional.softplus(neg_scores - pos_scores).mean() | ||
| loss = F.softplus(neg_scores - pos_scores.unsqueeze(1)).mean() | ||
|
|
||
| if self.in_batch_term_weight > 0: | ||
| loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset) | ||
|
|
@@ -206,6 +262,7 @@ def forward( | |
| self, | ||
| query_embeddings: torch.Tensor, | ||
| doc_embeddings: torch.Tensor, | ||
| offset: int = 0, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really want to add this argument if it's not supported in this loss? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch, legacy bug for this loss I suppose, all losses should support |
||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute softplus(hardest_neg - pos) where hardest_neg is the highest off-diagonal score. | ||
|
|
@@ -267,26 +324,93 @@ def forward( | |
| query_embeddings: torch.Tensor, | ||
| doc_embeddings: torch.Tensor, | ||
| neg_doc_embeddings: torch.Tensor, | ||
| offset: int = 0, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really want to add this argument if it's not supported in this loss? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as before |
||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute softplus(neg-explicit - pos) plus optional pairwise in-batch loss. | ||
|
|
||
| Args: | ||
| query_embeddings (Tensor[B, D]): Query vectors. | ||
| doc_embeddings (Tensor[B, D]): Positive document vectors. | ||
| neg_doc_embeddings (Tensor[B, D]): Negative document vectors. | ||
| neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors. | ||
|
|
||
| Returns: | ||
| Tensor: Scalar loss value. | ||
| """ | ||
| # dot product for matching pairs only | ||
| pos = (query_embeddings * doc_embeddings).sum(dim=1) | ||
| neg = (query_embeddings * neg_doc_embeddings).sum(dim=1) | ||
| pos = (query_embeddings * doc_embeddings).sum(dim=1) # B | ||
| neg = (query_embeddings.unsqueeze(1) * neg_doc_embeddings).sum(dim=2) # B x N | ||
|
|
||
| loss = torch.nn.functional.softplus((neg - pos) / self.temperature).mean() | ||
| loss = torch.nn.functional.softplus((neg - pos.unsqueeze(1)) / self.temperature).mean() | ||
|
|
||
| if self.in_batch_term_weight > 0: | ||
| loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings) | ||
| loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight | ||
|
|
||
| return loss | ||
|
|
||
| class BiSigmoidLoss(BiEncoderModule): | ||
| """ | ||
| Sigmoid loss for ColBERT with in-batch negatives. | ||
|
|
||
| Args: | ||
| temperature (float): Scaling factor for logits. | ||
| pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True. | ||
| max_batch_size (int): Max batch size for index buffer caching. | ||
| filter_threshold (float): Threshold ratio for negative filtering. | ||
| filter_factor (float): Factor to down-weight filtered negatives. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| temperature: float = 0.02, | ||
| pos_aware_negative_filtering: bool = False, | ||
| max_batch_size: int = 1024, | ||
| filter_threshold: float = 0.95, | ||
| filter_factor: float = 0.5, | ||
| ): | ||
| super().__init__(max_batch_size, temperature, filter_threshold, filter_factor) | ||
| self.pos_aware_negative_filtering = pos_aware_negative_filtering | ||
|
|
||
| def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor: | ||
| """ | ||
| Compute the sigmoid loss for a batch of bi-encoder embeddings. | ||
|
|
||
| Args: | ||
| query_embeddings (Tensor[B, D]): Query vectors. | ||
| doc_embeddings (Tensor[B, D]): Document vectors. | ||
| offset (int): Offset for positive indices (multi-GPU). | ||
|
|
||
| Returns: | ||
| Tensor: Scalar cross-entropy loss. | ||
| """ | ||
|
|
||
| # Compute in-batch similarity matrix | ||
| scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) | ||
|
|
||
| batch_size, num_targets = scores.shape | ||
| device = scores.device | ||
|
|
||
| _, pos_idx = self._get_idx(batch_size, offset, device) | ||
|
|
||
| if self.pos_aware_negative_filtering: | ||
| self._filter_high_negatives(scores, pos_idx) | ||
|
|
||
| all_losses = [] | ||
| for k in range(num_targets // batch_size): | ||
| # mask equal to 1 on offset -> offset + batch_size | ||
| curr_idx = torch.arange(offset, offset + batch_size, device=device) | ||
| # keep only the scores for the current batch | ||
| curr_scores = scores[:, curr_idx].view(-1) / self.temperature | ||
| # compute the labels | ||
| labels = -torch.ones(batch_size * batch_size, device=device) | ||
| if k == 0: | ||
| flat_pos = (pos_idx - offset) * (batch_size + 1) | ||
| labels[flat_pos] = 1.0 | ||
| # compute the loss | ||
| block_loss = F.softplus(curr_scores * labels) | ||
| all_losses.append(block_loss) | ||
| # shift the offset for the next batch | ||
| offset = (offset + batch_size) % num_targets | ||
|
|
||
| return torch.stack(all_losses, dim=0).mean() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment to the README if ColQwen 2.5 Omni is not supported anymore