diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 0f0c8118..c6e25e20 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -1,4 +1,5 @@ from .idefics3 import BiIdefics3, BiIdefics3Processor, ColIdefics3, ColIdefics3Processor +from .internvl3_5 import ColInternVL3_5, ColInternVL3_5_Processor from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor diff --git a/colpali_engine/models/internvl3_5/__init__.py b/colpali_engine/models/internvl3_5/__init__.py new file mode 100644 index 00000000..b4fe9fc7 --- /dev/null +++ b/colpali_engine/models/internvl3_5/__init__.py @@ -0,0 +1 @@ +from .colinternvl3_5 import ColInternVL3_5, ColInternVL3_5_Processor diff --git a/colpali_engine/models/internvl3_5/colinternvl3_5/__init__.py b/colpali_engine/models/internvl3_5/colinternvl3_5/__init__.py new file mode 100644 index 00000000..320a7337 --- /dev/null +++ b/colpali_engine/models/internvl3_5/colinternvl3_5/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colinternvl3_5 import ColInternVL3_5 +from .processing_colinternvl3_5 import ColInternVL3_5_Processor diff --git a/colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py b/colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py new file mode 100644 index 00000000..1778c5f9 --- /dev/null +++ b/colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py @@ -0,0 +1,67 @@ +from typing import ClassVar + +import torch +from torch import nn +from transformers.models.internvl import InternVLConfig, InternVLModel + + + + +class ColInternVL3_5(InternVLModel): + """ + ColInternVL3_5 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. + + Args: + config (InternVLConfig): The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + + def __init__(self, config: InternVLConfig, mask_non_image_embeddings: bool = False): + super().__init__(config=config) + self.dim = 128 + # breakpoint() + self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim) + self.padding_side = "left" + self.mask_non_image_embeddings = mask_non_image_embeddings + self.post_init() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = super()._checkpoint_conversion_mapping + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward(self, *args, **kwargs) -> torch.Tensor: + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) # (batch_size, sequence_length, hidden_size) + + proj = self.custom_text_proj(hidden_states) # (batch_size, sequence_length, dim) + + # L2 normalization + proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size diff --git a/colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py b/colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py new file mode 100644 index 00000000..968dd2bf --- /dev/null +++ b/colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py @@ -0,0 +1,136 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature +from transformers.models.internvl import InternVLProcessor +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColInternVL3_5_Processor(BaseVisualRetrieverProcessor, InternVLProcessor): + """ + Processor for ColInternVL3_5. + + Args: + *args: Variable length argument list to be passed to the parent `InternVLProcessor` class. + max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model. + **kwargs: Arbitrary keyword arguments to be passed to the parent `InternVLProcessor` class. + """ + + visual_prompt_prefix: ClassVar[str] = ( + "<|im_start|>user\n\nDescribe the image.<|im_end|><|endoftext|>" + ) + query_augmentation_token: ClassVar[str] = "<|endoftext|>" + image_token: ClassVar[str] = "" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.tokenizer.padding_side = "left" + + # @classmethod + # def from_pretrained( + # cls, + # *args, + # device_map: Optional[str] = None, + # **kwargs, + # ): + # instance = super().from_pretrained( + # *args, + # device_map=device_map, + # **kwargs, + # ) + + # if "max_num_visual_tokens" in kwargs: + # instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 28 * 28 + # instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels + + # return instance + + def process_images( + self, + images: List[Image.Image], + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColInternVL3_5. + + Args: + images: List of PIL images. + """ + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=[self.visual_prompt_prefix] * len(images), + images=images, + padding="longest", + return_tensors="pt", + ) + + return batch_doc + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColInternVL3_5. + + Args: + texts: List of input texts. + + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + spatial_merge_size: int, + ) -> Tuple[int, int]: + """ + Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of + size (height, width) with the given patch size. + + The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in + as a `InternVL3_5VLForConditionalGeneration` attribute under `model.spatial_merge_size`. + """ + patch_size = self.image_processor.patch_size + + height_new, width_new = smart_resize( + width=image_size[0], + height=image_size[1], + factor=patch_size * self.image_processor.merge_size, + min_pixels=self.image_processor.size["shortest_edge"], + max_pixels=self.image_processor.size["longest_edge"], + ) + + n_patches_x = width_new // patch_size // spatial_merge_size + n_patches_y = height_new // patch_size // spatial_merge_size + + return n_patches_x, n_patches_y + + def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + """ + Get a tensor mask that identifies the image tokens in the batch. + """ + return batch_images.input_ids == self.image_token_id diff --git a/scripts/configs/accelerate_configs/single_node_config.yml b/scripts/configs/accelerate_configs/single_node_config.yml new file mode 100644 index 00000000..9e8d85f8 --- /dev/null +++ b/scripts/configs/accelerate_configs/single_node_config.yml @@ -0,0 +1,19 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: all +machine_rank: 0 +main_process_ip: '' +main_process_port: 29500 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/configs/internvl/train_colinternvl3_5_1b_model.py b/scripts/configs/internvl/train_colinternvl3_5_1b_model.py new file mode 100644 index 00000000..cc221a53 --- /dev/null +++ b/scripts/configs/internvl/train_colinternvl3_5_1b_model.py @@ -0,0 +1,93 @@ +import argparse +import shutil +from pathlib import Path + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import TrainingArguments + +from colpali_engine.data.dataset import ColPaliEngineDataset +from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss +from colpali_engine.models import ColInternVL3_5, ColInternVL3_5_Processor +from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining +from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig +from colpali_engine.utils.dataset_transformation import load_train_set + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function") + p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use") + p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use") + return p.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + loss_func = ColbertLoss( + temperature=args.tau, + normalize_scores=True, + use_smooth_max=False, + pos_aware_negative_filtering=False, + ) + + config = ColModelTrainingConfig( + output_dir="models/colinternvl3_5_1b", + processor=ColInternVL3_5_Processor.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colinternvl3_5-1b-base", + max_num_visual_tokens=768, + ), + model=ColInternVL3_5.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colinternvl3_5-1b-base", + dtype=torch.bfloat16, + # low_cpu_mem_usage=True, + # use_cache=False, + trust_remote_code=True, + attn_implementation="flash_attention_2", + ), + train_dataset=load_train_set(), + eval_dataset=ColPaliEngineDataset( + load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image" + ), + run_eval=True, + loss_func=loss_func, + tr_args=TrainingArguments( + output_dir=None, + overwrite_output_dir=True, + num_train_epochs=5, + per_device_train_batch_size=32, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + per_device_eval_batch_size=16, + eval_strategy="steps", + dataloader_num_workers=8, + save_steps=500, + logging_steps=10, + eval_steps=100, + warmup_steps=100, + learning_rate=1e-05, + save_total_limit=1, + ddp_find_unused_parameters=False, + # run_name="visual-colinternvl3_5-1b-test", + report_to=None, + ), + peft_config=LoraConfig( + r=32, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights="gaussian", + bias="none", + task_type="FEATURE_EXTRACTION", + target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)", + ) + ) + + # make sure output_dir exists and copy script for provenance + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name) + + trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config) + trainer.train() + trainer.save() diff --git a/scripts/configs/internvl/train_colinternvl3_5_2b_model.py b/scripts/configs/internvl/train_colinternvl3_5_2b_model.py new file mode 100644 index 00000000..8a556ada --- /dev/null +++ b/scripts/configs/internvl/train_colinternvl3_5_2b_model.py @@ -0,0 +1,93 @@ +import argparse +import shutil +from pathlib import Path + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import TrainingArguments + +from colpali_engine.data.dataset import ColPaliEngineDataset +from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss +from colpali_engine.models import ColInternVL3_5, ColInternVL3_5_Processor +from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining +from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig +from colpali_engine.utils.dataset_transformation import load_train_set + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function") + p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use") + p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use") + return p.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + loss_func = ColbertLoss( + temperature=args.tau, + normalize_scores=True, + use_smooth_max=False, + pos_aware_negative_filtering=False, + ) + + config = ColModelTrainingConfig( + output_dir="models/colinternvl3_5_2b", + processor=ColInternVL3_5_Processor.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colinternvl3_5-2b-base", + max_num_visual_tokens=768, + ), + model=ColInternVL3_5.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colinternvl3_5-2b-base", + dtype=torch.bfloat16, + # low_cpu_mem_usage=True, + # use_cache=False, + trust_remote_code=True, + attn_implementation="flash_attention_2", + ), + train_dataset=load_train_set(), + eval_dataset=ColPaliEngineDataset( + load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image" + ), + run_eval=True, + loss_func=loss_func, + tr_args=TrainingArguments( + output_dir=None, + overwrite_output_dir=True, + num_train_epochs=5, + per_device_train_batch_size=32, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + per_device_eval_batch_size=16, + eval_strategy="steps", + dataloader_num_workers=8, + save_steps=500, + logging_steps=10, + eval_steps=100, + warmup_steps=100, + learning_rate=2e-05, + save_total_limit=1, + ddp_find_unused_parameters=False, + # run_name="visual-colinternvl3_5-2b-test", + report_to=None, + ), + peft_config=LoraConfig( + r=32, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights="gaussian", + bias="none", + task_type="FEATURE_EXTRACTION", + target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)", + ) + ) + + # make sure output_dir exists and copy script for provenance + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name) + + trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config) + trainer.train() + trainer.save() diff --git a/slurms/colinternvl3_5/train_colinternvl3_5_1b.slurm b/slurms/colinternvl3_5/train_colinternvl3_5_1b.slurm new file mode 100644 index 00000000..e5412a58 --- /dev/null +++ b/slurms/colinternvl3_5/train_colinternvl3_5_1b.slurm @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --output=slurm_logs/%x_%j.log +#SBATCH --error=slurm_logs/%x_%j.log +#SBATCH -A qjm@h100 +#SBATCH -C h100 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=20:00:00 +#SBATCH --job-name=train-colinternvl3_5-1b +# #SBATCH --cpu-bind=none +#SBATCH --mem-bind="local" +#SBATCH --cpus-per-task=64 +#SBATCH --hint=nomultithread + +export OMP_NUM_THREADS=64 +export MKL_NUM_THREADS=64 +export NUMEXPR_NUM_THREADS=64 + +module purge +module load arch/h100 cuda/12.8.0 + +export HF_DATASETS_CACHE=$SCRATCH/datasets +export HF_HOME=$SCRATCH/.cache/huggingface +export UV_CACHE_DIR=$SCRATCH +export HF_HUB_ENABLE_HF_TRANSFER=1 +export WANDB_PROJECT="colinternvl3_5_1b" +export WANDB_NAME="colinternvl3_5_1b" +export WANDB_MODE=offline +export HF_DATASETS_OFFLINE=1 +export HF_DATASETS_IN_MEMORY_MAX_SIZE=0 +export HF_HUB_OFFLINE=1 + +source .venv/bin/activate +wandb offline +echo "launching training script" +accelerate launch --config_file scripts/configs/accelerate_configs/single_node_config.yml scripts/configs/internvl/train_colinternvl3_5_1b_model.py +cd ../../mteb +source .venv/bin/activate +python -m experiments.colinternvl.eval_colinternvl --model_name models/colinternvl3_5_1b diff --git a/slurms/colinternvl3_5/train_colinternvl3_5_2b.slurm b/slurms/colinternvl3_5/train_colinternvl3_5_2b.slurm new file mode 100644 index 00000000..a53d0a88 --- /dev/null +++ b/slurms/colinternvl3_5/train_colinternvl3_5_2b.slurm @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --output=slurm_logs/%x_%j.log +#SBATCH --error=slurm_logs/%x_%j.log +#SBATCH -A qjm@h100 +#SBATCH -C h100 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:4 +#SBATCH --time=20:00:00 +#SBATCH --job-name=train-colinternvl3_5-2b +# #SBATCH --cpu-bind=none +#SBATCH --mem-bind="local" +#SBATCH --cpus-per-task=64 +#SBATCH --hint=nomultithread + +export OMP_NUM_THREADS=64 +export MKL_NUM_THREADS=64 +export NUMEXPR_NUM_THREADS=64 + +module purge +module load arch/h100 cuda/12.8.0 + +export HF_DATASETS_CACHE=$SCRATCH/datasets +export HF_HOME=$SCRATCH/.cache/huggingface +export UV_CACHE_DIR=$SCRATCH +export HF_HUB_ENABLE_HF_TRANSFER=1 +export WANDB_PROJECT="train-colinternvl3_5-2b" +export WANDB_NAME="train-colinternvl3_5-2b" +export WANDB_MODE=offline +export HF_DATASETS_OFFLINE=1 +export HF_DATASETS_IN_MEMORY_MAX_SIZE=0 +export HF_HUB_OFFLINE=1 + +source .venv/bin/activate +wandb offline +echo "launching training script" +accelerate launch --config_file scripts/configs/accelerate_configs/single_node_config.yml scripts/configs/internvl/train_colinternvl3_5_2b_model.py +cd ../../mteb +source .venv/bin/activate +python -m experiments.colinternvl.eval_colinternvl --model_name models/colinternvl3_5_2b