Skip to content
Draft

Vbert #339

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
fc27f3c
update processors to new signatures
paultltc Jun 12, 2025
31c0709
lint
paultltc Jun 12, 2025
32de63c
keep process_queries for back comp
paultltc Jun 12, 2025
d00b267
add vbert/vllama modeling
paultltc Jun 16, 2025
7fba1c6
stage
paultltc Jun 24, 2025
43d3d36
fix typo in vbert modeling
paultltc Jun 30, 2025
ed11060
loss
paultltc Jul 1, 2025
55ebd0c
models
paultltc Jul 8, 2025
4ddc453
losses
paultltc Jul 8, 2025
e54df49
symetric loss + flex biencodr score
paultltc Jul 10, 2025
0375d68
process
paultltc Jul 10, 2025
2ab0cb0
merge
paultltc Jul 14, 2025
00337b1
fix dup
paultltc Jul 15, 2025
ec4d4dd
latest
paultltc Aug 13, 2025
91e9f36
modeling
paultltc Aug 19, 2025
91ba4be
f
QuentinJGMace Aug 8, 2025
81eef80
rebase
QuentinJGMace Aug 8, 2025
9a82c1f
rebase
QuentinJGMace Aug 8, 2025
245bb33
rebase
QuentinJGMace Aug 8, 2025
2ebe2ab
symetric loss + flex biencodr score
QuentinJGMace Jul 30, 2025
44fe1e6
f
QuentinJGMace Sep 29, 2025
1ec65fc
f
QuentinJGMace Sep 29, 2025
1b8510f
f
QuentinJGMace Sep 29, 2025
3dfbe4b
remove bvbert file
QuentinJGMace Sep 29, 2025
d748aa1
negatives loss
QuentinJGMace Sep 29, 2025
afd0e95
prepare collators for multi-hardnegs
QuentinJGMace Sep 29, 2025
dcbbe15
multiple hard negs training
QuentinJGMace Sep 29, 2025
24cd010
f
QuentinJGMace Sep 29, 2025
5c11cd3
rm colqwen_omni init
QuentinJGMace Sep 29, 2025
da868ae
f
QuentinJGMace Sep 30, 2025
fa1ea76
modif tests
QuentinJGMace Sep 30, 2025
3fb3df4
Change default model
ManuelFay Oct 3, 2025
31630d1
Change default text model name in configuration
QuentinJGMace Oct 3, 2025
9ce2871
fix: `ModernVBERT` modeling (#348)
paultltc Oct 16, 2025
20e78cc
add tests for modernvbert
QuentinJGMace Oct 16, 2025
43fba98
f test
QuentinJGMace Oct 16, 2025
058a299
f
QuentinJGMace Oct 16, 2025
d1e3f38
ff
QuentinJGMace Oct 16, 2025
133bc51
update dtype assign (#349)
paultltc Oct 17, 2025
df0d1a8
oopsie
QuentinJGMace Oct 17, 2025
8c89c49
update other losses
QuentinJGMace Oct 20, 2025
c6d4dd0
correct tests to handle multiple neg
QuentinJGMace Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions colpali_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from .models import (
BiPali,
BiPaliProj,
BiQwen2,
BiQwen2_5,
BiQwen2_5_Processor,
BiQwen2Processor,
BiModernVBert,
BiModernVBertProcessor,
ColIdefics3,
ColIdefics3Processor,
ColPali,
ColPaliProcessor,
ColQwen2,
ColQwen2_5,
ColQwen2_5_Processor,
ColQwen2_5Omni,
ColQwen2_5OmniProcessor,
# ColQwen2_5Omni,
# ColQwen2_5OmniProcessor,
Comment on lines -15 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment to the README if ColQwen 2.5 Omni is not supported anymore

ColQwen2Processor,
ColModernVBert,
ColModernVBertProcessor,
)

Check failure on line 22 in colpali_engine/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

colpali_engine/__init__.py:1:1: I001 Import block is un-sorted or un-formatted
29 changes: 26 additions & 3 deletions colpali_engine/collators/visual_retriever_collator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import random
from typing import Any, Dict, List, Union
import torch

from PIL.Image import Image

from colpali_engine.data.dataset import ColPaliEngineDataset
from colpali_engine.models.paligemma import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor

Check failure on line 9 in colpali_engine/collators/visual_retriever_collator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

colpali_engine/collators/visual_retriever_collator.py:1:1: I001 Import block is un-sorted or un-formatted


def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -69,16 +70,18 @@

neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None)
if neg_tgt is not None:
sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt
neg_targets.append(sampled_neg)
neg_targets.append(neg_tgt)

# Ensure all queries are strings or images.
assert all(isinstance(q, str) for q in queries), (
"All queries must be strings, this collator does not support images in queries."
)

is_str = isinstance(queries[0], str)

# Process queries.
queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
# queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented lines if not useful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually usefull, in modernvbert self.processor.query_prefix is "" but it is useful if somebody wants to reproduce other older models.
Thanks for flagging it out !

queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put 10 into a constant (e.g. N_AUGMENTATION_TOKENS)

batch_query = self.auto_collate(queries, key_prefix=self.query_prefix)

# Process targets.
Expand All @@ -102,6 +105,26 @@
proc_batch = self.processor.process_texts(texts=batch)
elif isinstance(batch[0], Image):
proc_batch = self.processor.process_images(images=batch)
elif isinstance(batch[0], list):
if isinstance(batch[0][0], str):
proc_texts_batch = []

Check failure on line 110 in colpali_engine/collators/visual_retriever_collator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

colpali_engine/collators/visual_retriever_collator.py:110:17: F841 Local variable `proc_texts_batch` is assigned to but never used
batch_size = len(batch)
all_texts = [text for texts in batch for text in texts]
num_negatives = len(all_texts) // batch_size
proc_batch = self.processor.process_texts(texts=all_texts)
elif isinstance(batch[0][0], Image):
proc_imgs_batch = []

Check failure on line 116 in colpali_engine/collators/visual_retriever_collator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

colpali_engine/collators/visual_retriever_collator.py:116:17: F841 Local variable `proc_imgs_batch` is assigned to but never used
batch_size = len(batch)
all_imgs = [img for imgs in batch for img in imgs]
num_negatives = len(all_imgs) // batch_size
proc_batch = self.processor.process_images(images=all_imgs)
else:
raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.")
for k, v in proc_batch.items():
if isinstance(v, torch.Tensor):

Check failure on line 124 in colpali_engine/collators/visual_retriever_collator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

colpali_engine/collators/visual_retriever_collator.py:124:48: W291 Trailing whitespace
proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:])
else:
proc_batch[k] = v
Comment on lines +126 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary

else:
raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.")
return prefix_keys(proc_batch, key_prefix)
6 changes: 4 additions & 2 deletions colpali_engine/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
query_column_name: str = "query",
pos_target_column_name: str = "pos_target",
neg_target_column_name: str = None,
num_negatives: int = 3,
):
"""
Initialize the dataset with the provided data and external document corpus.
Expand All @@ -94,6 +95,7 @@ def __init__(
self.pos_target_column_name = pos_target_column_name
self.neg_target_column_name = neg_target_column_name

self.num_negatives = num_negatives
assert isinstance(
self.data,
(list, Dataset, HFDataset),
Expand Down Expand Up @@ -131,8 +133,8 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
pos_targets = [self.corpus.retrieve(doc_id) for doc_id in pos_targets]
if neg_targets is not None:
# to avoid oveflowing CPU memory
if len(neg_targets) > 5:
neg_targets = random.sample(neg_targets, 5)
if len(neg_targets) > self.num_negatives:
neg_targets = random.sample(neg_targets, self.num_negatives)
neg_targets = [self.corpus.retrieve(doc_id) for doc_id in neg_targets]

return {
Expand Down
2 changes: 2 additions & 0 deletions colpali_engine/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
BiNegativeCELoss,
BiPairwiseCELoss,
BiPairwiseNegativeCELoss,
BiSigmoidLoss,
)
from .late_interaction_losses import (
ColbertLoss,
ColbertModule,
ColbertNegativeCELoss,
ColbertPairwiseCELoss,
ColbertPairwiseNegativeCELoss,
ColbertSigmoidLoss,
)
140 changes: 132 additions & 8 deletions colpali_engine/loss/bi_encoder_losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn.functional as F # noqa: N812
from torch.nn import CrossEntropyLoss


Expand Down Expand Up @@ -111,6 +112,60 @@ def forward(

return self.ce_loss(scores / self.temperature, pos_idx)

class BiPairedEncoderLoss(BiEncoderModule):
"""
InfoNCE loss for bi-encoders without explicit negatives.

Args:
temperature (float): Scaling factor for logits.
pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
max_batch_size (int): Max batch size for index buffer caching.
filter_threshold (float): Threshold ratio for negative filtering.
filter_factor (float): Factor to down-weight filtered negatives.
"""

def __init__(
self,
temperature: float = 0.02,
pos_aware_negative_filtering: bool = False,
max_batch_size: int = 1024,
filter_threshold: float = 0.95,
filter_factor: float = 0.5,
):
super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
self.pos_aware_negative_filtering = pos_aware_negative_filtering
self.ce_loss = CrossEntropyLoss()

def forward(
self,
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
offset: int = 0,
) -> torch.Tensor:
"""
Compute the InfoNCE loss over a batch of bi-encoder embeddings.

Args:
query_embeddings (Tensor[B, D]): Query vectors.
doc_embeddings (Tensor[B, D]): Document vectors.
offset (int): Offset for positive indices (multi-GPU).

Returns:
Tensor: Scalar cross-entropy loss.
"""
# Compute in-batch similarity matrix
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
batch_size = scores.size(0)
idx, pos_idx = self._get_idx(batch_size, offset, scores.device)

if self.pos_aware_negative_filtering:
self._filter_high_negatives(scores, pos_idx)

q2t = self.ce_loss(scores / self.temperature, pos_idx)
t2q = self.ce_loss(scores.T / self.temperature, ...)

return (q2t + t2q) / 2.0
Comment on lines +164 to +167
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to add a comment (either here or in function comment) to mention the symmetric loss



class BiNegativeCELoss(BiEncoderModule):
"""
Expand Down Expand Up @@ -161,17 +216,18 @@ def forward(
Args:
query_embeddings (Tensor[B, D]): Query vectors.
doc_embeddings (Tensor[B, D]): Positive document vectors.
neg_doc_embeddings (Tensor[B, D]): Negative document vectors.
neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.
offset (int): Offset for in-batch CE positives.

Returns:
Tensor: Scalar loss value.
"""
# Dot-product only for matching pairs
pos_scores = (query_embeddings * doc_embeddings).sum(dim=1) / self.temperature
neg_scores = (query_embeddings * neg_doc_embeddings).sum(dim=1) / self.temperature
pos_scores = (query_embeddings * doc_embeddings[offset:offset + neg_doc_embeddings.size(0)]).sum(dim=1)
pos_scores /= self.temperature
neg_scores = torch.einsum("bd,bnd->bn", query_embeddings, neg_doc_embeddings) / self.temperature

loss = torch.nn.functional.softplus(neg_scores - pos_scores).mean()
loss = F.softplus(neg_scores - pos_scores.unsqueeze(1)).mean()

if self.in_batch_term_weight > 0:
loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset)
Expand Down Expand Up @@ -206,6 +262,7 @@ def forward(
self,
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
offset: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to add this argument if it's not supported in this loss?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch, legacy bug for this loss I suppose, all losses should support offset as its the thing that allows for multi-gpu distributed training.

) -> torch.Tensor:
"""
Compute softplus(hardest_neg - pos) where hardest_neg is the highest off-diagonal score.
Expand Down Expand Up @@ -267,26 +324,93 @@ def forward(
query_embeddings: torch.Tensor,
doc_embeddings: torch.Tensor,
neg_doc_embeddings: torch.Tensor,
offset: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to add this argument if it's not supported in this loss?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as before

) -> torch.Tensor:
"""
Compute softplus(neg-explicit - pos) plus optional pairwise in-batch loss.

Args:
query_embeddings (Tensor[B, D]): Query vectors.
doc_embeddings (Tensor[B, D]): Positive document vectors.
neg_doc_embeddings (Tensor[B, D]): Negative document vectors.
neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.

Returns:
Tensor: Scalar loss value.
"""
# dot product for matching pairs only
pos = (query_embeddings * doc_embeddings).sum(dim=1)
neg = (query_embeddings * neg_doc_embeddings).sum(dim=1)
pos = (query_embeddings * doc_embeddings).sum(dim=1) # B
neg = (query_embeddings.unsqueeze(1) * neg_doc_embeddings).sum(dim=2) # B x N

loss = torch.nn.functional.softplus((neg - pos) / self.temperature).mean()
loss = torch.nn.functional.softplus((neg - pos.unsqueeze(1)) / self.temperature).mean()

if self.in_batch_term_weight > 0:
loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings)
loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight

return loss

class BiSigmoidLoss(BiEncoderModule):
"""
Sigmoid loss for ColBERT with in-batch negatives.

Args:
temperature (float): Scaling factor for logits.
pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
max_batch_size (int): Max batch size for index buffer caching.
filter_threshold (float): Threshold ratio for negative filtering.
filter_factor (float): Factor to down-weight filtered negatives.
"""

def __init__(
self,
temperature: float = 0.02,
pos_aware_negative_filtering: bool = False,
max_batch_size: int = 1024,
filter_threshold: float = 0.95,
filter_factor: float = 0.5,
):
super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
self.pos_aware_negative_filtering = pos_aware_negative_filtering

def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
"""
Compute the sigmoid loss for a batch of bi-encoder embeddings.

Args:
query_embeddings (Tensor[B, D]): Query vectors.
doc_embeddings (Tensor[B, D]): Document vectors.
offset (int): Offset for positive indices (multi-GPU).

Returns:
Tensor: Scalar cross-entropy loss.
"""

# Compute in-batch similarity matrix
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)

batch_size, num_targets = scores.shape
device = scores.device

_, pos_idx = self._get_idx(batch_size, offset, device)

if self.pos_aware_negative_filtering:
self._filter_high_negatives(scores, pos_idx)

all_losses = []
for k in range(num_targets // batch_size):
# mask equal to 1 on offset -> offset + batch_size
curr_idx = torch.arange(offset, offset + batch_size, device=device)
# keep only the scores for the current batch
curr_scores = scores[:, curr_idx].view(-1) / self.temperature
# compute the labels
labels = -torch.ones(batch_size * batch_size, device=device)
if k == 0:
flat_pos = (pos_idx - offset) * (batch_size + 1)
labels[flat_pos] = 1.0
# compute the loss
block_loss = F.softplus(curr_scores * labels)
all_losses.append(block_loss)
# shift the offset for the next batch
offset = (offset + batch_size) % num_targets

return torch.stack(all_losses, dim=0).mean()
Loading