Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions nemo_curator/stages/text/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pandas as pd
import torch
import torch.nn.functional as F # noqa: N812
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoModel

from nemo_curator.backends.base import WorkerMetadata
Expand Down Expand Up @@ -98,9 +100,49 @@ def _get_last_token(self, model_output: torch.Tensor, attention_mask: torch.Tens
return F.normalize(last_token_embeddings, dim=1)


class SentenceTransformerEmbeddingModelStage(EmbeddingModelStage):
def __init__( # noqa: PLR0913
self,
model_identifier: str,
embedding_field: str = "embeddings",
hf_token: str | None = None,
model_inference_batch_size: int = 1024,
has_seq_order: bool = True,
padding_side: Literal["left", "right"] = "right",
autocast: bool = True,
):
super().__init__(
model_identifier=model_identifier,
hf_token=hf_token,
model_inference_batch_size=model_inference_batch_size,
has_seq_order=has_seq_order,
padding_side=padding_side,
autocast=autocast,
)
Comment on lines +114 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent class EmbeddingModelStage.__init__() is called without passing a pooling parameter, which means the parent's default pooling value will be used. While this doesn't cause issues since SentenceTransformerEmbeddingModelStage doesn't use the pooling attribute, consider adding a comment to clarify this is intentional.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

# Override unpack_inference_batch to False as SentenceTransformer expects a dictionary input
self.unpack_inference_batch = False
self.embedding_field = embedding_field

def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], [self.embedding_field]

def setup(self, _: WorkerMetadata | None = None) -> None:
"""Load the model for inference."""
self.model = SentenceTransformer(self.model_identifier, local_files_only=True)
self.model.eval().to("cuda")

def process_model_output(
self,
outputs: torch.Tensor,
model_input_batch: dict[str, torch.Tensor] | None = None, # noqa: ARG002
) -> torch.Tensor:
return outputs["sentence_embedding"].cpu()


@dataclass(kw_only=True)
class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]):
model_identifier: str = "sentence-transformers/all-MiniLM-L6-v2"
use_sentence_transformer: bool = True
text_field: str = "text"
embedding_field: str = "embeddings"
max_chars: int | None = None
Expand All @@ -115,6 +157,16 @@ class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]):
def __post_init__(self) -> None:
super().__init__()

model_class = SentenceTransformerEmbeddingModelStage if self.use_sentence_transformer else EmbeddingModelStage

if self.use_sentence_transformer:
logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in warning message: "ignoring" should be a complete sentence. Consider:

Suggested change
logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling")
logger.warning("Using SentenceTransformer for embedding model; ignoring embedding_pooling parameter")

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

model_additional_kwargs = {}
else:
model_additional_kwargs = {
"pooling": self.embedding_pooling,
}

self.stages = [
TokenizerStage(
model_identifier=self.model_identifier,
Expand All @@ -125,15 +177,15 @@ def __post_init__(self) -> None:
padding_side=self.padding_side,
sort_by_length=self.sort_by_length,
),
EmbeddingModelStage(
model_class(
model_identifier=self.model_identifier,
embedding_field=self.embedding_field,
pooling=self.embedding_pooling,
hf_token=self.hf_token,
model_inference_batch_size=self.model_inference_batch_size,
has_seq_order=self.sort_by_length,
padding_side=self.padding_side,
autocast=self.autocast,
**model_additional_kwargs,
),
]

Expand Down
138 changes: 138 additions & 0 deletions nemo_curator/stages/text/embedders/vllm.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the file shouldn't be called vllm.py, to avoid possible import weirdness with the vllm library.

Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import TYPE_CHECKING, Any

from huggingface_hub import snapshot_download
from vllm import LLM

from nemo_curator.backends.base import NodeInfo, WorkerMetadata
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources
from nemo_curator.stages.text.models.utils import format_name_with_suffix
from nemo_curator.tasks import DocumentBatch

if TYPE_CHECKING:
from transformers import AutoTokenizer


class VLLMEmbeddingModelStage(ProcessingStage[DocumentBatch, DocumentBatch]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand why we have a separate VLLMEmbeddingModelStage .

From a user workflow perspective, should we not always use the same interface and let user choose the backend ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like, when the eventual switch happens for VLLM as default, this will make users/tutorials go and change their workflow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should we not import this logic for ModelStage here ???

def __init__( # noqa: PLR0913
self,
model_identifier: str,
vllm_init_kwargs: dict[str, Any] | None = None,
text_field: str = "text",
pretokenize: bool = False,
embedding_field: str = "embeddings",
cache_dir: str | None = None,
hf_token: str | None = None,
verbose: bool = False,
):
self.model_identifier = model_identifier
self.vllm_init_kwargs = vllm_init_kwargs or {}

self.text_field = text_field
self.pretokenize = pretokenize
self.embedding_field = embedding_field

self.cache_dir = cache_dir
self.hf_token = hf_token

self.verbose = verbose
# after setup
self.model: None | LLM = None
self.tokenizer: None | AutoTokenizer = None
# stage setup
self.resources = Resources(
cpus=1,
gpus=1,
)
self.name = format_name_with_suffix(model_identifier, suffix="_vllm")

def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], [self.text_field]

def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], [self.text_field, self.embedding_field]

def _initialize_vllm(self) -> None:
vllm_init_kwargs = self.vllm_init_kwargs.copy()
# set defaults here
if "enforce_eager" not in vllm_init_kwargs:
vllm_init_kwargs["enforce_eager"] = False
if "runner" not in vllm_init_kwargs:
vllm_init_kwargs["runner"] = "pooling"
if "model_impl" not in vllm_init_kwargs:
# TODO: Once transformers is bumpted to 5.0 then we should also support transformers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: 'bumpted' should be 'bumped'

Suggested change
# TODO: Once transformers is bumpted to 5.0 then we should also support transformers
# TODO: Once transformers is bumped to 5.0 then we should also support transformers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "bumpted" should be "bumped":

Suggested change
# TODO: Once transformers is bumpted to 5.0 then we should also support transformers
# TODO: Once transformers is bumped to 5.0 then we should also support transformers

vllm_init_kwargs["model_impl"] = "vllm"

# Reduce verbosity when not in verbose mode
if not self.verbose and "disable_log_stats" not in vllm_init_kwargs:
vllm_init_kwargs["disable_log_stats"] = True

self.model = LLM(model=self.model_identifier, **vllm_init_kwargs)
Copy link
Contributor

@sarahyurick sarahyurick Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the parameters I was experimenting with:

self.model = vllm.LLM(
    model=self.model_identifier,
    max_num_seqs=self.model_inference_batch_size,
    max_num_batched_tokens=self.max_num_batched_tokens,
    gpu_memory_utilization=self.gpu_memory_utilization,
    # task="embed",
    tensor_parallel_size=1,
    enforce_eager=False,
    disable_log_stats=True,
    # runner="pooling",
)

I am wondering if we should include this logic too:

if "max_num_seqs" not in vllm_init_kwargs:
    vllm_init_kwargs["max_num_seqs"] = self.model_inference_batch_size


def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002
if not self.verbose:
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()

snapshot_download(self.model_identifier, cache_dir=self.cache_dir, token=self.hf_token, local_files_only=False)
self._initialize_vllm()

def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002
if self.model is None:
self._initialize_vllm()
if self.pretokenize:
from transformers import AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(self.model_identifier)

def process(self, batch: DocumentBatch) -> DocumentBatch:
df = batch.to_pandas()
input_data = df[self.text_field].tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_data = df[self.text_field].tolist()
input_data = df[self.text_field].tolist()
if self.max_chars is not None:
input_data = [input_data[:self.max_chars] for text in input_data]

?

metrics = {}

if self.pretokenize:
from vllm.inputs import TokensPrompt
Comment on lines +99 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to top-level imports?


if self.tokenizer is None:
msg = (
"Tokenizer is not initialized. Please call setup() before processing or set pretokenize to False."
)
raise ValueError(msg)

t0 = time.perf_counter()
max_model_len = self.model.model_config.max_model_len
tokenized_data = self.tokenizer.batch_encode_plus(input_data, truncation=True, max_length=max_model_len)
Comment on lines +119 to +120
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that max_model_len is set to a reasonable value before using it for truncation. If the model config has an unusually large or unset max_model_len, the tokenization could produce unexpected results.

input_data = [TokensPrompt(prompt_token_ids=ids) for ids in tokenized_data.input_ids]
metrics["tokenization_time"] = time.perf_counter() - t0

t0 = time.perf_counter()
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an advantage to doing tokenization then the model forward pass within the same stage? I was thinking that we should either break tokenization into a separate CPU-only stage or always let self.model.embed(...) handle it like:

Suggested change
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose)
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=self.max_seq_length, use_tqdm=self.verbose)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter value -1 for truncate_prompt_tokens may cause unexpected behavior. According to vLLM documentation, this should be either a positive integer or None for no truncation. Consider using None if truncation should be disabled, or verify that -1 is the intended value for this vLLM version.

metrics["vllm_embedding_time"] = time.perf_counter() - t0

df[self.embedding_field] = [e.outputs.embedding for e in vllm_output]
Comment on lines +124 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling for the model.embed() call. If the model fails to generate embeddings (e.g., due to malformed inputs, resource exhaustion, or vLLM errors), the stage will crash without a helpful error message. Consider wrapping this in a try-except block with informative error messaging.


self._log_metrics(metrics)

return DocumentBatch(
task_id=batch.task_id,
dataset_name=batch.dataset_name,
data=df,
_metadata=batch._metadata,
_stage_perf=batch._stage_perf,
)
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ dependencies = [

[project.optional-dependencies]
cuda12 = ["gpustat", "nvidia-ml-py"]
vllm = ["vllm>=0.13; (platform_machine == 'x86_64' and platform_system != 'Darwin')"]

# Installs CPU + GPU text curation modules
deduplication_cuda12 = [
"cudf-cu12==25.10.*",
"cuml-cu12==25.10.*",
"scikit-learn<1.8.0", # cuml 25.10.0 is incompatible with scikit-learn 1.8.0
"pylibcugraph-cu12==25.10.*",
"pylibraft-cu12==25.10.*",
"raft-dask-cu12==25.10.*",
Expand Down Expand Up @@ -124,6 +126,8 @@ text_cuda12 = [
"nemo_curator[cuda12]",
"nemo_curator[deduplication_cuda12]",
"nemo_curator[text_cpu]",
"nemo_curator[vllm]",
"sentence-transformers"
]

# Video Curation Dependencies
Expand All @@ -138,13 +142,13 @@ video_cpu = [
video_cuda12 = [
"nemo_curator[video_cpu]",
"nemo_curator[cuda12]",
"nemo_curator[vllm]",
"cvcuda_cu12",
"flash-attn<=2.8.3; (platform_machine == 'x86_64' and platform_system != 'Darwin')",
"pycuda",
"PyNvVideoCodec==2.0.2; (platform_machine == 'x86_64' and platform_system != 'Darwin')",
"torch<=2.9.0",
"torch<=2.9.1",
"torchaudio",
"vllm==0.11.1; (platform_machine == 'x86_64' and platform_system != 'Darwin')",
]

# All dependencies
Expand All @@ -156,7 +160,7 @@ all = [
]

[dependency-groups]
build = ["setuptools", "torch<=2.9.0"]
build = ["setuptools", "torch<=2.9.1"]
dev = ["jupyter"]
linting = ["pre-commit", "ruff==0.11.4"]
test = [
Expand All @@ -166,7 +170,7 @@ test = [
"pytest-asyncio",
"pytest-cov",
"pytest-loguru",
"scikit-learn",
"scikit-learn<1.8.0", # cuml 25.10.0 is incompatible with scikit-learn 1.8.0
"s3fs", # added for testing cloud fs
]

Expand Down
9 changes: 7 additions & 2 deletions tests/backends/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,18 @@ def test_ray_data_execution_plan(self):
execution_plan_stages = [stage.strip() for stage in stages]
# Tasks can get fused with Actors, but Actors can't get fused with Tasks or Actors
# StreamingRepartition should never get fused

if ray.__version__ >= "2.53.0":
streaming_repartition = "StreamingRepartition[num_rows_per_block=1]"
else:
streaming_repartition = "StreamingRepartition"
expected_stages = [
"InputDataBuffer[Input]",
"TaskPoolMapOperator[MapBatches(FilePartitioningStageTask)]",
"TaskPoolMapOperator[StreamingRepartition]",
f"TaskPoolMapOperator[{streaming_repartition}]",
"ActorPoolMapOperator[MapBatches(JsonlReaderStageTask)->MapBatches(AddLengthStageActor)]",
"ActorPoolMapOperator[MapBatches(SplitIntoRowsStageActor)]",
"TaskPoolMapOperator[StreamingRepartition]",
f"TaskPoolMapOperator[{streaming_repartition}]",
"ActorPoolMapOperator[MapBatches(AddLengthStageActor)]",
"ActorPoolMapOperator[MapBatches(StageWithSetupActor)]",
"TaskPoolMapOperator[MapBatches(JsonlWriterTask)]",
Expand Down
11 changes: 10 additions & 1 deletion tests/stages/text/embedders/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def sample_data(self) -> DocumentBatch:
def test_embedding_creator_stage_initialization_and_decomposition(self) -> None:
"""Test initialization, decomposition, and parameter passing to decomposed stages."""
# Test with custom parameters including hf_token and unk_token
# Note: use_sentence_transformer=False is required to test EmbeddingModelStage with custom pooling
stage = EmbeddingCreatorStage(
model_identifier="test-model",
text_field="content",
Expand All @@ -228,6 +229,7 @@ def test_embedding_creator_stage_initialization_and_decomposition(self) -> None:
model_inference_batch_size=128,
sort_by_length=False,
hf_token="test-token", # noqa:S106
use_sentence_transformer=False,
)

# Test decomposition and stage types
Expand Down Expand Up @@ -295,15 +297,22 @@ def test_embedding_creator_stage_process_integration(self) -> None:

@pytest.mark.parametrize("pooling_strategy", ["mean_pooling", "last_token"])
@pytest.mark.parametrize("autocast", [True, False])
@pytest.mark.parametrize("use_sentence_transformer", [True, False])
@pytest.mark.gpu
def test_embedding_creator_stage_with_reference_embeddings(
self, pooling_strategy: str, sample_data: DocumentBatch, autocast: bool
self, pooling_strategy: str, sample_data: DocumentBatch, autocast: bool, use_sentence_transformer: bool
) -> None:
"""Test embeddings match reference implementation (requires GPU and model download)."""
if use_sentence_transformer and pooling_strategy != "mean_pooling":
pytest.skip(
"Ignoring last_token strategy for sentence transformer as behavior for miniLM is mean pooling "
)
stage = EmbeddingCreatorStage(
model_identifier="sentence-transformers/all-MiniLM-L6-v2",
embedding_pooling=pooling_strategy,
model_inference_batch_size=32,
autocast=autocast,
use_sentence_transformer=use_sentence_transformer,
)

# Decompose and setup stages
Expand Down
Loading
Loading