Skip to content

Commit 044a445

Browse files
committed
First pass of embedding vLLM backend
1 parent bb62146 commit 044a445

File tree

8 files changed

+356
-14
lines changed

8 files changed

+356
-14
lines changed

nemo_retriever/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,45 @@ To stop and remove both stacks:
122122
docker compose -p ingest-gpu0 down
123123
docker compose -p ingest-gpu1 down
124124
```
125+
126+
## Embedding backends
127+
128+
Embeddings can be served by a **remote HTTP endpoint** (NIM, vLLM, or any OpenAI-compatible server) or by a **local HuggingFace model** when no endpoint is configured.
129+
130+
- **Config**: Set `embedding_nim_endpoint` in `ingest-config.yaml` or stage config (e.g. `http://localhost:8000/v1`). Leave empty or null to use the local HF embedder.
131+
- **CLI**: Use `--embed-invoke-url` (inprocess/batch pipelines) or `--embedding-endpoint` / `--embedding-http-endpoint` (recall CLI) to point at a remote server.
132+
133+
### Using vLLM for embeddings
134+
135+
You can serve an embedding model with [vLLM](https://docs.vllm.ai/) and point the retriever at it. vLLM exposes an OpenAI-compatible `/v1/embeddings` API. Set the embedding endpoint to the vLLM base URL (e.g. `http://localhost:8000/v1`).
136+
137+
**vLLM compatibility**: The default NIM-style client sends `input_type` and `truncate` in the request body; some vLLM versions or configs may not accept these. When using a **vLLM** server, enable the vLLM-compatible payload:
138+
139+
- **Ingest**: `--embed-use-vllm-compat` (inprocess pipeline) or set `embed_use_vllm_compat: true` in `EmbedParams`.
140+
- **Recall**: `--embedding-use-vllm-compat` (recall CLI).
141+
142+
This sends only `model`, `input`, and `encoding_format` (minimal OpenAI-compatible payload).
143+
144+
### llama-nemotron-embed-1b-v2 with vLLM
145+
146+
For **nvidia/llama-nemotron-embed-1b-v2**, follow the model’s official vLLM instructions:
147+
148+
1. Use **vllm==0.11.0**.
149+
2. Clone the [model repo](https://huggingface.co/nvidia/llama-nemotron-embed-1b-v2) and **overwrite `config.json` with `config_vllm.json`** from that repo.
150+
3. Start the server (replace `<path_to_the_cloned_repository>` and `<num_gpus_to_use>`):
151+
152+
```bash
153+
vllm serve \
154+
<path_to_the_cloned_repository> \
155+
--trust-remote-code \
156+
--runner pooling \
157+
--model-impl vllm \
158+
--override-pooler-config '{"pooling_type": "MEAN"}' \
159+
--data-parallel-size <num_gpus_to_use> \
160+
--dtype float32 \
161+
--port 8000
162+
```
163+
164+
4. Set the retriever embedding endpoint to `http://localhost:8000/v1` and use `--embed-use-vllm-compat` / `--embedding-use-vllm-compat` as above.
165+
166+
See the [model README](https://huggingface.co/nvidia/llama-nemotron-embed-1b-v2) for the canonical vLLM setup and client example.

nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ def main(
173173
min=0.0,
174174
help="Parse stage batch size (enables Parse-only path when > 0.0 with parse workers/GPU).",
175175
),
176+
embed_use_vllm_compat: bool = typer.Option(
177+
False,
178+
"--embed-use-vllm-compat/--no-embed-use-vllm-compat",
179+
help="Use vLLM-compatible HTTP payload for embeddings (no input_type/truncate)."
180+
"Set when --embed-invoke-url is a vLLM server.",
181+
),
176182
embed_modality: str = typer.Option(
177183
"text",
178184
"--embed-modality",
@@ -212,6 +218,7 @@ def main(
212218
EmbedParams(
213219
model_name=str(embed_model_name),
214220
embed_invoke_url=embed_invoke_url,
221+
embed_use_vllm_compat=embed_use_vllm_compat,
215222
embed_modality=embed_modality,
216223
text_elements_modality=text_elements_modality,
217224
structured_elements_modality=structured_elements_modality,
@@ -238,6 +245,7 @@ def main(
238245
EmbedParams(
239246
model_name=str(embed_model_name),
240247
embed_invoke_url=embed_invoke_url,
248+
embed_use_vllm_compat=embed_use_vllm_compat,
241249
embed_modality=embed_modality,
242250
text_elements_modality=text_elements_modality,
243251
structured_elements_modality=structured_elements_modality,
@@ -280,6 +288,7 @@ def main(
280288
EmbedParams(
281289
model_name=str(embed_model_name),
282290
embed_invoke_url=embed_invoke_url,
291+
embed_use_vllm_compat=embed_use_vllm_compat,
283292
embed_modality=embed_modality,
284293
text_elements_modality=text_elements_modality,
285294
structured_elements_modality=structured_elements_modality,
@@ -321,6 +330,7 @@ def main(
321330
EmbedParams(
322331
model_name=str(embed_model_name),
323332
embed_invoke_url=embed_invoke_url,
333+
embed_use_vllm_compat=embed_use_vllm_compat,
324334
embed_modality=embed_modality,
325335
text_elements_modality=text_elements_modality,
326336
structured_elements_modality=structured_elements_modality,
@@ -379,6 +389,7 @@ def main(
379389
embedding_http_endpoint=embed_invoke_url,
380390
top_k=10,
381391
ks=(1, 5, 10),
392+
embedding_use_vllm_compat=bool(embed_use_vllm_compat),
382393
)
383394

384395
_df_query, _gold, _raw_hits, _retrieved_keys, metrics = retrieve_and_score(query_csv=query_csv, cfg=cfg)

nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def _embed_group(
240240
inference_batch_size: int,
241241
output_column: str,
242242
resolved_model_name: str,
243+
use_vllm_compat: bool = False,
243244
) -> pd.DataFrame:
244245
"""Embed a single modality group via ``create_text_embeddings_for_df``.
245246
@@ -285,14 +286,17 @@ def _embed(texts: Sequence[str]) -> Sequence[Sequence[float]]: # noqa: F811
285286
embed_modality=group_modality,
286287
)
287288

289+
task_config = {
290+
"embedder": _embed,
291+
"multimodal_embedder": _multimodal_embedder,
292+
"endpoint_url": endpoint,
293+
"local_batch_size": int(inference_batch_size),
294+
}
295+
if use_vllm_compat:
296+
task_config["use_vllm_compat"] = True
288297
out_df, _ = create_text_embeddings_for_df(
289298
group_df,
290-
task_config={
291-
"embedder": _embed,
292-
"multimodal_embedder": _multimodal_embedder,
293-
"endpoint_url": endpoint,
294-
"local_batch_size": int(inference_batch_size),
295-
},
299+
task_config=task_config,
296300
transform_config=cfg,
297301
)
298302
return out_df
@@ -307,6 +311,7 @@ def embed_text_main_text_embed(
307311
model_name: Optional[str] = None,
308312
embedding_endpoint: Optional[str] = None,
309313
embed_invoke_url: Optional[str] = None,
314+
embed_use_vllm_compat: bool = False,
310315
text_column: str = "text",
311316
inference_batch_size: int = 16,
312317
output_column: str = "text_embeddings_1b_v2",
@@ -372,6 +377,7 @@ def embed_text_main_text_embed(
372377
inference_batch_size=inference_batch_size,
373378
output_column=output_column,
374379
resolved_model_name=_resolved_model_name,
380+
use_vllm_compat=bool(embed_use_vllm_compat),
375381
)
376382
else:
377383
# Multiple modalities: group, embed each, reassemble in original order.
@@ -390,6 +396,7 @@ def embed_text_main_text_embed(
390396
inference_batch_size=inference_batch_size,
391397
output_column=output_column,
392398
resolved_model_name=_resolved_model_name,
399+
use_vllm_compat=bool(embed_use_vllm_compat),
393400
)
394401
parts.append(part)
395402
out_df = pd.concat(parts).sort_index()

nemo_retriever/src/nemo_retriever/params/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class EmbedParams(_ParamsModel):
181181
model_name: Optional[str] = None
182182
embedding_endpoint: Optional[str] = None
183183
embed_invoke_url: Optional[str] = None
184+
embed_use_vllm_compat: bool = False # Use vLLM-compatible HTTP payload when using remote endpoint
184185
input_type: str = "passage"
185186
embed_modality: str = "text" # "text", "image", or "text_image" — default for all element types
186187
text_elements_modality: Optional[str] = None # per-type override for page-text rows

nemo_retriever/src/nemo_retriever/recall/core.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class RecallConfig:
4646
local_hf_device: Optional[str] = None
4747
local_hf_cache_dir: Optional[str] = None
4848
local_hf_batch_size: int = 64
49+
# When True and an HTTP embedding endpoint is set, use vLLM-compatible minimal
50+
# payload (no input_type/truncate). Set this when the endpoint is a vLLM server.
51+
embedding_use_vllm_compat: bool = False
4952

5053

5154
def _normalize_query_df(df: pd.DataFrame) -> pd.DataFrame:
@@ -106,6 +109,28 @@ def _resolve_embedding_endpoint(cfg: RecallConfig) -> Tuple[Optional[str], Optio
106109
return None, None
107110

108111

112+
def _embed_queries_vllm_http(
113+
queries: List[str],
114+
*,
115+
endpoint: str,
116+
model: str,
117+
api_key: str,
118+
batch_size: int = 256,
119+
) -> List[List[float]]:
120+
"""Embed queries via vLLM-compatible HTTP (minimal payload, no input_type/truncate)."""
121+
from nemo_retriever.text_embed.vllm_http import embed_via_vllm_http
122+
123+
# llama-nemotron-embed-1b-v2 expects "query: " prefix for queries (see model README).
124+
return embed_via_vllm_http(
125+
queries,
126+
endpoint_url=endpoint,
127+
model_name=model,
128+
api_key=(api_key or "").strip() or None,
129+
batch_size=batch_size,
130+
prefix="query: ",
131+
)
132+
133+
109134
def _embed_queries_nim(
110135
queries: List[str],
111136
*,
@@ -297,13 +322,22 @@ def retrieve_and_score(
297322

298323
endpoint, use_grpc = _resolve_embedding_endpoint(cfg)
299324
if endpoint is not None and use_grpc is not None:
300-
vectors = _embed_queries_nim(
301-
queries,
302-
endpoint=endpoint,
303-
model=cfg.embedding_model,
304-
api_key=cfg.embedding_api_key,
305-
grpc=bool(use_grpc),
306-
)
325+
if bool(cfg.embedding_use_vllm_compat) and not use_grpc:
326+
vectors = _embed_queries_vllm_http(
327+
queries,
328+
endpoint=endpoint,
329+
model=cfg.embedding_model,
330+
api_key=cfg.embedding_api_key,
331+
batch_size=256,
332+
)
333+
else:
334+
vectors = _embed_queries_nim(
335+
queries,
336+
endpoint=endpoint,
337+
model=cfg.embedding_model,
338+
api_key=cfg.embedding_api_key,
339+
grpc=bool(use_grpc),
340+
)
307341
else:
308342
vectors = _embed_queries_local_hf(
309343
queries,

nemo_retriever/src/nemo_retriever/recall/vdb_recall.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def recall_with_main(
131131
min=1,
132132
help="Batch size for local HF embedding inference.",
133133
),
134+
embedding_use_vllm_compat: bool = typer.Option(
135+
False,
136+
"--embedding-use-vllm-compat/--no-embedding-use-vllm-compat",
137+
help="Use vLLM-compatible HTTP payload (no input_type/truncate). Set when endpoint is a vLLM server.",
138+
),
134139
) -> None:
135140
query_csv = _resolve_query_csv(Path(query_csv))
136141

@@ -155,6 +160,7 @@ def recall_with_main(
155160
local_hf_device=_coerce_endpoint_str(local_hf_device),
156161
local_hf_cache_dir=(str(local_hf_cache_dir) if local_hf_cache_dir is not None else None),
157162
local_hf_batch_size=int(local_hf_batch_size),
163+
embedding_use_vllm_compat=bool(embedding_use_vllm_compat),
158164
)
159165

160166
print("Reading and normalizing query CSV...")
@@ -251,6 +257,11 @@ def run(
251257
min=1,
252258
help="Batch size for local HF embedding inference.",
253259
),
260+
embedding_use_vllm_compat: bool = typer.Option(
261+
False,
262+
"--embedding-use-vllm-compat/--no-embedding-use-vllm-compat",
263+
help="Use vLLM-compatible HTTP payload (no input_type/truncate). Set when endpoint is a vLLM server.",
264+
),
254265
print_hits: bool = typer.Option(True, "--print-hits/--no-print-hits", help="Print top-k hits per query."),
255266
) -> None:
256267
"""
@@ -282,6 +293,7 @@ def run(
282293
local_hf_device=_coerce_endpoint_str(local_hf_device),
283294
local_hf_cache_dir=(str(local_hf_cache_dir) if local_hf_cache_dir is not None else None),
284295
local_hf_batch_size=int(local_hf_batch_size),
296+
embedding_use_vllm_compat=bool(embedding_use_vllm_compat),
285297
)
286298

287299
df_query, gold, raw_hits, retrieved_keys, metrics = retrieve_and_score(

nemo_retriever/src/nemo_retriever/text_embed/main_text_embed.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,42 @@ def _async_runner(
455455
return flat_results
456456

457457

458+
def _vllm_compat_runner(
459+
prompts: List[List[str]],
460+
api_key: Optional[str],
461+
endpoint_url: str,
462+
embedding_model: str,
463+
encoding_format: str,
464+
dimensions: Optional[int] = None,
465+
batch_size: int = 256,
466+
) -> dict:
467+
"""
468+
Request embeddings using vLLM-compatible minimal payload (no input_type/truncate).
469+
Returns the same {"embeddings": [...], "info_msgs": [...]} shape as _async_runner.
470+
"""
471+
from nemo_retriever.text_embed.vllm_http import embed_via_vllm_http
472+
473+
flat_prompts: List[str] = []
474+
for batch in prompts:
475+
flat_prompts.extend(batch)
476+
if not flat_prompts:
477+
return {"embeddings": [], "info_msgs": []}
478+
# llama-nemotron-embed-1b-v2 expects "passage: " for documents (see model README).
479+
vectors = embed_via_vllm_http(
480+
flat_prompts,
481+
endpoint_url=endpoint_url,
482+
model_name=embedding_model,
483+
api_key=api_key,
484+
dimensions=dimensions,
485+
encoding_format=encoding_format,
486+
batch_size=batch_size,
487+
prefix="passage: ",
488+
)
489+
# Normalize to list of list (or None for missing)
490+
embeddings = [v if v else None for v in vectors]
491+
return {"embeddings": embeddings, "info_msgs": [None] * len(embeddings)}
492+
493+
458494
def _callable_runner(
459495
prompts: List[List[str]],
460496
*,
@@ -656,7 +692,17 @@ def _text_image_content(r: pd.Series) -> Optional[str]:
656692
filtered_content_list, batch_size=int(transform_config.batch_size)
657693
)
658694

659-
if endpoint_url:
695+
if endpoint_url and task_config.get("use_vllm_compat"):
696+
content_embeddings = _vllm_compat_runner(
697+
filtered_content_batches,
698+
api_key=api_key,
699+
endpoint_url=str(endpoint_url),
700+
embedding_model=str(model_name),
701+
encoding_format=str(transform_config.encoding_format),
702+
dimensions=dimensions,
703+
batch_size=int(transform_config.batch_size),
704+
)
705+
elif endpoint_url:
660706
content_embeddings = _async_runner(
661707
filtered_content_batches,
662708
api_key,

0 commit comments

Comments
 (0)