Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions nemo_curator/stages/text/embedders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pandas as pd
import torch
import torch.nn.functional as F # noqa: N812
from loguru import logger
from sentence_transformers import SentenceTransformer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] This unconditional import will cause ImportError for users who install without the text_cuda12 extra. The import should be conditional or guarded since sentence-transformers is in optional dependencies. Consider importing inside the class methods or using a try/except block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: unconditional import breaks users without text_cpu/text_cuda12 extras. since __init__.py imports EmbeddingCreatorStage from this file, importing the embedders package will fail with ImportError: No module named 'sentence_transformers'

wrap in try/except or use conditional import:

Suggested change
from sentence_transformers import SentenceTransformer
try:
from sentence_transformers import SentenceTransformer
except ImportError:
SentenceTransformer = None

then add validation in SentenceTransformerEmbeddingModelStage.__init__:

if SentenceTransformer is None:
    raise ImportError("sentence-transformers required. Install with: pip install nemo-curator[text_cpu]")

from transformers import AutoModel

from nemo_curator.backends.base import WorkerMetadata
Expand Down Expand Up @@ -98,9 +100,49 @@ def _get_last_token(self, model_output: torch.Tensor, attention_mask: torch.Tens
return F.normalize(last_token_embeddings, dim=1)


class SentenceTransformerEmbeddingModelStage(EmbeddingModelStage):
def __init__( # noqa: PLR0913
self,
model_identifier: str,
embedding_field: str = "embeddings",
hf_token: str | None = None,
model_inference_batch_size: int = 1024,
has_seq_order: bool = True,
padding_side: Literal["left", "right"] = "right",
autocast: bool = True,
):
super().__init__(
model_identifier=model_identifier,
hf_token=hf_token,
model_inference_batch_size=model_inference_batch_size,
has_seq_order=has_seq_order,
padding_side=padding_side,
autocast=autocast,
)
Comment on lines 117 to 125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent class EmbeddingModelStage.__init__() is called without passing a pooling parameter, which means the parent's default pooling value will be used. While this doesn't cause issues since SentenceTransformerEmbeddingModelStage doesn't use the pooling attribute, consider adding a comment to clarify this is intentional.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 117 to 125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The super().__init__() call is missing the pooling parameter that the parent EmbeddingModelStage requires. The parent class's __init__ method at line 40 defines pooling as a required positional parameter with a default value. While SentenceTransformer doesn't use the pooling parameter (it handles pooling internally), not passing it means the parent class won't set self.pooling, which could cause issues if any parent class methods expect this attribute to exist.

The super().init() call should include pooling="mean_pooling" to match the parent class signature.

# Override unpack_inference_batch to False as SentenceTransformer expects a dictionary input
self.unpack_inference_batch = False
self.embedding_field = embedding_field
Comment on lines 117 to 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the parent EmbeddingModelStage.__init__ sets self.embedding_field using the default value, and then line 124 sets it again with the same value from the parameter. This is redundant. Additionally, the parent sets self.pooling = "mean_pooling" (the default) which is never used by SentenceTransformer, creating a misleading attribute.

While this doesn't cause runtime errors, it's confusing for maintainability. Consider either:

  1. Passing embedding_field and pooling to the parent constructor explicitly, or
  2. Not setting self.embedding_field again after the parent call

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 117 to 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the initialization pattern here is fragile. super().__init__() is called with parameters that set self.unpack_inference_batch = True (via parent's init at line 53), then line 123 immediately overrides it to False. while this works, it's error-prone because:

  1. the parent class initialization does unnecessary work that's immediately overridden
  2. if the parent's __init__ implementation changes, this could break
  3. it's not immediately clear to future maintainers why this override is needed

consider either:

  • not passing the parameter to parent if it will be overridden
  • or better yet, modify the parent class to accept this as a parameter so the override isn't needed

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], [self.embedding_field]

def setup(self, _: WorkerMetadata | None = None) -> None:
"""Load the model for inference."""
self.model = SentenceTransformer(self.model_identifier, local_files_only=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the SentenceTransformer initialization doesn't support custom cache directory. the SentenceTransformer class accepts a cache_folder parameter, but this stage doesn't accept or use it. consider adding cache_dir parameter support for consistency with other model stages:

Suggested change
self.model = SentenceTransformer(self.model_identifier, local_files_only=True)
self.model = SentenceTransformer(
self.model_identifier,
cache_folder=self.cache_dir if hasattr(self, 'cache_dir') and self.cache_dir else None,
local_files_only=True
)

and update __init__ to accept and store cache_dir parameter

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

self.model.eval().to("cuda")
Comment on lines 133 to 141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SentenceTransformerEmbeddingModelStage has the same issue as EmbeddingModelStage - it overrides setup() but doesn't implement the _setup(local_files_only) method expected by the parent ModelStage class. this will cause issues with model downloading in distributed settings where setup_on_node() is called first to download models to each node.

the parent class ModelStage's setup_on_node() method (model.py lines 85-102) expects subclasses to implement _setup so it can be called with local_files_only=False during download and local_files_only=True during worker setup.

Comment on lines 133 to 141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] SentenceTransformerEmbeddingModelStage.setup() also overrides parent without implementing _setup() method. Same distributed setup issue as EmbeddingModelStage - model won't load when setup_on_node() is called in distributed environments.


Comment on lines 133 to 142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SentenceTransformerEmbeddingModelStage class overrides setup() but doesn't override setup_on_node(). The parent ModelStage.setup_on_node() (defined in nemo_curator/stages/text/models/model.py at line 85-102) will attempt to call _setup(local_files_only=False) which doesn't exist in this class, resulting in a warning log message "Subclass SentenceTransformerEmbeddingModelStage does not implement _setup method" at line 98 of model.py.

To properly support node-level setup (which downloads the model), this class should override setup_on_node() to handle SentenceTransformer model downloads, similar to how the parent class handles it for AutoModel. This should use snapshot_download to download the model before calling the existing setup() method.

def process_model_output(
self,
outputs: torch.Tensor,
model_input_batch: dict[str, torch.Tensor] | None = None, # noqa: ARG002
) -> torch.Tensor:
return outputs["sentence_embedding"].cpu()


@dataclass(kw_only=True)
class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]):
model_identifier: str = "sentence-transformers/all-MiniLM-L6-v2"
use_sentence_transformer: bool = True
text_field: str = "text"
embedding_field: str = "embeddings"
max_chars: int | None = None
Expand All @@ -115,6 +157,16 @@ class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]):
def __post_init__(self) -> None:
super().__init__()

model_class = SentenceTransformerEmbeddingModelStage if self.use_sentence_transformer else EmbeddingModelStage

if self.use_sentence_transformer:
logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in warning message: "ignoring" should be a complete sentence. Consider:

Suggested change
logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling")
logger.warning("Using SentenceTransformer for embedding model; ignoring embedding_pooling parameter")

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +172 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when use_sentence_transformer=True, the code logs a warning that embedding_pooling is ignored, but there's no validation to check if the user explicitly set a non-default pooling value. if a user explicitly sets embedding_pooling="last_token" with use_sentence_transformer=True, they'll only get a warning, which they might miss. consider validating that embedding_pooling has its default value when use_sentence_transformer=True, or raise an error if it's been explicitly changed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the warning message says "ignoring embedding_pooling" but it should be more clear about what the expected behavior is. consider rephrasing to: "SentenceTransformer uses its own internal pooling configuration; the embedding_pooling parameter will be ignored"

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +172 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning message says "ignoring embedding_pooling" but uses incorrect grammar. should be "ignoring embedding_pooling parameter"

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

model_additional_kwargs = {}
else:
model_additional_kwargs = {
"pooling": self.embedding_pooling,
}

self.stages = [
TokenizerStage(
model_identifier=self.model_identifier,
Expand All @@ -125,15 +177,15 @@ def __post_init__(self) -> None:
padding_side=self.padding_side,
sort_by_length=self.sort_by_length,
),
EmbeddingModelStage(
model_class(
model_identifier=self.model_identifier,
embedding_field=self.embedding_field,
pooling=self.embedding_pooling,
hf_token=self.hf_token,
model_inference_batch_size=self.model_inference_batch_size,
has_seq_order=self.sort_by_length,
padding_side=self.padding_side,
autocast=self.autocast,
**model_additional_kwargs,
),
]

Expand Down
138 changes: 138 additions & 0 deletions nemo_curator/stages/text/embedders/vllm.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the file shouldn't be called vllm.py, to avoid possible import weirdness with the vllm library.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have thoughts on what it should be? Given the import needs to be from x.y.z.vllm import vLLMEmbeddingStage I don't see it to be an issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm_model.py or vllm_embedder.py could work.

Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import TYPE_CHECKING, Any

from huggingface_hub import snapshot_download
from vllm import LLM

from nemo_curator.backends.base import NodeInfo, WorkerMetadata
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources
from nemo_curator.stages.text.models.utils import format_name_with_suffix
from nemo_curator.tasks import DocumentBatch

if TYPE_CHECKING:
from transformers import AutoTokenizer


class VLLMEmbeddingModelStage(ProcessingStage[DocumentBatch, DocumentBatch]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont understand why we have a separate VLLMEmbeddingModelStage .

From a user workflow perspective, should we not always use the same interface and let user choose the backend ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like, when the eventual switch happens for VLLM as default, this will make users/tutorials go and change their workflow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question I am very reulctant to having a monolithy stage where everything is seemingly controlled by "backend=vllm/sentence_transformers/nn_module" because in reality it doesn't end up like that.Here is my reasoning

  • The backends have materially different surfaces (pooling options for nn_module, vLLM-specific kwargs like pretokenize or vllm_init_kwargs, sentence-transformers lacking pooling and requiring model_inference_batch_size etc.). A single “backend” flag forces one stage to carry many backend-only parameters that are dead or invalid in other paths, making validation, docs, UX and developer experience worse.
  • We’ve seen this pattern go wrong (e.g., in our Dask’s code base we had one read(..., format=...)): the code became convoluted, and users still had to learn format-specific caveats. Changing a default later still required users to adjust arguments because of backend-only options.
  • Separate stages keep each interface tight, validated, and testable; errors stay specific to the backend. It also keeps tutorials and maintenance simpler—no giant monolith of conditional logic.

This bloat however will show up in TextSemanticWorkflow which does many things including emebedding generation. If we want to support both vllm and sentence transformers, we'll have the same argument bloat in our workflow construction (and there is already so much bloat because it also does kmeans / pairwise / removal so all those arguments), in my opinion, we should support one backend by default; users who need another backend can compose/override the embedding stage explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the point that vLLM / sentence-transformers / nn_module have materially different parameter surfaces (pooling semantics, vLLM init kwargs, pretokenization options, etc.). Where I disagree is the conclusion that this forces multiple public stages or forces users to manually re-compose workflows.

I’d prefer we optimize for user effort by keeping a single default “Embedding” stage as the public entry point, but avoid the “giant backend-specific kwargs soup” by structuring it as a facade + typed backend configs:

  • EmbeddingStage(model_identifier=..., backend="vllm", backend_config=VllmConfig(...))
  • EmbeddingStage(..., backend="sentence_transformers", backend_config=SentenceTransformersConfig(...))

This gives us:

  • a stable, easy tutorial/workflow story (“use EmbeddingStage”) with reasonable defaults,
  • backend-specific validation and docs (each config is tight and testable),
  • and avoids a single constructor accumulating dead/invalid params across backends.

It also reduces migration burden in workflows: new backends become new config types rather than new stage classes + new tutorial branches. Users with advanced needs can still instantiate the backend-specific stage/config directly, but the default path stays unified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We synced offline, the goal of making users use the best/fastest framework for 90% of the cases will be done by doing a fast follow to this PR where the text embedding/semdedup workflows will move on to VLLM backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should we not import this logic for ModelStage here ???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our ModelStage is currently more like PreTokenizedSortedTextModelStage as it has a lot of helpers from there. We don't really reuse many functions from there, so there isn't a need to extend that.

However if the intent is we should unify model interfaces of our repos.. 100%, we're not doing great on that and we should, however that warrants a bigger overhaul something that also considers image / audio / video modalities

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this, the goal indeed is to enforece unification/same signature and some functional contracts b/w all the backend stages.

We are aligned on that as long as we achieve that. My opinion is to have it in this PR for the text stages, other stages can do a fast follow/unification

def __init__( # noqa: PLR0913
self,
model_identifier: str,
vllm_init_kwargs: dict[str, Any] | None = None,
text_field: str = "text",
pretokenize: bool = False,
embedding_field: str = "embeddings",
cache_dir: str | None = None,
hf_token: str | None = None,
verbose: bool = False,
):
self.model_identifier = model_identifier
self.vllm_init_kwargs = vllm_init_kwargs or {}

self.text_field = text_field
self.pretokenize = pretokenize
self.embedding_field = embedding_field

self.cache_dir = cache_dir
self.hf_token = hf_token

self.verbose = verbose
# after setup
self.model: None | LLM = None
self.tokenizer: None | AutoTokenizer = None
# stage setup
self.resources = Resources(
cpus=1,
gpus=1,
)
self.name = format_name_with_suffix(model_identifier, suffix="_vllm")

def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], [self.text_field]

def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], [self.text_field, self.embedding_field]

def _initialize_vllm(self) -> None:
vllm_init_kwargs = self.vllm_init_kwargs.copy()
# set defaults here
if "enforce_eager" not in vllm_init_kwargs:
vllm_init_kwargs["enforce_eager"] = False
if "runner" not in vllm_init_kwargs:
vllm_init_kwargs["runner"] = "pooling"
if "model_impl" not in vllm_init_kwargs:
# TODO: Once transformers is bumpted to 5.0 then we should also support transformers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "bumpted" should be "bumped":

Suggested change
# TODO: Once transformers is bumpted to 5.0 then we should also support transformers
# TODO: Once transformers is bumped to 5.0 then we should also support transformers

vllm_init_kwargs["model_impl"] = "vllm"

# Reduce verbosity when not in verbose mode
if not self.verbose and "disable_log_stats" not in vllm_init_kwargs:
vllm_init_kwargs["disable_log_stats"] = True

self.model = LLM(model=self.model_identifier, **vllm_init_kwargs)
Copy link
Contributor

@sarahyurick sarahyurick Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the parameters I was experimenting with:

self.model = vllm.LLM(
    model=self.model_identifier,
    max_num_seqs=self.model_inference_batch_size,
    max_num_batched_tokens=self.max_num_batched_tokens,
    gpu_memory_utilization=self.gpu_memory_utilization,
    # task="embed",
    tensor_parallel_size=1,
    enforce_eager=False,
    disable_log_stats=True,
    # runner="pooling",
)

I am wondering if we should include this logic too:

if "max_num_seqs" not in vllm_init_kwargs:
    vllm_init_kwargs["max_num_seqs"] = self.model_inference_batch_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't benchmarked the perf for these kwargs so I feel reluctant to add them, I'd prefer if users pass it through model_kwargs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For most kwargs I agree, yes. I guess it depends on how closely we want to stay aligned with ModelStage/EmbeddingModelStage with init parameters, like model_inference_batch_size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cache_dir parameter from __init__ is not passed to the LLM initialization. vLLM's LLM class accepts a download_dir parameter that should be used here. currently, vLLM will use its default cache location which may differ from where snapshot_download cached the model in setup_on_node(), potentially causing duplicate downloads or cache misses

Suggested change
self.model = LLM(model=self.model_identifier, **vllm_init_kwargs)
self.model = LLM(model=self.model_identifier, download_dir=self.cache_dir, **vllm_init_kwargs)


def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002
if not self.verbose:
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()

snapshot_download(self.model_identifier, cache_dir=self.cache_dir, token=self.hf_token, local_files_only=False)
self._initialize_vllm()

def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002
if self.model is None:
self._initialize_vllm()
if self.pretokenize:
from transformers import AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(self.model_identifier)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] The tokenizer loading should respect the cache directory and authentication settings passed to the stage constructor. Currently only model_identifier is used, which may cause inconsistent behavior with model downloads.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tokenizer loading doesn't pass cache_dir, which means it will look in the default HuggingFace cache directory instead of using self.cache_dir specified in the constructor. This will fail when a custom cache directory is used, since setup_on_node downloads to self.cache_dir but the tokenizer loads from the default location.

Suggested change
self.tokenizer = AutoTokenizer.from_pretrained(self.model_identifier)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_identifier, cache_dir=self.cache_dir)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tokenizer is loaded without the cache directory, authentication parameter, or local-only flag that are used elsewhere in the codebase (see line 95 for snapshot_download). if the tokenizer isn't already cached or requires authentication, this will fail or make unexpected network calls during processing. the tokenizer loading should include cache_dir=self.cache_dir and local_files_only=True parameters for consistency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tokenizer is loaded without proper parameters for distributed execution. it's missing local_files_only=True which could cause failures if model files aren't already cached locally. it should match the pattern used in setup_on_node() at line 95 where snapshot_download is called with local_files_only=False to download, and then this setup should use local_files_only=True.


def process(self, batch: DocumentBatch) -> DocumentBatch:
df = batch.to_pandas()
input_data = df[self.text_field].tolist()
metrics = {}

if self.pretokenize:
from vllm.inputs import TokensPrompt
Comment on lines 103 to 121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to top-level imports?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of top level imports especially in-case of imports that don't need to get resolved. For instance in case of pretokenize=False there is no need to for user to resolve from transformers import AutoTokenizer which will have its own plethora of imports.


if self.tokenizer is None:
msg = (
"Tokenizer is not initialized. Please call setup() before processing or set pretokenize to False."
)
raise ValueError(msg)

t0 = time.perf_counter()
max_model_len = self.model.model_config.max_model_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code accesses self.model.model_config.max_model_len without checking if self.model is initialized. if process() is called before setup(), this will raise an AttributeError: 'NoneType' object has no attribute 'model_config'

Suggested change
max_model_len = self.model.model_config.max_model_len
if self.model is None:
msg = "Model is not initialized. Please call setup() before processing."
raise ValueError(msg)
max_model_len = self.model.model_config.max_model_len

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Accessing model_config when model is None will cause AttributeError. This can happen if process() is called before setup() completes. Add check before line 123 to ensure model is initialized.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing model.model_config without checking if model is None. When pretokenize is True and process is called before setup, this line will crash with AttributeError. Should validate self.model is initialized like the tokenizer check above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accessing self.model.model_config before checking if self.model is None could cause AttributeError if setup() wasn't called. add a check:

Suggested change
max_model_len = self.model.model_config.max_model_len
if self.model is None:
msg = "Model is not initialized. Please call setup() before processing."
raise ValueError(msg)
max_model_len = self.model.model_config.max_model_len

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: validate model initialization before accessing config property, similar to tokenizer validation at line 123

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: accessing model.model_config without checking if model is initialized. if process() called before setup(), raises AttributeError: 'NoneType' object has no attribute 'model_config'

Suggested change
max_model_len = self.model.model_config.max_model_len
if self.model is None:
raise ValueError("Model is not initialized. Please call setup() before processing.")
max_model_len = self.model.model_config.max_model_len

tokenized_data = self.tokenizer.batch_encode_plus(input_data, truncation=True, max_length=max_model_len)
Comment on lines +130 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that max_model_len is set to a reasonable value before using it for truncation. If the model config has an unusually large or unset max_model_len, the tokenization could produce unexpected results.

input_data = [TokensPrompt(prompt_token_ids=ids) for ids in tokenized_data.input_ids]
metrics["tokenization_time"] = time.perf_counter() - t0
Comment on lines +120 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code checks if self.tokenizer is None at line 116 but doesn't check if self.model is None before accessing self.model.model_config.max_model_len at line 123. If setup() was not called or failed to initialize the model, this will raise an AttributeError.

While setup() should normally be called before process(), defensive programming suggests adding a check for self.model as well, similar to the tokenizer check.

Suggested change
if self.pretokenize:
from vllm.inputs import TokensPrompt
if self.tokenizer is None:
msg = (
"Tokenizer is not initialized. Please call setup() before processing or set pretokenize to False."
)
raise ValueError(msg)
t0 = time.perf_counter()
max_model_len = self.model.model_config.max_model_len
tokenized_data = self.tokenizer.batch_encode_plus(input_data, truncation=True, max_length=max_model_len)
input_data = [TokensPrompt(prompt_token_ids=ids) for ids in tokenized_data.input_ids]
metrics["tokenization_time"] = time.perf_counter() - t0
if self.tokenizer is None:
msg = (
"Tokenizer is not initialized. Please call setup() before processing or set pretokenize to False."
)
raise ValueError(msg)
if self.model is None:
msg = "Model is not initialized. Please call setup() before processing."
raise ValueError(msg)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


t0 = time.perf_counter()
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an advantage to doing tokenization then the model forward pass within the same stage? I was thinking that we should either break tokenization into a separate CPU-only stage or always let self.model.embed(...) handle it like:

Suggested change
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose)
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=self.max_seq_length, use_tqdm=self.verbose)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an advantage to doing tokenization then the model forward pass within the same stage

Yup, if you see the description, the vllm_tokens, vllm_text and vllm_pretokenize.. vllm_pretokenize comes in most ahead, followed by vllm_text. vllm_tokens is also close but we have no reason to have an extra stage which incurs serialization overhead and increase dev mantainenance .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really see a strong reason for allowing multiple options unless we would expect there to be advantages when switching models (i.e., model A performs better with pretokenize=True but model B performs better with pretokenize=False).

The precedence from before was to delegate tokenization to a CPU-only stage, with serialization overhead not being a huge issue (due to max_chars and max_seq_len). In the current implementation here, tokenization is done during a GPU stage no matter what.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter value -1 for truncate_prompt_tokens may cause unexpected behavior. According to vLLM documentation, this should be either a positive integer or None for no truncation. Consider using None if truncation should be disabled, or verify that -1 is the intended value for this vLLM version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] When self.model is None at line 129, calling self.model.embed() will cause AttributeError. Add model initialization check before this line to fail fast with a clear error message.

metrics["vllm_embedding_time"] = time.perf_counter() - t0
Comment on lines +135 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the self.model.embed() call lacks error handling. if the embedding operation fails (GPU issues, invalid input, model errors), it will crash the entire processing pipeline. wrap this in a try-except block to provide better error messages and allow for graceful failure recovery.


df[self.embedding_field] = [e.outputs.embedding for e in vllm_output]
Comment on lines +135 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling for the model.embed() call. If the model fails to generate embeddings (e.g., due to malformed inputs, resource exhaustion, or vLLM errors), the stage will crash without a helpful error message. Consider wrapping this in a try-except block with informative error messaging.


self._log_metrics(metrics)

return DocumentBatch(
task_id=batch.task_id,
dataset_name=batch.dataset_name,
data=df,
_metadata=batch._metadata,
_stage_perf=batch._stage_perf,
)
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ dependencies = [

[project.optional-dependencies]
cuda12 = ["gpustat", "nvidia-ml-py"]
vllm = ["vllm>=0.13; (platform_machine == 'x86_64' and platform_system != 'Darwin')"]
Comment on lines 70 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vLLM version constraint change from ==0.11.1 (in video_cuda12) to >=0.13 may cause compatibility issues for video curation features that depend on vLLM 0.11.1 specifically. The video_cuda12 dependency previously pinned vLLM to 0.11.1, but now it will use 0.13+.

Before merging, please verify that:

  1. Video curation features work correctly with vLLM 0.13+
  2. There are no breaking API changes between vLLM 0.11.1 and 0.13 that affect video processing

If video curation requires vLLM 0.11.1 specifically, consider creating a separate vLLM version constraint for video vs text embedding use cases.


# Installs CPU + GPU text curation modules
deduplication_cuda12 = [
"cudf-cu12==25.10.*",
"cuml-cu12==25.10.*",
"scikit-learn<1.8.0", # cuml 25.10.0 is incompatible with scikit-learn 1.8.0
"pylibcugraph-cu12==25.10.*",
"pylibraft-cu12==25.10.*",
"raft-dask-cu12==25.10.*",
Expand Down Expand Up @@ -124,6 +126,8 @@ text_cuda12 = [
"nemo_curator[cuda12]",
"nemo_curator[deduplication_cuda12]",
"nemo_curator[text_cpu]",
"nemo_curator[vllm]",
"sentence-transformers"
]

# Video Curation Dependencies
Expand All @@ -138,13 +142,13 @@ video_cpu = [
video_cuda12 = [
"nemo_curator[video_cpu]",
"nemo_curator[cuda12]",
"nemo_curator[vllm]",
"cvcuda_cu12",
"flash-attn<=2.8.3; (platform_machine == 'x86_64' and platform_system != 'Darwin')",
"pycuda",
"PyNvVideoCodec==2.0.2; (platform_machine == 'x86_64' and platform_system != 'Darwin')",
"torch<=2.9.0",
"torch<=2.9.1",
"torchaudio",
"vllm==0.11.1; (platform_machine == 'x86_64' and platform_system != 'Darwin')",
]

# All dependencies
Expand All @@ -156,7 +160,7 @@ all = [
]

[dependency-groups]
build = ["setuptools", "torch<=2.9.0"]
build = ["setuptools", "torch<=2.9.1"]
dev = ["jupyter"]
linting = ["pre-commit", "ruff==0.11.4"]
test = [
Expand All @@ -166,7 +170,7 @@ test = [
"pytest-asyncio",
"pytest-cov",
"pytest-loguru",
"scikit-learn",
"scikit-learn<1.8.0", # cuml 25.10.0 is incompatible with scikit-learn 1.8.0
"s3fs", # added for testing cloud fs
]

Expand Down
9 changes: 7 additions & 2 deletions tests/backends/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,18 @@ def test_ray_data_execution_plan(self):
execution_plan_stages = [stage.strip() for stage in stages]
# Tasks can get fused with Actors, but Actors can't get fused with Tasks or Actors
# StreamingRepartition should never get fused

if ray.__version__ >= "2.53.0":
streaming_repartition = "StreamingRepartition[num_rows_per_block=1]"
else:
streaming_repartition = "StreamingRepartition"
expected_stages = [
"InputDataBuffer[Input]",
"TaskPoolMapOperator[MapBatches(FilePartitioningStageTask)]",
"TaskPoolMapOperator[StreamingRepartition]",
f"TaskPoolMapOperator[{streaming_repartition}]",
"ActorPoolMapOperator[MapBatches(JsonlReaderStageTask)->MapBatches(AddLengthStageActor)]",
"ActorPoolMapOperator[MapBatches(SplitIntoRowsStageActor)]",
"TaskPoolMapOperator[StreamingRepartition]",
f"TaskPoolMapOperator[{streaming_repartition}]",
"ActorPoolMapOperator[MapBatches(AddLengthStageActor)]",
"ActorPoolMapOperator[MapBatches(StageWithSetupActor)]",
"TaskPoolMapOperator[MapBatches(JsonlWriterTask)]",
Expand Down
11 changes: 10 additions & 1 deletion tests/stages/text/embedders/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def sample_data(self) -> DocumentBatch:
def test_embedding_creator_stage_initialization_and_decomposition(self) -> None:
"""Test initialization, decomposition, and parameter passing to decomposed stages."""
# Test with custom parameters including hf_token and unk_token
# Note: use_sentence_transformer=False is required to test EmbeddingModelStage with custom pooling
stage = EmbeddingCreatorStage(
model_identifier="test-model",
text_field="content",
Expand All @@ -228,6 +229,7 @@ def test_embedding_creator_stage_initialization_and_decomposition(self) -> None:
model_inference_batch_size=128,
sort_by_length=False,
hf_token="test-token", # noqa:S106
use_sentence_transformer=False,
)

# Test decomposition and stage types
Expand Down Expand Up @@ -295,15 +297,22 @@ def test_embedding_creator_stage_process_integration(self) -> None:

@pytest.mark.parametrize("pooling_strategy", ["mean_pooling", "last_token"])
@pytest.mark.parametrize("autocast", [True, False])
@pytest.mark.parametrize("use_sentence_transformer", [True, False])
@pytest.mark.gpu
def test_embedding_creator_stage_with_reference_embeddings(
self, pooling_strategy: str, sample_data: DocumentBatch, autocast: bool
self, pooling_strategy: str, sample_data: DocumentBatch, autocast: bool, use_sentence_transformer: bool
) -> None:
"""Test embeddings match reference implementation (requires GPU and model download)."""
if use_sentence_transformer and pooling_strategy != "mean_pooling":
pytest.skip(
"Ignoring last_token strategy for sentence transformer as behavior for miniLM is mean pooling "
)
stage = EmbeddingCreatorStage(
model_identifier="sentence-transformers/all-MiniLM-L6-v2",
embedding_pooling=pooling_strategy,
model_inference_batch_size=32,
autocast=autocast,
use_sentence_transformer=use_sentence_transformer,
)

# Decompose and setup stages
Expand Down
Loading
Loading