Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

import torch
from packaging.specifiers import SpecifierSet
from torch.nn.functional import normalize
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import __version__ as transformers_version

from mteb._requires_package import requires_package
from mteb.models.abs_encoder import AbsEncoder
from mteb.models.model_meta import ModelMeta
from mteb.types import PromptType

if TYPE_CHECKING:
from mteb.abstasks.task_metadata import TaskMetadata
from mteb.types import Array, BatchedInput, PromptType
from mteb.types import Array, BatchedInput

LLAMA_NEMORETRIEVER_CITATION = """@misc{xu2025llamanemoretrievercolembedtopperforming,
title={Llama Nemoretriever Colembed: Top-Performing Text-Image Retrieval Model},
Expand All @@ -36,6 +39,14 @@
url={https://arxiv.org/abs/2602.03992},
}"""

NEMOTRON_EMBED_VL_1B_V2_CITATION = """
@misc{ronay2026smallyetmighty,
title={Small Yet Mighty: Improve Accuracy In Multimodal Search and Visual Document Retrieval with Llama Nemotron RAG Models},
author={Ronay Ak, Gabriel de Souza Pereira Moreira and Bo Liu},
year={2026},
howpublished = {Available at: https://huggingface.co/blog/nvidia/llama-nemotron-vl-1b},
}"""

# Transformers version constraints per extra.
# Keep in sync with pyproject.toml [project.optional-dependencies]
#
Expand All @@ -45,6 +56,7 @@
_TRANSFORMERS_CONSTRAINTS: dict[str, str] = {
"llama-nemotron-colembed-vl": "==4.49.0", # llama-nemoretriever-colembed-*
"nemotron-colembed-vl-v2": "==5.0.0", # nemotron-colembed-vl-4b-v2, nemotron-colembed-vl-8b-v2
"llama-nemotron-embed-vl-1b-v2": ">=4.56.0", # llama-nemotron-embed-vl-1b-v2
}


Expand Down Expand Up @@ -343,3 +355,141 @@ def encode(
citation=NEMOTRON_COLEMBED_CITATION_V2,
model_type=["late-interaction"],
)


class LlamaNemotronEmbedVL(AbsEncoder):
def __init__(
self,
model_name_or_path: str,
revision: str,
trust_remote_code: bool,
extra_name: str = "llama-nemotron-embed-vl-1b-v2",
device_map="cuda",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
**kwargs,
):
install_hint = f"pip install 'mteb[{extra_name}]'"

# Check transformers version
constraint = _TRANSFORMERS_CONSTRAINTS.get(extra_name)
if constraint is None:
raise ValueError(
f"Unknown extra_name '{extra_name}'. "
f"Must be one of: {list(_TRANSFORMERS_CONSTRAINTS.keys())}"
)
if transformers_version not in SpecifierSet(constraint):
raise RuntimeError(
f"Model `{model_name_or_path}` requires transformers{constraint}, "
f"but {transformers_version} is installed. "
f"Run: {install_hint}"
)

# Check required packages
for package in ("torchvision", "accelerate", "flash_attn"):
requires_package(self, package, model_name_or_path, install_hint)

from transformers import AutoModel

self.model = AutoModel.from_pretrained(
model_name_or_path,
revision=revision,
device_map=device_map,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
attn_implementation=attn_implementation,
).eval()

# Sets the number of tiles the image can be split into
self.model.processor.max_input_tiles = 4

def encode(
self,
inputs: DataLoader[BatchedInput],
*,
task_metadata: TaskMetadata,
hf_split: str,
gf_subset: str,
prompt_type: PromptType | None = None,
show_progress_bar: bool = True,
**kwargs: Any,
) -> Array:
with torch.inference_mode():
embeddings_list = []
for batch in tqdm(
inputs,
desc=f"Extracting {prompt_type} embeddings...",
disable=not show_progress_bar,
):
if prompt_type == PromptType.query:
embeddings = self.model.encode_queries(batch["text"])
else:
if "image" in batch and "text" in batch:
embeddings = self.model.encode_documents(
images=batch["image"], texts=batch["text"]
)
elif "image" in batch:
embeddings = self.model.encode_documents(images=batch["image"])
elif "text" in batch:
embeddings = self.model.encode_documents(texts=batch["text"])
else:
raise ValueError(
f"Could not find 'image' or 'text' in batch: {batch}"
)

embeddings = normalize(embeddings, dim=-1)
assert torch.sum(embeddings).float().item() not in [
0.0,
float("inf"),
]
embeddings_list.append(embeddings)

concatenated_embeddings = torch.vstack(embeddings_list)
return concatenated_embeddings


TRAINING_DATA_EMBED_VL_1B_V2 = {
"VidoreDocVQARetrieval",
"VidoreInfoVQARetrieval",
"VidoreTatdqaRetrieval",
"VidoreArxivQARetrieval",
"docmatix-ir",
"wiki-ss-nq",
"Cauldron (AI2D, OCRVQA, Websight)",
"VDRMultilingualRetrieval",
"HotpotQA",
"MIRACLRetrieval",
"NQ",
"StackExchangeClustering",
"SQuAD",
"MultiLongDocRetrieval",
"MLQARetrieval",
"Tiger Math/Stack",
}

llama_nemotron_embed_vl_1b_v2 = ModelMeta(
loader=LlamaNemotronEmbedVL,
loader_kwargs=dict(
trust_remote_code=True,
),
name="nvidia/llama-nemotron-embed-vl-1b-v2",
languages=["eng-Latn"],
revision="859e1f2dac29c56c37a5279cf55f53f3e74efc6b",
release_date="2026-01-06",
modalities=["image", "text"],
n_parameters=1_678_252_480,
n_embedding_parameters=262_688_768,
memory_usage_mb=6402,
max_tokens=10240,
embed_dim=2048,
license="https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/",
open_weights=True,
public_training_code=None,
public_training_data="https://huggingface.co/nvidia/llama-nemotron-embed-vl-1b-v2#training-dataset",
framework=["PyTorch"],
reference="https://huggingface.co/nvidia/llama-nemotron-embed-vl-1b-v2",
similarity_fn_name="cosine",
use_instructions=True,
training_datasets=TRAINING_DATA_EMBED_VL_1B_V2,
citation=NEMOTRON_EMBED_VL_1B_V2_CITATION,
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ youtu = ["tencentcloud-sdk-python-common>=3.0.1454", "tencentcloud-sdk-python-lk
llama-embed-nemotron = ["transformers==4.51.0"]
llama-nemotron-colembed-vl = ["transformers[torch]==4.49.0", "torchvision>=0.22.0", "flash-attn>=2.6.3", "accelerate"]
nemotron-colembed-vl-v2 = ["transformers[torch]==5.0.0", "torchvision>=0.22.0", "flash-attn>=2.6.3", "accelerate"]
llama-nemotron-embed-vl-1b-v2 = ["transformers[torch]>=4.56.0", "torchvision>=0.22.0", "flash-attn>=2.6.3", "accelerate"]
faiss-cpu = ["faiss-cpu>=1.12.0"]
eager_embed = ["qwen_vl_utils>=0.0.14"]
speechbrain = ["speechbrain>=0.5.12"]
Expand Down Expand Up @@ -376,6 +377,7 @@ conflicts = [
{ extra = "llama-embed-nemotron" }, # conflicting versions of transformers
{ extra = "llama-nemotron-colembed-vl" }, # conflicting versions of transformers
{ extra = "nemotron-colembed-vl-v2" }, # conflicting versions of transformers
{ extra = "llama-nemotron-embed-vl-1b-v2" }, # conflicting versions of transformers
{ extra = "colpali_engine" },
{ extra = "colqwen3" },
{ extra = "jina-v4" },
Expand Down
Loading