Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .idefics3 import BiIdefics3, BiIdefics3Processor, ColIdefics3, ColIdefics3Processor
from .internvl3_5 import ColInternVL3_5, ColInternVL3_5_Processor
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor
Expand Down
1 change: 1 addition & 0 deletions colpali_engine/models/internvl3_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .colinternvl3_5 import ColInternVL3_5, ColInternVL3_5_Processor
2 changes: 2 additions & 0 deletions colpali_engine/models/internvl3_5/colinternvl3_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_colinternvl3_5 import ColInternVL3_5
from .processing_colinternvl3_5 import ColInternVL3_5_Processor
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import ClassVar

import torch
from torch import nn
from transformers.models.internvl import InternVLConfig, InternVLModel

Check failure on line 5 in colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py:1:1: I001 Import block is un-sorted or un-formatted




class ColInternVL3_5(InternVLModel):

Check failure on line 10 in colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (N801)

colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py:10:7: N801 Class name `ColInternVL3_5` should use CapWords convention
"""
ColInternVL3_5 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.

Check failure on line 12 in colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py:12:121: E501 Line too long (123 > 120)

Args:
config (InternVLConfig): The model configuration.
mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
except those of the image at inference.
Defaults to False --> Do not mask any embeddings during forward pass.
"""

main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related

def __init__(self, config: InternVLConfig, mask_non_image_embeddings: bool = False):
super().__init__(config=config)
self.dim = 128
# breakpoint()
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim)
self.padding_side = "left"
self.mask_non_image_embeddings = mask_non_image_embeddings
self.post_init()

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = super()._checkpoint_conversion_mapping
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

def forward(self, *args, **kwargs) -> torch.Tensor:
kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)
hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
) # (batch_size, sequence_length, hidden_size)

proj = self.custom_text_proj(hidden_states) # (batch_size, sequence_length, dim)

# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)

if "pixel_values" in kwargs and self.mask_non_image_embeddings:
# Pools only the image embeddings
image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask
return proj

@property
def patch_size(self) -> int:
return self.visual.config.patch_size

@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import ClassVar, List, Optional, Tuple, Union

import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature
from transformers.models.internvl import InternVLProcessor
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor


class ColInternVL3_5_Processor(BaseVisualRetrieverProcessor, InternVLProcessor):

Check failure on line 12 in colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (N801)

colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py:12:7: N801 Class name `ColInternVL3_5_Processor` should use CapWords convention
"""
Processor for ColInternVL3_5.

Args:
*args: Variable length argument list to be passed to the parent `InternVLProcessor` class.
max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model.
**kwargs: Arbitrary keyword arguments to be passed to the parent `InternVLProcessor` class.
"""

visual_prompt_prefix: ClassVar[str] = (
"<|im_start|>user\n<img><IMG_CONTEXT></img>\nDescribe the image.<|im_end|><|endoftext|>"
)
query_augmentation_token: ClassVar[str] = "<|endoftext|>"
image_token: ClassVar[str] = "<IMG_CONTEXT>"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.tokenizer.padding_side = "left"

# @classmethod
# def from_pretrained(
# cls,
# *args,
# device_map: Optional[str] = None,
# **kwargs,
# ):
# instance = super().from_pretrained(
# *args,
# device_map=device_map,
# **kwargs,
# )

# if "max_num_visual_tokens" in kwargs:
# instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 28 * 28
# instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels

# return instance

def process_images(
self,
images: List[Image.Image],
) -> Union[BatchFeature, BatchEncoding]:
"""
Process images for ColInternVL3_5.

Args:
images: List of PIL images.
"""

images = [image.convert("RGB") for image in images]

batch_doc = self(
text=[self.visual_prompt_prefix] * len(images),
images=images,
padding="longest",
return_tensors="pt",
)

return batch_doc

def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
"""
Process texts for ColInternVL3_5.

Args:
texts: List of input texts.

Returns:
Union[BatchFeature, BatchEncoding]: Processed texts.
"""
return self(
text=texts,
return_tensors="pt",
padding="longest",
)

def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
"""
return self.score_multi_vector(qs, ps, device=device, **kwargs)

def get_n_patches(
self,
image_size: Tuple[int, int],
spatial_merge_size: int,
) -> Tuple[int, int]:
"""
Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
size (height, width) with the given patch size.

The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
as a `InternVL3_5VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
"""
patch_size = self.image_processor.patch_size

height_new, width_new = smart_resize(
width=image_size[0],
height=image_size[1],
factor=patch_size * self.image_processor.merge_size,
min_pixels=self.image_processor.size["shortest_edge"],
max_pixels=self.image_processor.size["longest_edge"],
)

n_patches_x = width_new // patch_size // spatial_merge_size
n_patches_y = height_new // patch_size // spatial_merge_size

return n_patches_x, n_patches_y

def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
"""
Get a tensor mask that identifies the image tokens in the batch.
"""
return batch_images.input_ids == self.image_token_id
19 changes: 19 additions & 0 deletions scripts/configs/accelerate_configs/single_node_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_process_ip: ''
main_process_port: 29500
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
93 changes: 93 additions & 0 deletions scripts/configs/internvl/train_colinternvl3_5_1b_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
import shutil
from pathlib import Path

import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments

from colpali_engine.data.dataset import ColPaliEngineDataset
from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss

Check failure on line 11 in scripts/configs/internvl/train_colinternvl3_5_1b_model.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

scripts/configs/internvl/train_colinternvl3_5_1b_model.py:11:70: F401 `colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss` imported but unused
from colpali_engine.models import ColInternVL3_5, ColInternVL3_5_Processor
from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
from colpali_engine.utils.dataset_transformation import load_train_set


def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
return p.parse_args()


if __name__ == "__main__":
args = parse_args()

loss_func = ColbertLoss(
temperature=args.tau,
normalize_scores=True,
use_smooth_max=False,
pos_aware_negative_filtering=False,
)

config = ColModelTrainingConfig(
output_dir="models/colinternvl3_5_1b",
processor=ColInternVL3_5_Processor.from_pretrained(
pretrained_model_name_or_path="./models/base_models/colinternvl3_5-1b-base",
max_num_visual_tokens=768,
),
model=ColInternVL3_5.from_pretrained(
pretrained_model_name_or_path="./models/base_models/colinternvl3_5-1b-base",
dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# use_cache=False,
trust_remote_code=True,
attn_implementation="flash_attention_2",
),
train_dataset=load_train_set(),
eval_dataset=ColPaliEngineDataset(
load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
),
run_eval=True,
loss_func=loss_func,
tr_args=TrainingArguments(
output_dir=None,
overwrite_output_dir=True,
num_train_epochs=5,
per_device_train_batch_size=32,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
per_device_eval_batch_size=16,
eval_strategy="steps",
dataloader_num_workers=8,
save_steps=500,
logging_steps=10,
eval_steps=100,
warmup_steps=100,
learning_rate=1e-05,
save_total_limit=1,
ddp_find_unused_parameters=False,
# run_name="visual-colinternvl3_5-1b-test",
report_to=None,
),
peft_config=LoraConfig(
r=32,
lora_alpha=32,
lora_dropout=0.1,
init_lora_weights="gaussian",
bias="none",
task_type="FEATURE_EXTRACTION",
target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
)
)

# make sure output_dir exists and copy script for provenance
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)

trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
trainer.train()
trainer.save()
Loading
Loading