-
Notifications
You must be signed in to change notification settings - Fork 207
Vbert #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Vbert #339
Changes from 39 commits
fc27f3c
31c0709
32de63c
d00b267
7fba1c6
43d3d36
ed11060
55ebd0c
4ddc453
e54df49
0375d68
2ab0cb0
00337b1
ec4d4dd
91e9f36
91ba4be
81eef80
9a82c1f
245bb33
2ebe2ab
44fe1e6
1ec65fc
1b8510f
3dfbe4b
d748aa1
afd0e95
dcbbe15
24cd010
5c11cd3
da868ae
fa1ea76
3fb3df4
31630d1
9ce2871
20e78cc
43fba98
058a299
d1e3f38
133bc51
df0d1a8
8c89c49
c6d4dd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,22 @@ | ||
| from .models import ( | ||
| BiPali, | ||
| BiPaliProj, | ||
| BiQwen2, | ||
| BiQwen2_5, | ||
| BiQwen2_5_Processor, | ||
| BiQwen2Processor, | ||
| BiModernVBert, | ||
| BiModernVBertProcessor, | ||
| ColIdefics3, | ||
| ColIdefics3Processor, | ||
| ColPali, | ||
| ColPaliProcessor, | ||
| ColQwen2, | ||
| ColQwen2_5, | ||
| ColQwen2_5_Processor, | ||
| ColQwen2_5Omni, | ||
| ColQwen2_5OmniProcessor, | ||
| # ColQwen2_5Omni, | ||
| # ColQwen2_5OmniProcessor, | ||
| ColQwen2Processor, | ||
| ColModernVBert, | ||
| ColModernVBertProcessor, | ||
| ) | ||
QuentinJGMace marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| import random | ||
| import torch | ||
| from typing import Any, Dict, List, Union | ||
|
|
||
| from PIL.Image import Image | ||
|
|
||
| from colpali_engine.data.dataset import ColPaliEngineDataset | ||
| from colpali_engine.models.paligemma import ColPaliProcessor | ||
| from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | ||
|
|
||
|
|
||
| def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: | ||
| """ | ||
| Prefix all keys in a dictionary with the given prefix. | ||
| """ | ||
| return {f"{prefix}{k}": v for k, v in data.items()} | ||
|
|
||
|
|
||
| class VisualRetrieverCollator: | ||
| """ | ||
| Collator for training vision retrieval models. | ||
| """ | ||
|
|
||
| # Prefixes | ||
| query_prefix = "query_" | ||
| pos_doc_prefix = "doc_" | ||
| neg_doc_prefix = "neg_doc_" | ||
|
|
||
| def __init__( | ||
| self, | ||
| processor: BaseVisualRetrieverProcessor, | ||
| max_length: int = 2048, | ||
| ): | ||
| self.processor = processor | ||
| self.max_length = max_length | ||
| self.image_token_id = None | ||
|
|
||
| # If processor is one of the supported types, extract the <image> token id. | ||
| if isinstance(self.processor, (ColPaliProcessor,)): | ||
| image_token = "<image>" | ||
| try: | ||
| idx = self.processor.tokenizer.additional_special_tokens.index(image_token) | ||
| self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[idx] | ||
| except ValueError: | ||
| self.image_token_id = None | ||
|
|
||
| # Force padding to be on the right for ColPaliProcessor. | ||
| if isinstance(self.processor, ColPaliProcessor) and self.processor.tokenizer.padding_side != "right": | ||
| print("Setting padding side to right") | ||
| self.processor.tokenizer.padding_side = "right" | ||
|
|
||
| def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: | ||
| queries: List[Union[None, str, Image]] = [] | ||
| pos_targets: List[Union[str, Image]] = [] | ||
| neg_targets: List[Union[str, Image]] = [] | ||
| selected_ids: List[int] = [] | ||
|
|
||
| # Parse the examples. | ||
| positive_ids_tensor = -torch.ones((len(examples), 100), dtype=torch.long) | ||
| for i, example in enumerate(examples): | ||
| assert ColPaliEngineDataset.QUERY_KEY in example, f"Missing {ColPaliEngineDataset.QUERY_KEY} in example." | ||
| query = example[ColPaliEngineDataset.QUERY_KEY] | ||
| sampled_query = random.choice(query) if isinstance(query, list) else query | ||
| queries.append(sampled_query) | ||
|
|
||
| assert ColPaliEngineDataset.POS_TARGET_KEY in example, ( | ||
| f"Missing {ColPaliEngineDataset.POS_TARGET_KEY} in example." | ||
| ) | ||
| pos_tgt = example[ColPaliEngineDataset.POS_TARGET_KEY] | ||
| positive_ids = example.get("positive_ids", None) | ||
| if isinstance(pos_tgt, list): | ||
| sample_tuple = random.choice([(t, id_) for t, id_ in zip(pos_tgt, positive_ids)]) | ||
| sample_pos = sample_tuple[0] | ||
| selected_ids.append(sample_tuple[1]) | ||
| else: | ||
| sample_pos = pos_tgt | ||
| pos_targets.append(sample_pos) | ||
| if positive_ids is not None: | ||
| positive_ids_tensor[i, :len(positive_ids)] = torch.tensor(positive_ids) | ||
|
|
||
| neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) | ||
| if neg_tgt is not None: | ||
| # sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt | ||
| # neg_targets.append(random.choice(neg_tgt)) #neg_tgts) | ||
| neg_targets.append(neg_tgt) | ||
|
|
||
| # Ensure all queries are strings or images. | ||
| assert all(isinstance(q, str) for q in queries), ( | ||
| "All queries must be strings, this collator does not support images in queries." | ||
| ) | ||
|
|
||
| # Process queries. | ||
| queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
| batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) | ||
|
|
||
| # Process targets. | ||
| batch_pos_target = self.auto_collate(pos_targets, key_prefix=self.pos_doc_prefix) | ||
| batch_neg_target = self.auto_collate(neg_targets, key_prefix=self.neg_doc_prefix) if neg_targets else {} | ||
|
|
||
| return { | ||
| **batch_query, | ||
| **batch_pos_target, | ||
| **batch_neg_target, | ||
| "selected_ids": torch.Tensor(selected_ids), | ||
| "positive_ids_tensor": positive_ids_tensor, | ||
| } | ||
|
|
||
| def auto_collate(self, batch: List[Union[str, Image, List[str], List[Image]]], key_prefix: str = "") -> Dict[str, Any]: | ||
| """Automatically collate a batch of documents.""" | ||
| # Convert Document objects to their underlying data. | ||
| # if type is mixed across the batch, raise an error. | ||
| all_types = set(type(item) for item in batch) | ||
| if str in all_types and Image in all_types: | ||
| raise ValueError(f"Batch contains mixed types: {all_types}. Expected all items to be of the same type.") | ||
| if isinstance(batch[0], str): | ||
| proc_batch = self.processor.process_texts(texts=batch) | ||
| elif isinstance(batch[0], Image): | ||
| proc_batch = self.processor.process_images(images=batch) | ||
| elif isinstance(batch[0], list): | ||
| if isinstance(batch[0][0], str): | ||
| proc_texts_batch = [] | ||
| batch_size = len(batch) | ||
| all_texts = [text for texts in batch for text in texts] | ||
| num_negatives = len(all_texts) // batch_size | ||
| proc_batch = self.processor.process_texts(texts=all_texts) | ||
| elif isinstance(batch[0][0], Image): | ||
| proc_imgs_batch = [] | ||
| batch_size = len(batch) | ||
| all_imgs = [img for imgs in batch for img in imgs] | ||
| num_negatives = len(all_imgs) // batch_size | ||
| proc_batch = self.processor.process_images(images=all_imgs) | ||
| else: | ||
| raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.") | ||
| for k, v in proc_batch.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:]) | ||
| else: | ||
| proc_batch[k] = v | ||
| else: | ||
| raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") | ||
|
|
||
| return prefix_keys(proc_batch, key_prefix) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| import random | ||
| from typing import Any, Dict, List, Union | ||
| import torch | ||
|
|
||
| from PIL.Image import Image | ||
|
|
||
| from colpali_engine.data.dataset import ColPaliEngineDataset | ||
| from colpali_engine.models.paligemma import ColPaliProcessor | ||
| from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | ||
|
|
||
|
|
||
| def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: | ||
|
|
@@ -69,16 +70,18 @@ | |
|
|
||
| neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None) | ||
| if neg_tgt is not None: | ||
| sampled_neg = random.choice(neg_tgt) if isinstance(neg_tgt, list) else neg_tgt | ||
| neg_targets.append(sampled_neg) | ||
| neg_targets.append(neg_tgt) | ||
|
|
||
| # Ensure all queries are strings or images. | ||
| assert all(isinstance(q, str) for q in queries), ( | ||
| "All queries must be strings, this collator does not support images in queries." | ||
| ) | ||
|
|
||
| is_str = isinstance(queries[0], str) | ||
|
|
||
| # Process queries. | ||
| queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
| # queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented lines if not useful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually usefull, in modernvbert |
||
| queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put 10 into a constant (e.g. |
||
| batch_query = self.auto_collate(queries, key_prefix=self.query_prefix) | ||
|
|
||
| # Process targets. | ||
|
|
@@ -102,6 +105,26 @@ | |
| proc_batch = self.processor.process_texts(texts=batch) | ||
| elif isinstance(batch[0], Image): | ||
| proc_batch = self.processor.process_images(images=batch) | ||
| elif isinstance(batch[0], list): | ||
| if isinstance(batch[0][0], str): | ||
| proc_texts_batch = [] | ||
| batch_size = len(batch) | ||
| all_texts = [text for texts in batch for text in texts] | ||
| num_negatives = len(all_texts) // batch_size | ||
| proc_batch = self.processor.process_texts(texts=all_texts) | ||
| elif isinstance(batch[0][0], Image): | ||
| proc_imgs_batch = [] | ||
| batch_size = len(batch) | ||
| all_imgs = [img for imgs in batch for img in imgs] | ||
| num_negatives = len(all_imgs) // batch_size | ||
| proc_batch = self.processor.process_images(images=all_imgs) | ||
| else: | ||
| raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.") | ||
| for k, v in proc_batch.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:]) | ||
| else: | ||
| proc_batch[k] = v | ||
|
Comment on lines
+126
to
+127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary |
||
| else: | ||
| raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.") | ||
| return prefix_keys(proc_batch, key_prefix) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment to the README if ColQwen 2.5 Omni is not supported anymore