Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,22 @@ def _coerce_none_to_empty(cls, values):
if isinstance(values, dict) and values.get("api_key") is None:
values["api_key"] = ""
return values

@model_validator(mode="after")
def _reject_non_text_modalities(self):
"""Reject image/text_image modalities — only text modality is supported."""
_NON_TEXT = frozenset({"image", "text_image", "image_text"})
for field_name in (
"text_elements_modality",
"image_elements_modality",
"structured_elements_modality",
"audio_elements_modality",
):
value = getattr(self, field_name, "text")
if value in _NON_TEXT:
raise ValueError(
f"{field_name}={value!r} is not supported. "
f"Only 'text' modality is supported for embedding. "
f"Image and multimodal embedding support has been removed."
)
return self
10 changes: 10 additions & 0 deletions api/src/nv_ingest_api/internal/transform/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,16 @@ def _has_page_images(df):
ContentTypeEnum.VIDEO: lambda x: None, # Not supported yet.
}

# Reject non-text modalities — image/multimodal embedding support has been removed.
_NON_TEXT_MODALITIES = frozenset({"image", "text_image", "image_text"})
for _ct, _mod in task_type_to_modality.items():
if isinstance(_mod, str) and _mod in _NON_TEXT_MODALITIES:
raise ValueError(
f"Modality {_mod!r} for {_ct.value} elements is not supported. "
f"Only 'text' modality is supported for embedding. "
f"Image and multimodal embedding support has been removed."
)

# Determine which content types to embed
# When aggregating page content, automatically skip TEXT and STRUCTURED unless explicitly set
def _get_embed_flag(content_type: ContentTypeEnum) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion retriever/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ dependencies = [
"python-multipart>=0.0.9",
# NOTE: `llama_nemotron_embed_1b_v2` currently expects `HybridCache` from
# `transformers.cache_utils`, which is not present in transformers>=5.
"transformers>=4.49.0,<5.0.0",
# Versions 4.54.0-4.55.x have a flash attention bug that produces incorrect
# image embeddings in the VL model; exclude that range.
"transformers>=4.49.0,<5.0.0,!=4.54.*,!=4.55.*",
"tokenizers>=0.20.3,<0.22.0",
"torch~=2.9.1",
"torchvision>=0.24,<0.25",
Expand Down
21 changes: 16 additions & 5 deletions retriever/src/retriever/examples/batch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ def main(
"--embed-invoke-url",
help="Optional remote endpoint URL for embedding model inference.",
),
embed_model_name: str = typer.Option(
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"--embed-model-name",
help="Embedding model name passed to .embed().",
),
runtime_metrics_dir: Optional[Path] = typer.Option(
None,
"--runtime-metrics-dir",
Expand Down Expand Up @@ -565,7 +570,7 @@ def main(
ingestor = (
ingestor.files(glob_pattern)
.extract_txt(TextChunkParams(max_tokens=512, overlap_tokens=0))
.embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url))
.embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url))
.vdb_upload(
VdbUploadParams(
lancedb={
Expand All @@ -586,7 +591,7 @@ def main(
ingestor = (
ingestor.files(glob_pattern)
.extract_html(TextChunkParams(max_tokens=512, overlap_tokens=0))
.embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url))
.embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url))
.vdb_upload(
VdbUploadParams(
lancedb={
Expand Down Expand Up @@ -635,7 +640,7 @@ def main(
)
.embed(
EmbedParams(
model_name="nemo_retriever_v1",
model_name=str(embed_model_name),
embed_invoke_url=embed_invoke_url,
batch_tuning={
"embed_workers": int(embed_workers),
Expand Down Expand Up @@ -691,7 +696,7 @@ def main(
)
.embed(
EmbedParams(
model_name="nemo_retriever_v1",
model_name=str(embed_model_name),
embed_invoke_url=embed_invoke_url,
batch_tuning={
"embed_workers": int(embed_workers),
Expand Down Expand Up @@ -767,10 +772,16 @@ def main(
unique_basenames = table.to_pandas()["pdf_basename"].unique()
print(f"Unique basenames: {unique_basenames}")

# Resolve the HF model ID for recall query embedding so aliases
# (e.g. "nemo_retriever_v1") map to the correct model.
from retriever.model import resolve_embed_model

_recall_model = resolve_embed_model(str(embed_model_name))

cfg = RecallConfig(
lancedb_uri=str(lancedb_uri),
lancedb_table=str(LANCEDB_TABLE),
embedding_model="nvidia/llama-3.2-nv-embedqa-1b-v2",
embedding_model=_recall_model,
top_k=10,
ks=(1, 5, 10),
)
Expand Down
21 changes: 16 additions & 5 deletions retriever/src/retriever/examples/inprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def main(
"--embed-invoke-url",
help="Optional remote endpoint URL for embedding model inference.",
),
embed_model_name: str = typer.Option(
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"--embed-model-name",
help="Embedding model name passed to .embed().",
),
) -> None:
_ = input_type

Expand All @@ -169,7 +174,7 @@ def main(
ingestor = (
ingestor.files(glob_pattern)
.extract_txt(TextChunkParams(max_tokens=512, overlap_tokens=0))
.embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url))
.embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url))
.vdb_upload(
VdbUploadParams(
lancedb={
Expand All @@ -187,7 +192,7 @@ def main(
ingestor = (
ingestor.files(glob_pattern)
.extract_html(TextChunkParams(max_tokens=512, overlap_tokens=0))
.embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url))
.embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url))
.vdb_upload(
VdbUploadParams(
lancedb={
Expand Down Expand Up @@ -216,7 +221,7 @@ def main(
ocr_invoke_url=ocr_invoke_url,
)
)
.embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url))
.embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url))
.vdb_upload(
VdbUploadParams(
lancedb={
Expand Down Expand Up @@ -244,7 +249,7 @@ def main(
ocr_invoke_url=ocr_invoke_url,
)
)
.embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url))
.embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url))
.vdb_upload(
VdbUploadParams(
lancedb={
Expand Down Expand Up @@ -285,10 +290,16 @@ def main(
unique_basenames = table.to_pandas()["pdf_basename"].unique()
print(f"Unique basenames: {unique_basenames}")

# Resolve the HF model ID for recall query embedding so aliases
# (e.g. "nemo_retriever_v1") map to the correct model.
from retriever.model import resolve_embed_model

_recall_model = resolve_embed_model(str(embed_model_name))

cfg = RecallConfig(
lancedb_uri=str(LANCEDB_URI),
lancedb_table=str(LANCEDB_TABLE),
embedding_model="nvidia/llama-3.2-nv-embedqa-1b-v2",
embedding_model=_recall_model,
embedding_http_endpoint=embed_invoke_url,
top_k=10,
ks=(1, 5, 10),
Expand Down
37 changes: 23 additions & 14 deletions retriever/src/retriever/ingest_modes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,25 +252,34 @@ def __init__(self, params: EmbedParams) -> None:
self._model = None
return

from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder

device = self._kwargs.get("device")
hf_cache_dir = self._kwargs.get("hf_cache_dir")
normalize = bool(self._kwargs.get("normalize", True))
max_length = int(self._kwargs.get("max_length", 8192))
# model_name may be a NIM alias (e.g. "nemo_retriever_v1") or a real HF
# repo ID (e.g. "nvidia/llama-3.2-nv-embedqa-1b-v2"). Only forward it as
# model_id when it looks like an HF repo (contains "/").
model_name_raw = self._kwargs.get("model_name")
model_id = model_name_raw if (isinstance(model_name_raw, str) and "/" in model_name_raw) else None

self._model = LlamaNemotronEmbed1BV2Embedder(
device=str(device) if device else None,
hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None,
normalize=normalize,
max_length=max_length,
model_id=model_id,
)

from retriever.model import is_vl_embed_model, resolve_embed_model

model_id = resolve_embed_model(model_name_raw)

if is_vl_embed_model(model_name_raw):
from retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder

self._model = LlamaNemotronEmbedVL1BV2Embedder(
device=str(device) if device else None,
hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None,
model_id=model_id,
)
else:
from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder

self._model = LlamaNemotronEmbed1BV2Embedder(
device=str(device) if device else None,
hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None,
normalize=normalize,
max_length=max_length,
model_id=model_id,
)

def __call__(self, batch_df: Any) -> Any:
from retriever.ingest_modes.inprocess import embed_text_main_text_embed
Expand Down
52 changes: 32 additions & 20 deletions retriever/src/retriever/ingest_modes/inprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@


import pandas as pd
from retriever.chart.chart_detection import detect_graphic_elements_v1_from_page_elements_v3 # noqa: F401
from retriever.model.local import NemotronGraphicElementsV1, NemotronOCRV1, NemotronPageElementsV3 # noqa: F401
from retriever.model.local import NemotronOCRV1, NemotronPageElementsV3
from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder
from retriever.page_elements import detect_page_elements_v3
from retriever.ocr.ocr import ocr_page_elements
Expand Down Expand Up @@ -209,20 +208,23 @@ def embed_text_main_text_embed(
if _endpoint is None and model is None:
raise ValueError("Either a local model or an embedding_endpoint must be provided.")

# Map NIM aliases to the actual model ID expected by the remote endpoint.
_NIM_MODEL_ALIASES = {
"nemo_retriever_v1": "nvidia/llama-3.2-nv-embedqa-1b-v2",
}
_resolved_model_name = _NIM_MODEL_ALIASES.get(model_name, model_name) if model_name else model_name
# Resolve NIM aliases to the actual HF model ID.
from retriever.model import resolve_embed_model

_resolved_model_name = resolve_embed_model(model_name)

# Build an embedder callable compatible with `create_text_embeddings_for_df`.
# Only used when running with a local model (no NIM endpoint).
_embed = None
if _endpoint is None and model is not None:

# VL model handles formatting internally via encode_documents();
# embedqa needs an explicit "passage: " prefix.
_skip_prefix = hasattr(model, "embed_queries")

def _embed(texts: Sequence[str]) -> Sequence[Sequence[float]]: # noqa: F811
prefixed = [f"passage: {t}" for t in texts]
vecs = model.embed(prefixed, batch_size=int(inference_batch_size))
batch = texts if _skip_prefix else [f"passage: {t}" for t in texts]
vecs = model.embed(batch, batch_size=int(inference_batch_size))
tolist = getattr(vecs, "tolist", None)
if callable(tolist):
return tolist()
Expand Down Expand Up @@ -1196,20 +1198,30 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI
normalize = bool(embed_kwargs.pop("normalize", True))
max_length = int(embed_kwargs.pop("max_length", 8192))

# model_name may be a NIM alias (e.g. "nemo_retriever_v1") or a real HF
# repo ID (e.g. "nvidia/llama-3.2-nv-embedqa-1b-v2"). Only forward it as
# model_id when it looks like an HF repo (contains "/").
model_name_raw = embed_kwargs.pop("model_name", None)
model_id = model_name_raw if (isinstance(model_name_raw, str) and "/" in model_name_raw) else None

from retriever.model import is_vl_embed_model, resolve_embed_model

model_id = resolve_embed_model(model_name_raw)

embed_kwargs.setdefault("input_type", "passage")
embed_kwargs["model"] = LlamaNemotronEmbed1BV2Embedder(
device=str(device) if device is not None else None,
hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None,
normalize=normalize,
max_length=max_length,
model_id=model_id,
)

if is_vl_embed_model(model_name_raw):
from retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder

embed_kwargs["model"] = LlamaNemotronEmbedVL1BV2Embedder(
device=str(device) if device is not None else None,
hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None,
model_id=model_id,
)
else:
embed_kwargs["model"] = LlamaNemotronEmbed1BV2Embedder(
device=str(device) if device is not None else None,
hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None,
normalize=normalize,
max_length=max_length,
model_id=model_id,
)
self._tasks.append((embed_text_main_text_embed, embed_kwargs))
return self

Expand Down
32 changes: 32 additions & 0 deletions retriever/src/retriever/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,35 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

_VL_EMBED_MODEL_IDS = frozenset(
{
"nvidia/llama-nemotron-embed-vl-1b-v2",
"llama-nemotron-embed-vl-1b-v2",
}
)

# Short name → full HF repo ID.
_EMBED_MODEL_ALIASES: dict[str, str] = {
"nemo_retriever_v1": "nvidia/llama-3.2-nv-embedqa-1b-v2",
"llama-nemotron-embed-vl-1b-v2": "nvidia/llama-nemotron-embed-vl-1b-v2",
}

_DEFAULT_EMBED_MODEL = "nvidia/llama-3.2-nv-embedqa-1b-v2"


def resolve_embed_model(model_name: str | None) -> str:
"""Resolve a model name/alias to a full HF repo ID.

Returns ``_DEFAULT_EMBED_MODEL`` when *model_name* is ``None`` or empty.
"""
if not model_name:
return _DEFAULT_EMBED_MODEL
return _EMBED_MODEL_ALIASES.get(model_name, model_name)


def is_vl_embed_model(model_name: str | None) -> bool:
"""Return True if *model_name* refers to the VL embedding model."""
return resolve_embed_model(model_name) in _VL_EMBED_MODEL_IDS
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,23 @@ def _embed_local(self, texts: List[str], *, batch_size: int) -> torch.Tensor:
max_length=max(1, int(self.max_length)),
return_tensors="pt",
).to(dev)
out = self._model(**batch)
lhs = out.last_hidden_state # [B, S, D]
out = self._model(**batch, output_hidden_states=True)
# The bidirectional model returns BaseModelOutputWithPast
# (last_hidden_state), but some transformers versions or
# model revisions return CausalLMOutputWithPast (hidden_states).
lhs = getattr(out, "last_hidden_state", None)
if lhs is None:
# CausalLMOutputWithPast: use the last layer's hidden state.
hs = getattr(out, "hidden_states", None)
if hs is not None:
lhs = hs[-1]
else:
raise AttributeError(
f"Model output ({type(out).__name__}) has neither "
"'last_hidden_state' nor 'hidden_states'. "
"Ensure the model is loaded with trust_remote_code=True."
)
# lhs shape: [B, S, D]
mask = batch["attention_mask"].unsqueeze(-1) # [B, S, 1]
vec = (lhs * mask).sum(dim=1) / mask.sum(dim=1) # [B, D]
vec = vec.detach().to("cpu")
Expand Down
Loading
Loading