Skip to content
Draft

Vbert #339

Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
fc27f3c
update processors to new signatures
paultltc Jun 12, 2025
31c0709
lint
paultltc Jun 12, 2025
32de63c
keep process_queries for back comp
paultltc Jun 12, 2025
d00b267
add vbert/vllama modeling
paultltc Jun 16, 2025
7fba1c6
stage
paultltc Jun 24, 2025
43d3d36
fix typo in vbert modeling
paultltc Jun 30, 2025
ed11060
loss
paultltc Jul 1, 2025
55ebd0c
models
paultltc Jul 8, 2025
4ddc453
losses
paultltc Jul 8, 2025
e54df49
symetric loss + flex biencodr score
paultltc Jul 10, 2025
0375d68
process
paultltc Jul 10, 2025
2ab0cb0
merge
paultltc Jul 14, 2025
00337b1
fix dup
paultltc Jul 15, 2025
ec4d4dd
latest
paultltc Aug 13, 2025
91e9f36
modeling
paultltc Aug 19, 2025
91ba4be
f
QuentinJGMace Aug 8, 2025
81eef80
rebase
QuentinJGMace Aug 8, 2025
9a82c1f
rebase
QuentinJGMace Aug 8, 2025
245bb33
rebase
QuentinJGMace Aug 8, 2025
2ebe2ab
symetric loss + flex biencodr score
QuentinJGMace Jul 30, 2025
44fe1e6
f
QuentinJGMace Sep 29, 2025
1ec65fc
f
QuentinJGMace Sep 29, 2025
1b8510f
f
QuentinJGMace Sep 29, 2025
3dfbe4b
remove bvbert file
QuentinJGMace Sep 29, 2025
d748aa1
negatives loss
QuentinJGMace Sep 29, 2025
afd0e95
prepare collators for multi-hardnegs
QuentinJGMace Sep 29, 2025
dcbbe15
multiple hard negs training
QuentinJGMace Sep 29, 2025
24cd010
f
QuentinJGMace Sep 29, 2025
5c11cd3
rm colqwen_omni init
QuentinJGMace Sep 29, 2025
da868ae
f
QuentinJGMace Sep 30, 2025
fa1ea76
modif tests
QuentinJGMace Sep 30, 2025
3fb3df4
Change default model
ManuelFay Oct 3, 2025
31630d1
Change default text model name in configuration
QuentinJGMace Oct 3, 2025
9ce2871
fix: `ModernVBERT` modeling (#348)
paultltc Oct 16, 2025
20e78cc
add tests for modernvbert
QuentinJGMace Oct 16, 2025
43fba98
f test
QuentinJGMace Oct 16, 2025
058a299
f
QuentinJGMace Oct 16, 2025
d1e3f38
ff
QuentinJGMace Oct 16, 2025
133bc51
update dtype assign (#349)
paultltc Oct 17, 2025
df0d1a8
oopsie
QuentinJGMace Oct 17, 2025
8c89c49
update other losses
QuentinJGMace Oct 20, 2025
c6d4dd0
correct tests to handle multiple neg
QuentinJGMace Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions colpali_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from .models import (
BiPali,
BiPaliProj,
BiQwen2,
BiQwen2_5,
BiQwen2_5_Processor,
BiQwen2Processor,
BiModernVBert,
BiModernVBertProcessor,
ColIdefics3,
ColIdefics3Processor,
ColPali,
ColPaliProcessor,
ColQwen2,
ColQwen2_5,
ColQwen2_5_Processor,
ColQwen2_5Omni,
ColQwen2_5OmniProcessor,
# ColQwen2_5Omni,
# ColQwen2_5OmniProcessor,
Comment on lines -15 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment to the README if ColQwen 2.5 Omni is not supported anymore

ColQwen2Processor,
ColModernVBert,
ColModernVBertProcessor,
)

Check failure on line 22 in colpali_engine/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

colpali_engine/__init__.py:1:1: I001 Import block is un-sorted or un-formatted
142 changes: 142 additions & 0 deletions colpali_engine/collators/collator_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import random
import torch
from typing import Any, Dict, List, Union

from PIL.Image import Image

from colpali_engine.data.dataset import ColPaliEngineDataset
from colpali_engine.models.paligemma import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor

Check failure on line 9 in colpali_engine/collators/collator_copy.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

colpali_engine/collators/collator_copy.py:1:1: I001 Import block is un-sorted or un-formatted


def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
"""
Prefix all keys in a dictionary with the given prefix.
"""
return {f"{prefix}{k}": v for k, v in data.items()}


class VisualRetrieverCollator:
"""
Collator for training vision retrieval models.
"""

# Prefixes
query_prefix = "query_"
pos_doc_prefix = "doc_"
neg_doc_prefix = "neg_doc_"

def __init__(
self,
processor: BaseVisualRetrieverProcessor,
max_length: int = 2048,
):
self.processor = processor
self.max_length = max_length
self.image_token_id = None

# If processor is one of the supported types, extract the <image> token id.
if isinstance(self.processor, (ColPaliProcessor,)):
image_token = "<image>"
try:
idx = self.processor.tokenizer.additional_special_tokens.index(image_token)
self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[idx]
except ValueError:
self.image_token_id = None

# Force padding to be on the right for ColPaliProcessor.
if isinstance(self.processor, ColPaliProcessor) and self.processor.tokenizer.padding_side != "right":
print("Setting padding side to right")
self.processor.tokenizer.padding_side = "right"

def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
queries: List[Union[None, str, Image]] = []
pos_targets: List[Union[str, Image]] = []
neg_targets: List[Union[str, Image]] = []
selected_ids: List[int] = []

# Parse the examples.
positive_ids_tensor = -torch.ones((len(examples), 100), dtype=torch.long)
for i, example in enumerate(examples):
assert ColPaliEngineDataset.QUERY_KEY in example, f"Missing {ColPaliEngineDataset.QUERY_KEY} in example."
query = example[ColPaliEngineDataset.QUERY_KEY]
sampled_query = random.choice(query) if isinstance(query, list) else query
queries.append(sampled_query)

assert ColPaliEngineDataset.POS_TARGET_KEY in example, (
f"Missing {ColPaliEngineDataset.POS_TARGET_KEY} in example."
)
pos_tgt = example[ColPaliEngineDataset.POS_TARGET_KEY]
positive_ids = example.get("positive_ids", None)
if isinstance(pos_tgt, list):
sample_tuple = random.choice([(t, id_) for t, id_ in zip(pos_tgt, positive_ids)])
sample_pos = sample_tuple[0]
selected_ids.append(sample_tuple[1])
else:
sample_pos = pos_tgt
pos_targets.append(sample_pos)
if positive_ids is not None:
positive_ids_tensor[i, :len(positive_ids)] = torch.tensor(positive_ids)

neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None)
if neg_tgt is not None:
# sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt
# neg_targets.append(random.choice(neg_tgt)) #neg_tgts)
neg_targets.append(neg_tgt)

# Ensure all queries are strings or images.
assert all(isinstance(q, str) for q in queries), (
"All queries must be strings, this collator does not support images in queries."
)

# Process queries.
queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
batch_query = self.auto_collate(queries, key_prefix=self.query_prefix)

# Process targets.
batch_pos_target = self.auto_collate(pos_targets, key_prefix=self.pos_doc_prefix)
batch_neg_target = self.auto_collate(neg_targets, key_prefix=self.neg_doc_prefix) if neg_targets else {}

return {
**batch_query,
**batch_pos_target,
**batch_neg_target,
"selected_ids": torch.Tensor(selected_ids),
"positive_ids_tensor": positive_ids_tensor,
}

def auto_collate(self, batch: List[Union[str, Image, List[str], List[Image]]], key_prefix: str = "") -> Dict[str, Any]:

Check failure on line 108 in colpali_engine/collators/collator_copy.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

colpali_engine/collators/collator_copy.py:108:121: E501 Line too long (123 > 120)
"""Automatically collate a batch of documents."""
# Convert Document objects to their underlying data.
# if type is mixed across the batch, raise an error.
all_types = set(type(item) for item in batch)
if str in all_types and Image in all_types:
raise ValueError(f"Batch contains mixed types: {all_types}. Expected all items to be of the same type.")
if isinstance(batch[0], str):
proc_batch = self.processor.process_texts(texts=batch)
elif isinstance(batch[0], Image):
proc_batch = self.processor.process_images(images=batch)
elif isinstance(batch[0], list):
if isinstance(batch[0][0], str):
proc_texts_batch = []

Check failure on line 121 in colpali_engine/collators/collator_copy.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

colpali_engine/collators/collator_copy.py:121:17: F841 Local variable `proc_texts_batch` is assigned to but never used
batch_size = len(batch)
all_texts = [text for texts in batch for text in texts]
num_negatives = len(all_texts) // batch_size
proc_batch = self.processor.process_texts(texts=all_texts)
elif isinstance(batch[0][0], Image):
proc_imgs_batch = []

Check failure on line 127 in colpali_engine/collators/collator_copy.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

colpali_engine/collators/collator_copy.py:127:17: F841 Local variable `proc_imgs_batch` is assigned to but never used
batch_size = len(batch)
all_imgs = [img for imgs in batch for img in imgs]
num_negatives = len(all_imgs) // batch_size
proc_batch = self.processor.process_images(images=all_imgs)
else:
raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.")
for k, v in proc_batch.items():
if isinstance(v, torch.Tensor):

Check failure on line 135 in colpali_engine/collators/collator_copy.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W291)

colpali_engine/collators/collator_copy.py:135:48: W291 Trailing whitespace
proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:])
else:
proc_batch[k] = v
else:
raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.")

return prefix_keys(proc_batch, key_prefix)

Check failure on line 142 in colpali_engine/collators/collator_copy.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W292)

colpali_engine/collators/collator_copy.py:142:51: W292 No newline at end of file
29 changes: 26 additions & 3 deletions colpali_engine/collators/visual_retriever_collator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import random
from typing import Any, Dict, List, Union
import torch

from PIL.Image import Image

from colpali_engine.data.dataset import ColPaliEngineDataset
from colpali_engine.models.paligemma import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor

Check failure on line 9 in colpali_engine/collators/visual_retriever_collator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

colpali_engine/collators/visual_retriever_collator.py:1:1: I001 Import block is un-sorted or un-formatted


def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -69,16 +70,18 @@

neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None)
if neg_tgt is not None:
sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt
neg_targets.append(sampled_neg)
neg_targets.append(neg_tgt)

# Ensure all queries are strings or images.
assert all(isinstance(q, str) for q in queries), (
"All queries must be strings, this collator does not support images in queries."
)

is_str = isinstance(queries[0], str)

# Process queries.
queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
# queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented lines if not useful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually usefull, in modernvbert self.processor.query_prefix is "" but it is useful if somebody wants to reproduce other older models.
Thanks for flagging it out !

queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put 10 into a constant (e.g. N_AUGMENTATION_TOKENS)

batch_query = self.auto_collate(queries, key_prefix=self.query_prefix)

# Process targets.
Expand All @@ -102,6 +105,26 @@
proc_batch = self.processor.process_texts(texts=batch)
elif isinstance(batch[0], Image):
proc_batch = self.processor.process_images(images=batch)
elif isinstance(batch[0], list):
if isinstance(batch[0][0], str):
proc_texts_batch = []

Check failure on line 110 in colpali_engine/collators/visual_retriever_collator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

colpali_engine/collators/visual_retriever_collator.py:110:17: F841 Local variable `proc_texts_batch` is assigned to but never used
batch_size = len(batch)
all_texts = [text for texts in batch for text in texts]
num_negatives = len(all_texts) // batch_size
proc_batch = self.processor.process_texts(texts=all_texts)
elif isinstance(batch[0][0], Image):
proc_imgs_batch = []

Check failure on line 116 in colpali_engine/collators/visual_retriever_collator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

colpali_engine/collators/visual_retriever_collator.py:116:17: F841 Local variable `proc_imgs_batch` is assigned to but never used
batch_size = len(batch)
all_imgs = [img for imgs in batch for img in imgs]
num_negatives = len(all_imgs) // batch_size
proc_batch = self.processor.process_images(images=all_imgs)
else:
raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.")
for k, v in proc_batch.items():
if isinstance(v, torch.Tensor):
proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:])
else:
proc_batch[k] = v
Comment on lines +126 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary

else:
raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.")
return prefix_keys(proc_batch, key_prefix)
6 changes: 4 additions & 2 deletions colpali_engine/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
query_column_name: str = "query",
pos_target_column_name: str = "pos_target",
neg_target_column_name: str = None,
num_negatives: int = 3,
):
"""
Initialize the dataset with the provided data and external document corpus.
Expand All @@ -94,6 +95,7 @@ def __init__(
self.pos_target_column_name = pos_target_column_name
self.neg_target_column_name = neg_target_column_name

self.num_negatives = num_negatives
assert isinstance(
self.data,
(list, Dataset, HFDataset),
Expand Down Expand Up @@ -131,8 +133,8 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
pos_targets = [self.corpus.retrieve(doc_id) for doc_id in pos_targets]
if neg_targets is not None:
# to avoid oveflowing CPU memory
if len(neg_targets) > 5:
neg_targets = random.sample(neg_targets, 5)
if len(neg_targets) > self.num_negatives:
neg_targets = random.sample(neg_targets, self.num_negatives)
neg_targets = [self.corpus.retrieve(doc_id) for doc_id in neg_targets]

return {
Expand Down
2 changes: 2 additions & 0 deletions colpali_engine/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
BiNegativeCELoss,
BiPairwiseCELoss,
BiPairwiseNegativeCELoss,
BiSigmoidLoss,
)
from .late_interaction_losses import (
ColbertLoss,
ColbertModule,
ColbertNegativeCELoss,
ColbertPairwiseCELoss,
ColbertPairwiseNegativeCELoss,
ColbertSigmoidLoss,
)
Loading