diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py
index 0f0c8118..c6e25e20 100644
--- a/colpali_engine/models/__init__.py
+++ b/colpali_engine/models/__init__.py
@@ -1,4 +1,5 @@
from .idefics3 import BiIdefics3, BiIdefics3Processor, ColIdefics3, ColIdefics3Processor
+from .internvl3_5 import ColInternVL3_5, ColInternVL3_5_Processor
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor
diff --git a/colpali_engine/models/internvl3_5/__init__.py b/colpali_engine/models/internvl3_5/__init__.py
new file mode 100644
index 00000000..b4fe9fc7
--- /dev/null
+++ b/colpali_engine/models/internvl3_5/__init__.py
@@ -0,0 +1 @@
+from .colinternvl3_5 import ColInternVL3_5, ColInternVL3_5_Processor
diff --git a/colpali_engine/models/internvl3_5/colinternvl3_5/__init__.py b/colpali_engine/models/internvl3_5/colinternvl3_5/__init__.py
new file mode 100644
index 00000000..320a7337
--- /dev/null
+++ b/colpali_engine/models/internvl3_5/colinternvl3_5/__init__.py
@@ -0,0 +1,2 @@
+from .modeling_colinternvl3_5 import ColInternVL3_5
+from .processing_colinternvl3_5 import ColInternVL3_5_Processor
diff --git a/colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py b/colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py
new file mode 100644
index 00000000..1778c5f9
--- /dev/null
+++ b/colpali_engine/models/internvl3_5/colinternvl3_5/modeling_colinternvl3_5.py
@@ -0,0 +1,67 @@
+from typing import ClassVar
+
+import torch
+from torch import nn
+from transformers.models.internvl import InternVLConfig, InternVLModel
+
+
+
+
+class ColInternVL3_5(InternVLModel):
+ """
+ ColInternVL3_5 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
+
+ Args:
+ config (InternVLConfig): The model configuration.
+ mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
+ except those of the image at inference.
+ Defaults to False --> Do not mask any embeddings during forward pass.
+ """
+
+ main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
+
+ def __init__(self, config: InternVLConfig, mask_non_image_embeddings: bool = False):
+ super().__init__(config=config)
+ self.dim = 128
+ # breakpoint()
+ self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim)
+ self.padding_side = "left"
+ self.mask_non_image_embeddings = mask_non_image_embeddings
+ self.post_init()
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ key_mapping = kwargs.pop("key_mapping", None)
+ if key_mapping is None:
+ key_mapping = super()._checkpoint_conversion_mapping
+ return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+ def forward(self, *args, **kwargs) -> torch.Tensor:
+ kwargs.pop("return_dict", True)
+ kwargs.pop("output_hidden_states", None)
+ kwargs.pop("use_cache", None)
+ hidden_states = (
+ super()
+ .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
+ .last_hidden_state
+ ) # (batch_size, sequence_length, hidden_size)
+
+ proj = self.custom_text_proj(hidden_states) # (batch_size, sequence_length, dim)
+
+ # L2 normalization
+ proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)
+
+ if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+ # Pools only the image embeddings
+ image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
+ proj = proj * image_mask
+ return proj
+
+ @property
+ def patch_size(self) -> int:
+ return self.visual.config.patch_size
+
+ @property
+ def spatial_merge_size(self) -> int:
+ return self.visual.config.spatial_merge_size
diff --git a/colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py b/colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py
new file mode 100644
index 00000000..968dd2bf
--- /dev/null
+++ b/colpali_engine/models/internvl3_5/colinternvl3_5/processing_colinternvl3_5.py
@@ -0,0 +1,136 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature
+from transformers.models.internvl import InternVLProcessor
+from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColInternVL3_5_Processor(BaseVisualRetrieverProcessor, InternVLProcessor):
+ """
+ Processor for ColInternVL3_5.
+
+ Args:
+ *args: Variable length argument list to be passed to the parent `InternVLProcessor` class.
+ max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model.
+ **kwargs: Arbitrary keyword arguments to be passed to the parent `InternVLProcessor` class.
+ """
+
+ visual_prompt_prefix: ClassVar[str] = (
+ "<|im_start|>user\n
\nDescribe the image.<|im_end|><|endoftext|>"
+ )
+ query_augmentation_token: ClassVar[str] = "<|endoftext|>"
+ image_token: ClassVar[str] = ""
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.tokenizer.padding_side = "left"
+
+ # @classmethod
+ # def from_pretrained(
+ # cls,
+ # *args,
+ # device_map: Optional[str] = None,
+ # **kwargs,
+ # ):
+ # instance = super().from_pretrained(
+ # *args,
+ # device_map=device_map,
+ # **kwargs,
+ # )
+
+ # if "max_num_visual_tokens" in kwargs:
+ # instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 28 * 28
+ # instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
+
+ # return instance
+
+ def process_images(
+ self,
+ images: List[Image.Image],
+ ) -> Union[BatchFeature, BatchEncoding]:
+ """
+ Process images for ColInternVL3_5.
+
+ Args:
+ images: List of PIL images.
+ """
+
+ images = [image.convert("RGB") for image in images]
+
+ batch_doc = self(
+ text=[self.visual_prompt_prefix] * len(images),
+ images=images,
+ padding="longest",
+ return_tensors="pt",
+ )
+
+ return batch_doc
+
+ def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+ """
+ Process texts for ColInternVL3_5.
+
+ Args:
+ texts: List of input texts.
+
+ Returns:
+ Union[BatchFeature, BatchEncoding]: Processed texts.
+ """
+ return self(
+ text=texts,
+ return_tensors="pt",
+ padding="longest",
+ )
+
+ def score(
+ self,
+ qs: List[torch.Tensor],
+ ps: List[torch.Tensor],
+ device: Optional[Union[str, torch.device]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+ """
+ return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+ def get_n_patches(
+ self,
+ image_size: Tuple[int, int],
+ spatial_merge_size: int,
+ ) -> Tuple[int, int]:
+ """
+ Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
+ size (height, width) with the given patch size.
+
+ The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
+ as a `InternVL3_5VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
+ """
+ patch_size = self.image_processor.patch_size
+
+ height_new, width_new = smart_resize(
+ width=image_size[0],
+ height=image_size[1],
+ factor=patch_size * self.image_processor.merge_size,
+ min_pixels=self.image_processor.size["shortest_edge"],
+ max_pixels=self.image_processor.size["longest_edge"],
+ )
+
+ n_patches_x = width_new // patch_size // spatial_merge_size
+ n_patches_y = height_new // patch_size // spatial_merge_size
+
+ return n_patches_x, n_patches_y
+
+ def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
+ """
+ Get a tensor mask that identifies the image tokens in the batch.
+ """
+ return batch_images.input_ids == self.image_token_id
diff --git a/scripts/configs/accelerate_configs/single_node_config.yml b/scripts/configs/accelerate_configs/single_node_config.yml
new file mode 100644
index 00000000..9e8d85f8
--- /dev/null
+++ b/scripts/configs/accelerate_configs/single_node_config.yml
@@ -0,0 +1,19 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+gpu_ids: all
+machine_rank: 0
+main_process_ip: ''
+main_process_port: 29500
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/scripts/configs/internvl/train_colinternvl3_5_1b_model.py b/scripts/configs/internvl/train_colinternvl3_5_1b_model.py
new file mode 100644
index 00000000..cc221a53
--- /dev/null
+++ b/scripts/configs/internvl/train_colinternvl3_5_1b_model.py
@@ -0,0 +1,93 @@
+import argparse
+import shutil
+from pathlib import Path
+
+import torch
+from datasets import load_dataset
+from peft import LoraConfig
+from transformers import TrainingArguments
+
+from colpali_engine.data.dataset import ColPaliEngineDataset
+from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss
+from colpali_engine.models import ColInternVL3_5, ColInternVL3_5_Processor
+from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
+from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
+from colpali_engine.utils.dataset_transformation import load_train_set
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+ p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
+ p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
+ p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
+ return p.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ loss_func = ColbertLoss(
+ temperature=args.tau,
+ normalize_scores=True,
+ use_smooth_max=False,
+ pos_aware_negative_filtering=False,
+ )
+
+ config = ColModelTrainingConfig(
+ output_dir="models/colinternvl3_5_1b",
+ processor=ColInternVL3_5_Processor.from_pretrained(
+ pretrained_model_name_or_path="./models/base_models/colinternvl3_5-1b-base",
+ max_num_visual_tokens=768,
+ ),
+ model=ColInternVL3_5.from_pretrained(
+ pretrained_model_name_or_path="./models/base_models/colinternvl3_5-1b-base",
+ dtype=torch.bfloat16,
+ # low_cpu_mem_usage=True,
+ # use_cache=False,
+ trust_remote_code=True,
+ attn_implementation="flash_attention_2",
+ ),
+ train_dataset=load_train_set(),
+ eval_dataset=ColPaliEngineDataset(
+ load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
+ ),
+ run_eval=True,
+ loss_func=loss_func,
+ tr_args=TrainingArguments(
+ output_dir=None,
+ overwrite_output_dir=True,
+ num_train_epochs=5,
+ per_device_train_batch_size=32,
+ gradient_checkpointing=True,
+ gradient_checkpointing_kwargs={"use_reentrant": False},
+ per_device_eval_batch_size=16,
+ eval_strategy="steps",
+ dataloader_num_workers=8,
+ save_steps=500,
+ logging_steps=10,
+ eval_steps=100,
+ warmup_steps=100,
+ learning_rate=1e-05,
+ save_total_limit=1,
+ ddp_find_unused_parameters=False,
+ # run_name="visual-colinternvl3_5-1b-test",
+ report_to=None,
+ ),
+ peft_config=LoraConfig(
+ r=32,
+ lora_alpha=32,
+ lora_dropout=0.1,
+ init_lora_weights="gaussian",
+ bias="none",
+ task_type="FEATURE_EXTRACTION",
+ target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
+ )
+ )
+
+ # make sure output_dir exists and copy script for provenance
+ Path(config.output_dir).mkdir(parents=True, exist_ok=True)
+ shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)
+
+ trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
+ trainer.train()
+ trainer.save()
diff --git a/scripts/configs/internvl/train_colinternvl3_5_2b_model.py b/scripts/configs/internvl/train_colinternvl3_5_2b_model.py
new file mode 100644
index 00000000..8a556ada
--- /dev/null
+++ b/scripts/configs/internvl/train_colinternvl3_5_2b_model.py
@@ -0,0 +1,93 @@
+import argparse
+import shutil
+from pathlib import Path
+
+import torch
+from datasets import load_dataset
+from peft import LoraConfig
+from transformers import TrainingArguments
+
+from colpali_engine.data.dataset import ColPaliEngineDataset
+from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss
+from colpali_engine.models import ColInternVL3_5, ColInternVL3_5_Processor
+from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
+from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
+from colpali_engine.utils.dataset_transformation import load_train_set
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+ p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
+ p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
+ p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
+ return p.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ loss_func = ColbertLoss(
+ temperature=args.tau,
+ normalize_scores=True,
+ use_smooth_max=False,
+ pos_aware_negative_filtering=False,
+ )
+
+ config = ColModelTrainingConfig(
+ output_dir="models/colinternvl3_5_2b",
+ processor=ColInternVL3_5_Processor.from_pretrained(
+ pretrained_model_name_or_path="./models/base_models/colinternvl3_5-2b-base",
+ max_num_visual_tokens=768,
+ ),
+ model=ColInternVL3_5.from_pretrained(
+ pretrained_model_name_or_path="./models/base_models/colinternvl3_5-2b-base",
+ dtype=torch.bfloat16,
+ # low_cpu_mem_usage=True,
+ # use_cache=False,
+ trust_remote_code=True,
+ attn_implementation="flash_attention_2",
+ ),
+ train_dataset=load_train_set(),
+ eval_dataset=ColPaliEngineDataset(
+ load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
+ ),
+ run_eval=True,
+ loss_func=loss_func,
+ tr_args=TrainingArguments(
+ output_dir=None,
+ overwrite_output_dir=True,
+ num_train_epochs=5,
+ per_device_train_batch_size=32,
+ gradient_checkpointing=True,
+ gradient_checkpointing_kwargs={"use_reentrant": False},
+ per_device_eval_batch_size=16,
+ eval_strategy="steps",
+ dataloader_num_workers=8,
+ save_steps=500,
+ logging_steps=10,
+ eval_steps=100,
+ warmup_steps=100,
+ learning_rate=2e-05,
+ save_total_limit=1,
+ ddp_find_unused_parameters=False,
+ # run_name="visual-colinternvl3_5-2b-test",
+ report_to=None,
+ ),
+ peft_config=LoraConfig(
+ r=32,
+ lora_alpha=32,
+ lora_dropout=0.1,
+ init_lora_weights="gaussian",
+ bias="none",
+ task_type="FEATURE_EXTRACTION",
+ target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
+ )
+ )
+
+ # make sure output_dir exists and copy script for provenance
+ Path(config.output_dir).mkdir(parents=True, exist_ok=True)
+ shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)
+
+ trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
+ trainer.train()
+ trainer.save()
diff --git a/slurms/colinternvl3_5/train_colinternvl3_5_1b.slurm b/slurms/colinternvl3_5/train_colinternvl3_5_1b.slurm
new file mode 100644
index 00000000..e5412a58
--- /dev/null
+++ b/slurms/colinternvl3_5/train_colinternvl3_5_1b.slurm
@@ -0,0 +1,40 @@
+#!/bin/bash
+#SBATCH --output=slurm_logs/%x_%j.log
+#SBATCH --error=slurm_logs/%x_%j.log
+#SBATCH -A qjm@h100
+#SBATCH -C h100
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:4
+#SBATCH --time=20:00:00
+#SBATCH --job-name=train-colinternvl3_5-1b
+# #SBATCH --cpu-bind=none
+#SBATCH --mem-bind="local"
+#SBATCH --cpus-per-task=64
+#SBATCH --hint=nomultithread
+
+export OMP_NUM_THREADS=64
+export MKL_NUM_THREADS=64
+export NUMEXPR_NUM_THREADS=64
+
+module purge
+module load arch/h100 cuda/12.8.0
+
+export HF_DATASETS_CACHE=$SCRATCH/datasets
+export HF_HOME=$SCRATCH/.cache/huggingface
+export UV_CACHE_DIR=$SCRATCH
+export HF_HUB_ENABLE_HF_TRANSFER=1
+export WANDB_PROJECT="colinternvl3_5_1b"
+export WANDB_NAME="colinternvl3_5_1b"
+export WANDB_MODE=offline
+export HF_DATASETS_OFFLINE=1
+export HF_DATASETS_IN_MEMORY_MAX_SIZE=0
+export HF_HUB_OFFLINE=1
+
+source .venv/bin/activate
+wandb offline
+echo "launching training script"
+accelerate launch --config_file scripts/configs/accelerate_configs/single_node_config.yml scripts/configs/internvl/train_colinternvl3_5_1b_model.py
+cd ../../mteb
+source .venv/bin/activate
+python -m experiments.colinternvl.eval_colinternvl --model_name models/colinternvl3_5_1b
diff --git a/slurms/colinternvl3_5/train_colinternvl3_5_2b.slurm b/slurms/colinternvl3_5/train_colinternvl3_5_2b.slurm
new file mode 100644
index 00000000..a53d0a88
--- /dev/null
+++ b/slurms/colinternvl3_5/train_colinternvl3_5_2b.slurm
@@ -0,0 +1,40 @@
+#!/bin/bash
+#SBATCH --output=slurm_logs/%x_%j.log
+#SBATCH --error=slurm_logs/%x_%j.log
+#SBATCH -A qjm@h100
+#SBATCH -C h100
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:4
+#SBATCH --time=20:00:00
+#SBATCH --job-name=train-colinternvl3_5-2b
+# #SBATCH --cpu-bind=none
+#SBATCH --mem-bind="local"
+#SBATCH --cpus-per-task=64
+#SBATCH --hint=nomultithread
+
+export OMP_NUM_THREADS=64
+export MKL_NUM_THREADS=64
+export NUMEXPR_NUM_THREADS=64
+
+module purge
+module load arch/h100 cuda/12.8.0
+
+export HF_DATASETS_CACHE=$SCRATCH/datasets
+export HF_HOME=$SCRATCH/.cache/huggingface
+export UV_CACHE_DIR=$SCRATCH
+export HF_HUB_ENABLE_HF_TRANSFER=1
+export WANDB_PROJECT="train-colinternvl3_5-2b"
+export WANDB_NAME="train-colinternvl3_5-2b"
+export WANDB_MODE=offline
+export HF_DATASETS_OFFLINE=1
+export HF_DATASETS_IN_MEMORY_MAX_SIZE=0
+export HF_HUB_OFFLINE=1
+
+source .venv/bin/activate
+wandb offline
+echo "launching training script"
+accelerate launch --config_file scripts/configs/accelerate_configs/single_node_config.yml scripts/configs/internvl/train_colinternvl3_5_2b_model.py
+cd ../../mteb
+source .venv/bin/activate
+python -m experiments.colinternvl.eval_colinternvl --model_name models/colinternvl3_5_2b