diff --git a/api/src/nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py b/api/src/nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py index ce7f35d3e..dc3d791b3 100644 --- a/api/src/nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +++ b/api/src/nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py @@ -62,3 +62,22 @@ def _coerce_none_to_empty(cls, values): if isinstance(values, dict) and values.get("api_key") is None: values["api_key"] = "" return values + + @model_validator(mode="after") + def _reject_non_text_modalities(self): + """Reject image/text_image modalities — only text modality is supported.""" + _NON_TEXT = frozenset({"image", "text_image", "image_text"}) + for field_name in ( + "text_elements_modality", + "image_elements_modality", + "structured_elements_modality", + "audio_elements_modality", + ): + value = getattr(self, field_name, "text") + if value in _NON_TEXT: + raise ValueError( + f"{field_name}={value!r} is not supported. " + f"Only 'text' modality is supported for embedding. " + f"Image and multimodal embedding support has been removed." + ) + return self diff --git a/api/src/nv_ingest_api/internal/transform/embed_text.py b/api/src/nv_ingest_api/internal/transform/embed_text.py index 88147e017..aa10985e9 100644 --- a/api/src/nv_ingest_api/internal/transform/embed_text.py +++ b/api/src/nv_ingest_api/internal/transform/embed_text.py @@ -800,6 +800,16 @@ def _has_page_images(df): ContentTypeEnum.VIDEO: lambda x: None, # Not supported yet. } + # Reject non-text modalities — image/multimodal embedding support has been removed. + _NON_TEXT_MODALITIES = frozenset({"image", "text_image", "image_text"}) + for _ct, _mod in task_type_to_modality.items(): + if isinstance(_mod, str) and _mod in _NON_TEXT_MODALITIES: + raise ValueError( + f"Modality {_mod!r} for {_ct.value} elements is not supported. " + f"Only 'text' modality is supported for embedding. " + f"Image and multimodal embedding support has been removed." + ) + # Determine which content types to embed # When aggregating page content, automatically skip TEXT and STRUCTURED unless explicitly set def _get_embed_flag(content_type: ContentTypeEnum) -> bool: diff --git a/retriever/pyproject.toml b/retriever/pyproject.toml index a8e58d4a5..04b49e823 100644 --- a/retriever/pyproject.toml +++ b/retriever/pyproject.toml @@ -46,7 +46,9 @@ dependencies = [ "python-multipart>=0.0.9", # NOTE: `llama_nemotron_embed_1b_v2` currently expects `HybridCache` from # `transformers.cache_utils`, which is not present in transformers>=5. - "transformers>=4.49.0,<5.0.0", + # Versions 4.54.0-4.55.x have a flash attention bug that produces incorrect + # image embeddings in the VL model; exclude that range. + "transformers>=4.49.0,<5.0.0,!=4.54.*,!=4.55.*", "tokenizers>=0.20.3,<0.22.0", "torch~=2.9.1", "torchvision>=0.24,<0.25", diff --git a/retriever/src/retriever/examples/batch_pipeline.py b/retriever/src/retriever/examples/batch_pipeline.py index 7919b1fbb..5a9ad955e 100644 --- a/retriever/src/retriever/examples/batch_pipeline.py +++ b/retriever/src/retriever/examples/batch_pipeline.py @@ -489,6 +489,11 @@ def main( "--embed-invoke-url", help="Optional remote endpoint URL for embedding model inference.", ), + embed_model_name: str = typer.Option( + "nvidia/llama-3.2-nv-embedqa-1b-v2", + "--embed-model-name", + help="Embedding model name passed to .embed().", + ), runtime_metrics_dir: Optional[Path] = typer.Option( None, "--runtime-metrics-dir", @@ -565,7 +570,7 @@ def main( ingestor = ( ingestor.files(glob_pattern) .extract_txt(TextChunkParams(max_tokens=512, overlap_tokens=0)) - .embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url)) + .embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url)) .vdb_upload( VdbUploadParams( lancedb={ @@ -586,7 +591,7 @@ def main( ingestor = ( ingestor.files(glob_pattern) .extract_html(TextChunkParams(max_tokens=512, overlap_tokens=0)) - .embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url)) + .embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url)) .vdb_upload( VdbUploadParams( lancedb={ @@ -635,7 +640,7 @@ def main( ) .embed( EmbedParams( - model_name="nemo_retriever_v1", + model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url, batch_tuning={ "embed_workers": int(embed_workers), @@ -691,7 +696,7 @@ def main( ) .embed( EmbedParams( - model_name="nemo_retriever_v1", + model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url, batch_tuning={ "embed_workers": int(embed_workers), @@ -767,10 +772,16 @@ def main( unique_basenames = table.to_pandas()["pdf_basename"].unique() print(f"Unique basenames: {unique_basenames}") + # Resolve the HF model ID for recall query embedding so aliases + # (e.g. "nemo_retriever_v1") map to the correct model. + from retriever.model import resolve_embed_model + + _recall_model = resolve_embed_model(str(embed_model_name)) + cfg = RecallConfig( lancedb_uri=str(lancedb_uri), lancedb_table=str(LANCEDB_TABLE), - embedding_model="nvidia/llama-3.2-nv-embedqa-1b-v2", + embedding_model=_recall_model, top_k=10, ks=(1, 5, 10), ) diff --git a/retriever/src/retriever/examples/inprocess_pipeline.py b/retriever/src/retriever/examples/inprocess_pipeline.py index b1b953d02..9eb61875f 100644 --- a/retriever/src/retriever/examples/inprocess_pipeline.py +++ b/retriever/src/retriever/examples/inprocess_pipeline.py @@ -150,6 +150,11 @@ def main( "--embed-invoke-url", help="Optional remote endpoint URL for embedding model inference.", ), + embed_model_name: str = typer.Option( + "nvidia/llama-3.2-nv-embedqa-1b-v2", + "--embed-model-name", + help="Embedding model name passed to .embed().", + ), ) -> None: _ = input_type @@ -169,7 +174,7 @@ def main( ingestor = ( ingestor.files(glob_pattern) .extract_txt(TextChunkParams(max_tokens=512, overlap_tokens=0)) - .embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url)) + .embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url)) .vdb_upload( VdbUploadParams( lancedb={ @@ -187,7 +192,7 @@ def main( ingestor = ( ingestor.files(glob_pattern) .extract_html(TextChunkParams(max_tokens=512, overlap_tokens=0)) - .embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url)) + .embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url)) .vdb_upload( VdbUploadParams( lancedb={ @@ -216,7 +221,7 @@ def main( ocr_invoke_url=ocr_invoke_url, ) ) - .embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url)) + .embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url)) .vdb_upload( VdbUploadParams( lancedb={ @@ -244,7 +249,7 @@ def main( ocr_invoke_url=ocr_invoke_url, ) ) - .embed(EmbedParams(model_name="nemo_retriever_v1", embed_invoke_url=embed_invoke_url)) + .embed(EmbedParams(model_name=str(embed_model_name), embed_invoke_url=embed_invoke_url)) .vdb_upload( VdbUploadParams( lancedb={ @@ -285,10 +290,16 @@ def main( unique_basenames = table.to_pandas()["pdf_basename"].unique() print(f"Unique basenames: {unique_basenames}") + # Resolve the HF model ID for recall query embedding so aliases + # (e.g. "nemo_retriever_v1") map to the correct model. + from retriever.model import resolve_embed_model + + _recall_model = resolve_embed_model(str(embed_model_name)) + cfg = RecallConfig( lancedb_uri=str(LANCEDB_URI), lancedb_table=str(LANCEDB_TABLE), - embedding_model="nvidia/llama-3.2-nv-embedqa-1b-v2", + embedding_model=_recall_model, embedding_http_endpoint=embed_invoke_url, top_k=10, ks=(1, 5, 10), diff --git a/retriever/src/retriever/ingest_modes/batch.py b/retriever/src/retriever/ingest_modes/batch.py index f4339175e..769180cfb 100644 --- a/retriever/src/retriever/ingest_modes/batch.py +++ b/retriever/src/retriever/ingest_modes/batch.py @@ -252,25 +252,34 @@ def __init__(self, params: EmbedParams) -> None: self._model = None return - from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder - device = self._kwargs.get("device") hf_cache_dir = self._kwargs.get("hf_cache_dir") normalize = bool(self._kwargs.get("normalize", True)) max_length = int(self._kwargs.get("max_length", 8192)) - # model_name may be a NIM alias (e.g. "nemo_retriever_v1") or a real HF - # repo ID (e.g. "nvidia/llama-3.2-nv-embedqa-1b-v2"). Only forward it as - # model_id when it looks like an HF repo (contains "/"). model_name_raw = self._kwargs.get("model_name") - model_id = model_name_raw if (isinstance(model_name_raw, str) and "/" in model_name_raw) else None - - self._model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, - ) + + from retriever.model import is_vl_embed_model, resolve_embed_model + + model_id = resolve_embed_model(model_name_raw) + + if is_vl_embed_model(model_name_raw): + from retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder + + self._model = LlamaNemotronEmbedVL1BV2Embedder( + device=str(device) if device else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, + model_id=model_id, + ) + else: + from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder + + self._model = LlamaNemotronEmbed1BV2Embedder( + device=str(device) if device else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir else None, + normalize=normalize, + max_length=max_length, + model_id=model_id, + ) def __call__(self, batch_df: Any) -> Any: from retriever.ingest_modes.inprocess import embed_text_main_text_embed diff --git a/retriever/src/retriever/ingest_modes/inprocess.py b/retriever/src/retriever/ingest_modes/inprocess.py index 436f8b9c5..f3d620e4c 100644 --- a/retriever/src/retriever/ingest_modes/inprocess.py +++ b/retriever/src/retriever/ingest_modes/inprocess.py @@ -27,8 +27,7 @@ import pandas as pd -from retriever.chart.chart_detection import detect_graphic_elements_v1_from_page_elements_v3 # noqa: F401 -from retriever.model.local import NemotronGraphicElementsV1, NemotronOCRV1, NemotronPageElementsV3 # noqa: F401 +from retriever.model.local import NemotronOCRV1, NemotronPageElementsV3 from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder from retriever.page_elements import detect_page_elements_v3 from retriever.ocr.ocr import ocr_page_elements @@ -209,20 +208,23 @@ def embed_text_main_text_embed( if _endpoint is None and model is None: raise ValueError("Either a local model or an embedding_endpoint must be provided.") - # Map NIM aliases to the actual model ID expected by the remote endpoint. - _NIM_MODEL_ALIASES = { - "nemo_retriever_v1": "nvidia/llama-3.2-nv-embedqa-1b-v2", - } - _resolved_model_name = _NIM_MODEL_ALIASES.get(model_name, model_name) if model_name else model_name + # Resolve NIM aliases to the actual HF model ID. + from retriever.model import resolve_embed_model + + _resolved_model_name = resolve_embed_model(model_name) # Build an embedder callable compatible with `create_text_embeddings_for_df`. # Only used when running with a local model (no NIM endpoint). _embed = None if _endpoint is None and model is not None: + # VL model handles formatting internally via encode_documents(); + # embedqa needs an explicit "passage: " prefix. + _skip_prefix = hasattr(model, "embed_queries") + def _embed(texts: Sequence[str]) -> Sequence[Sequence[float]]: # noqa: F811 - prefixed = [f"passage: {t}" for t in texts] - vecs = model.embed(prefixed, batch_size=int(inference_batch_size)) + batch = texts if _skip_prefix else [f"passage: {t}" for t in texts] + vecs = model.embed(batch, batch_size=int(inference_batch_size)) tolist = getattr(vecs, "tolist", None) if callable(tolist): return tolist() @@ -1196,20 +1198,30 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI normalize = bool(embed_kwargs.pop("normalize", True)) max_length = int(embed_kwargs.pop("max_length", 8192)) - # model_name may be a NIM alias (e.g. "nemo_retriever_v1") or a real HF - # repo ID (e.g. "nvidia/llama-3.2-nv-embedqa-1b-v2"). Only forward it as - # model_id when it looks like an HF repo (contains "/"). model_name_raw = embed_kwargs.pop("model_name", None) - model_id = model_name_raw if (isinstance(model_name_raw, str) and "/" in model_name_raw) else None + + from retriever.model import is_vl_embed_model, resolve_embed_model + + model_id = resolve_embed_model(model_name_raw) embed_kwargs.setdefault("input_type", "passage") - embed_kwargs["model"] = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - normalize=normalize, - max_length=max_length, - model_id=model_id, - ) + + if is_vl_embed_model(model_name_raw): + from retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder + + embed_kwargs["model"] = LlamaNemotronEmbedVL1BV2Embedder( + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + model_id=model_id, + ) + else: + embed_kwargs["model"] = LlamaNemotronEmbed1BV2Embedder( + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + normalize=normalize, + max_length=max_length, + model_id=model_id, + ) self._tasks.append((embed_text_main_text_embed, embed_kwargs)) return self diff --git a/retriever/src/retriever/model/__init__.py b/retriever/src/retriever/model/__init__.py index 6aa2e3d5b..dc763d548 100644 --- a/retriever/src/retriever/model/__init__.py +++ b/retriever/src/retriever/model/__init__.py @@ -1,3 +1,35 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +_VL_EMBED_MODEL_IDS = frozenset( + { + "nvidia/llama-nemotron-embed-vl-1b-v2", + "llama-nemotron-embed-vl-1b-v2", + } +) + +# Short name → full HF repo ID. +_EMBED_MODEL_ALIASES: dict[str, str] = { + "nemo_retriever_v1": "nvidia/llama-3.2-nv-embedqa-1b-v2", + "llama-nemotron-embed-vl-1b-v2": "nvidia/llama-nemotron-embed-vl-1b-v2", +} + +_DEFAULT_EMBED_MODEL = "nvidia/llama-3.2-nv-embedqa-1b-v2" + + +def resolve_embed_model(model_name: str | None) -> str: + """Resolve a model name/alias to a full HF repo ID. + + Returns ``_DEFAULT_EMBED_MODEL`` when *model_name* is ``None`` or empty. + """ + if not model_name: + return _DEFAULT_EMBED_MODEL + return _EMBED_MODEL_ALIASES.get(model_name, model_name) + + +def is_vl_embed_model(model_name: str | None) -> bool: + """Return True if *model_name* refers to the VL embedding model.""" + return resolve_embed_model(model_name) in _VL_EMBED_MODEL_IDS diff --git a/retriever/src/retriever/model/local/llama_nemotron_embed_1b_v2_embedder.py b/retriever/src/retriever/model/local/llama_nemotron_embed_1b_v2_embedder.py index a01e4d0e1..dfd5a2a3a 100644 --- a/retriever/src/retriever/model/local/llama_nemotron_embed_1b_v2_embedder.py +++ b/retriever/src/retriever/model/local/llama_nemotron_embed_1b_v2_embedder.py @@ -82,8 +82,23 @@ def _embed_local(self, texts: List[str], *, batch_size: int) -> torch.Tensor: max_length=max(1, int(self.max_length)), return_tensors="pt", ).to(dev) - out = self._model(**batch) - lhs = out.last_hidden_state # [B, S, D] + out = self._model(**batch, output_hidden_states=True) + # The bidirectional model returns BaseModelOutputWithPast + # (last_hidden_state), but some transformers versions or + # model revisions return CausalLMOutputWithPast (hidden_states). + lhs = getattr(out, "last_hidden_state", None) + if lhs is None: + # CausalLMOutputWithPast: use the last layer's hidden state. + hs = getattr(out, "hidden_states", None) + if hs is not None: + lhs = hs[-1] + else: + raise AttributeError( + f"Model output ({type(out).__name__}) has neither " + "'last_hidden_state' nor 'hidden_states'. " + "Ensure the model is loaded with trust_remote_code=True." + ) + # lhs shape: [B, S, D] mask = batch["attention_mask"].unsqueeze(-1) # [B, S, 1] vec = (lhs * mask).sum(dim=1) / mask.sum(dim=1) # [B, D] vec = vec.detach().to("cpu") diff --git a/retriever/src/retriever/model/local/llama_nemotron_embed_vl_1b_v2_embedder.py b/retriever/src/retriever/model/local/llama_nemotron_embed_vl_1b_v2_embedder.py new file mode 100644 index 000000000..fe2735ed3 --- /dev/null +++ b/retriever/src/retriever/model/local/llama_nemotron_embed_vl_1b_v2_embedder.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional, Sequence + +import torch + + +@dataclass +class LlamaNemotronEmbedVL1BV2Embedder: + """ + Text-only embedder wrapper for ``nvidia/llama-nemotron-embed-vl-1b-v2``. + + The VL model exposes ``encode_queries()`` and ``encode_documents()`` + instead of the standard tokenizer + forward pass used by the embedqa + model. This class uses only the text modality. + """ + + device: Optional[str] = None + hf_cache_dir: Optional[str] = None + model_id: Optional[str] = None + + # Populated in __post_init__ + _model: Any = field(default=None, init=False, repr=False) + _device: Any = field(default=None, init=False, repr=False) + + def __post_init__(self) -> None: + from transformers import AutoModel + + model_id = self.model_id or "nvidia/llama-nemotron-embed-vl-1b-v2" + dev = torch.device(self.device or ("cuda" if torch.cuda.is_available() else "cpu")) + hf_cache_dir = self.hf_cache_dir or str(Path.home() / ".cache" / "huggingface") + + # flash_attention_2 requires the model on GPU at init time, so use + # device_map when requesting it. Fall back to sdpa/eager on CPU or + # when flash-attn is not installed. + use_gpu = dev.type == "cuda" + for attn_impl in ("flash_attention_2", "sdpa", "eager"): + try: + kwargs: dict[str, Any] = { + "trust_remote_code": True, + "torch_dtype": torch.bfloat16, + "attn_implementation": attn_impl, + "cache_dir": hf_cache_dir, + } + if attn_impl == "flash_attention_2" and use_gpu: + kwargs["device_map"] = dev + self._model = AutoModel.from_pretrained(model_id, **kwargs) + break + except (ValueError, ImportError): + if attn_impl == "eager": + raise + continue + + if not hasattr(self._model, "device_map"): + self._model = self._model.to(dev) + self._model.eval() + self._device = dev + + @property + def is_remote(self) -> bool: + return False + + def embed(self, texts: Sequence[str], *, batch_size: int = 64) -> torch.Tensor: + """Embed document texts. Returns CPU tensor ``[N, 2048]``.""" + texts_list = [str(t) for t in texts if str(t).strip()] + if not texts_list: + return torch.empty((0, 2048), dtype=torch.float32) + with torch.inference_mode(): + out = self._model.encode_documents(texts=texts_list) + if isinstance(out, torch.Tensor): + return out.detach().cpu().float() + return torch.as_tensor(out, dtype=torch.float32).cpu() + + def embed_queries(self, texts: Sequence[str], *, batch_size: int = 64) -> torch.Tensor: + """Embed query strings. Returns CPU tensor ``[N, 2048]``.""" + texts_list = [str(t) for t in texts] + if not texts_list: + return torch.empty((0, 2048), dtype=torch.float32) + with torch.inference_mode(): + out = self._model.encode_queries(texts_list) + if isinstance(out, torch.Tensor): + return out.detach().cpu().float() + return torch.as_tensor(out, dtype=torch.float32).cpu() diff --git a/retriever/src/retriever/recall/core.py b/retriever/src/retriever/recall/core.py index 9d8ed8ce9..75cc3b296 100644 --- a/retriever/src/retriever/recall/core.py +++ b/retriever/src/retriever/recall/core.py @@ -138,12 +138,26 @@ def _embed_queries_local_hf( device: Optional[str], cache_dir: Optional[str], batch_size: int, + model_name: Optional[str] = None, ) -> List[List[float]]: # Lazy import: only load torch/HF when needed. - from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder + from retriever.model import is_vl_embed_model, resolve_embed_model - embedder = LlamaNemotronEmbed1BV2Embedder(device=device, hf_cache_dir=cache_dir, normalize=True) - vecs = embedder.embed(["query: " + q for q in queries], batch_size=int(batch_size)) + model_id = resolve_embed_model(model_name) + + if is_vl_embed_model(model_name): + from retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder + + embedder = LlamaNemotronEmbedVL1BV2Embedder(device=device, hf_cache_dir=cache_dir, model_id=model_id) + # VL model handles query formatting internally via encode_queries(). + vecs = embedder.embed_queries(queries, batch_size=int(batch_size)) + else: + from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder + + embedder = LlamaNemotronEmbed1BV2Embedder( + device=device, hf_cache_dir=cache_dir, normalize=True, model_id=model_id + ) + vecs = embedder.embed(["query: " + q for q in queries], batch_size=int(batch_size)) # Ensure list-of-list floats. return vecs.detach().to("cpu").tolist() @@ -275,6 +289,7 @@ def retrieve_and_score( device=cfg.local_hf_device, cache_dir=cfg.local_hf_cache_dir, batch_size=int(cfg.local_hf_batch_size), + model_name=cfg.embedding_model, ) raw_hits = _search_lancedb( lancedb_uri=cfg.lancedb_uri, diff --git a/retriever/src/retriever/text_embed/text_embed.py b/retriever/src/retriever/text_embed/text_embed.py index 2650d1873..20e9b9ee7 100644 --- a/retriever/src/retriever/text_embed/text_embed.py +++ b/retriever/src/retriever/text_embed/text_embed.py @@ -104,7 +104,11 @@ def embed_text_1b_v2( # Keep placeholder but mark as "no text". payloads[i] = {"embedding": None, "error": None} continue - texts.append(f"{input_type}: {txt}" if input_type else txt) + # VL model handles formatting internally; skip prefix. + if input_type and not hasattr(model, "embed_queries"): + texts.append(f"{input_type}: {txt}") + else: + texts.append(txt) text_row_idxs.append(i) except BaseException as e: payloads[i] = _error_payload(stage="extract_text", exc=e) @@ -181,19 +185,31 @@ class TextEmbedActor: def __init__(self, **detect_kwargs: Any) -> None: self.detect_kwargs = dict(detect_kwargs) - from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder device = self.detect_kwargs.pop("device", None) hf_cache_dir = self.detect_kwargs.pop("hf_cache_dir", None) normalize = bool(self.detect_kwargs.pop("normalize", True)) max_length = self.detect_kwargs.pop("max_length", 4096) + model_name = self.detect_kwargs.get("model_name") - self._model = LlamaNemotronEmbed1BV2Embedder( - device=str(device) if device is not None else None, - hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, - normalize=normalize, - max_length=int(max_length), - ) + from retriever.model import is_vl_embed_model + + if is_vl_embed_model(model_name): + from retriever.model.local.llama_nemotron_embed_vl_1b_v2_embedder import LlamaNemotronEmbedVL1BV2Embedder + + self._model = LlamaNemotronEmbedVL1BV2Embedder( + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + ) + else: + from retriever.model.local.llama_nemotron_embed_1b_v2_embedder import LlamaNemotronEmbed1BV2Embedder + + self._model = LlamaNemotronEmbed1BV2Embedder( + device=str(device) if device is not None else None, + hf_cache_dir=str(hf_cache_dir) if hf_cache_dir is not None else None, + normalize=normalize, + max_length=int(max_length), + ) def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: