-
Notifications
You must be signed in to change notification settings - Fork 213
Add vLLM and Sentence Transformers support for embedding generation #1346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
c06c705
d77ef3c
7a318d3
2894bdb
f709307
da477e9
03f9399
ec63c6f
7a9a282
4e3303c
65c7c9c
559f568
eda934a
4c47bfc
a1fe5c3
9eb824c
8a67c40
11d0ebe
bcc51ad
069e1ba
e0a6967
d4b99ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |||||||||||||
| import pandas as pd | ||||||||||||||
| import torch | ||||||||||||||
| import torch.nn.functional as F # noqa: N812 | ||||||||||||||
| from loguru import logger | ||||||||||||||
| from sentence_transformers import SentenceTransformer | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: unconditional import breaks users without wrap in try/except or use conditional import:
Suggested change
then add validation in if SentenceTransformer is None:
raise ImportError("sentence-transformers required. Install with: pip install nemo-curator[text_cpu]") |
||||||||||||||
| from transformers import AutoModel | ||||||||||||||
|
|
||||||||||||||
| from nemo_curator.backends.base import WorkerMetadata | ||||||||||||||
|
|
@@ -98,9 +100,49 @@ def _get_last_token(self, model_output: torch.Tensor, attention_mask: torch.Tens | |||||||||||||
| return F.normalize(last_token_embeddings, dim=1) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class SentenceTransformerEmbeddingModelStage(EmbeddingModelStage): | ||||||||||||||
| def __init__( # noqa: PLR0913 | ||||||||||||||
| self, | ||||||||||||||
| model_identifier: str, | ||||||||||||||
| embedding_field: str = "embeddings", | ||||||||||||||
| hf_token: str | None = None, | ||||||||||||||
| model_inference_batch_size: int = 1024, | ||||||||||||||
| has_seq_order: bool = True, | ||||||||||||||
| padding_side: Literal["left", "right"] = "right", | ||||||||||||||
| autocast: bool = True, | ||||||||||||||
| ): | ||||||||||||||
| super().__init__( | ||||||||||||||
| model_identifier=model_identifier, | ||||||||||||||
| hf_token=hf_token, | ||||||||||||||
| model_inference_batch_size=model_inference_batch_size, | ||||||||||||||
| has_seq_order=has_seq_order, | ||||||||||||||
| padding_side=padding_side, | ||||||||||||||
| autocast=autocast, | ||||||||||||||
| ) | ||||||||||||||
|
Comment on lines
117
to
125
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parent class Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Comment on lines
117
to
125
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The The super().init() call should include |
||||||||||||||
| # Override unpack_inference_batch to False as SentenceTransformer expects a dictionary input | ||||||||||||||
| self.unpack_inference_batch = False | ||||||||||||||
| self.embedding_field = embedding_field | ||||||||||||||
|
Comment on lines
117
to
128
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the parent While this doesn't cause runtime errors, it's confusing for maintainability. Consider either:
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Comment on lines
117
to
128
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the initialization pattern here is fragile.
consider either:
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||||
|
|
||||||||||||||
| def outputs(self) -> tuple[list[str], list[str]]: | ||||||||||||||
| return ["data"], [self.embedding_field] | ||||||||||||||
|
|
||||||||||||||
| def setup(self, _: WorkerMetadata | None = None) -> None: | ||||||||||||||
| """Load the model for inference.""" | ||||||||||||||
| self.model = SentenceTransformer(self.model_identifier, local_files_only=True) | ||||||||||||||
|
||||||||||||||
| self.model = SentenceTransformer(self.model_identifier, local_files_only=True) | |
| self.model = SentenceTransformer( | |
| self.model_identifier, | |
| cache_folder=self.cache_dir if hasattr(self, 'cache_dir') and self.cache_dir else None, | |
| local_files_only=True | |
| ) |
and update __init__ to accept and store cache_dir parameter
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
praateekmahajan marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SentenceTransformerEmbeddingModelStage has the same issue as EmbeddingModelStage - it overrides setup() but doesn't implement the _setup(local_files_only) method expected by the parent ModelStage class. this will cause issues with model downloading in distributed settings where setup_on_node() is called first to download models to each node.
the parent class ModelStage's setup_on_node() method (model.py lines 85-102) expects subclasses to implement _setup so it can be called with local_files_only=False during download and local_files_only=True during worker setup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P2] SentenceTransformerEmbeddingModelStage.setup() also overrides parent without implementing _setup() method. Same distributed setup issue as EmbeddingModelStage - model won't load when setup_on_node() is called in distributed environments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SentenceTransformerEmbeddingModelStage class overrides setup() but doesn't override setup_on_node(). The parent ModelStage.setup_on_node() (defined in nemo_curator/stages/text/models/model.py at line 85-102) will attempt to call _setup(local_files_only=False) which doesn't exist in this class, resulting in a warning log message "Subclass SentenceTransformerEmbeddingModelStage does not implement _setup method" at line 98 of model.py.
To properly support node-level setup (which downloads the model), this class should override setup_on_node() to handle SentenceTransformer model downloads, similar to how the parent class handles it for AutoModel. This should use snapshot_download to download the model before calling the existing setup() method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in warning message: "ignoring" should be a complete sentence. Consider:
| logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling") | |
| logger.warning("Using SentenceTransformer for embedding model; ignoring embedding_pooling parameter") |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when use_sentence_transformer=True, the code logs a warning that embedding_pooling is ignored, but there's no validation to check if the user explicitly set a non-default pooling value. if a user explicitly sets embedding_pooling="last_token" with use_sentence_transformer=True, they'll only get a warning, which they might miss. consider validating that embedding_pooling has its default value when use_sentence_transformer=True, or raise an error if it's been explicitly changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the warning message says "ignoring embedding_pooling" but it should be more clear about what the expected behavior is. consider rephrasing to: "SentenceTransformer uses its own internal pooling configuration; the embedding_pooling parameter will be ignored"
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning message says "ignoring embedding_pooling" but uses incorrect grammar. should be "ignoring embedding_pooling parameter"
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the file shouldn't be called
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have thoughts on what it should be? Given the import needs to be
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,138 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from huggingface_hub import snapshot_download | ||||||||||||||||||||||||||||||||||||||||||||||
| from vllm import LLM | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_curator.backends.base import NodeInfo, WorkerMetadata | ||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_curator.stages.base import ProcessingStage | ||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_curator.stages.resources import Resources | ||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_curator.stages.text.models.utils import format_name_with_suffix | ||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_curator.tasks import DocumentBatch | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||||||||||||||
| from transformers import AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| class VLLMEmbeddingModelStage(ProcessingStage[DocumentBatch, DocumentBatch]): | ||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont understand why we have a separate VLLMEmbeddingModelStage . From a user workflow perspective, should we not always use the same interface and let user choose the backend ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like, when the eventual switch happens for VLLM as default, this will make users/tutorials go and change their workflow.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question I am very reulctant to having a monolithy stage where everything is seemingly controlled by
This bloat however will show up in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand the point that vLLM / sentence-transformers / nn_module have materially different parameter surfaces (pooling semantics, vLLM init kwargs, pretokenization options, etc.). Where I disagree is the conclusion that this forces multiple public stages or forces users to manually re-compose workflows. I’d prefer we optimize for user effort by keeping a single default “Embedding” stage as the public entry point, but avoid the “giant backend-specific kwargs soup” by structuring it as a facade + typed backend configs:
This gives us:
It also reduces migration burden in workflows: new backends become new config types rather than new stage classes + new tutorial branches. Users with advanced needs can still instantiate the backend-specific stage/config directly, but the default path stays unified.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We synced offline, the goal of making users use the best/fastest framework for 90% of the cases will be done by doing a fast follow to this PR where the text embedding/semdedup workflows will move on to VLLM backend.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, should we not import this logic for ModelStage here ???
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our However if the intent is we should unify model interfaces of our repos.. 100%, we're not doing great on that and we should, however that warrants a bigger overhaul something that also considers image / audio / video modalities
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed this, the goal indeed is to enforece unification/same signature and some functional contracts b/w all the backend stages. We are aligned on that as long as we achieve that. My opinion is to have it in this PR for the text stages, other stages can do a fast follow/unification |
||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( # noqa: PLR0913 | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| model_identifier: str, | ||||||||||||||||||||||||||||||||||||||||||||||
| vllm_init_kwargs: dict[str, Any] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| text_field: str = "text", | ||||||||||||||||||||||||||||||||||||||||||||||
| pretokenize: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||
| embedding_field: str = "embeddings", | ||||||||||||||||||||||||||||||||||||||||||||||
| cache_dir: str | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| hf_token: str | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| verbose: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| self.model_identifier = model_identifier | ||||||||||||||||||||||||||||||||||||||||||||||
| self.vllm_init_kwargs = vllm_init_kwargs or {} | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.text_field = text_field | ||||||||||||||||||||||||||||||||||||||||||||||
| self.pretokenize = pretokenize | ||||||||||||||||||||||||||||||||||||||||||||||
| self.embedding_field = embedding_field | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.cache_dir = cache_dir | ||||||||||||||||||||||||||||||||||||||||||||||
| self.hf_token = hf_token | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.verbose = verbose | ||||||||||||||||||||||||||||||||||||||||||||||
| # after setup | ||||||||||||||||||||||||||||||||||||||||||||||
| self.model: None | LLM = None | ||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer: None | AutoTokenizer = None | ||||||||||||||||||||||||||||||||||||||||||||||
| # stage setup | ||||||||||||||||||||||||||||||||||||||||||||||
| self.resources = Resources( | ||||||||||||||||||||||||||||||||||||||||||||||
| cpus=1, | ||||||||||||||||||||||||||||||||||||||||||||||
| gpus=1, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.name = format_name_with_suffix(model_identifier, suffix="_vllm") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def inputs(self) -> tuple[list[str], list[str]]: | ||||||||||||||||||||||||||||||||||||||||||||||
| return ["data"], [self.text_field] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def outputs(self) -> tuple[list[str], list[str]]: | ||||||||||||||||||||||||||||||||||||||||||||||
| return ["data"], [self.text_field, self.embedding_field] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _initialize_vllm(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| vllm_init_kwargs = self.vllm_init_kwargs.copy() | ||||||||||||||||||||||||||||||||||||||||||||||
| # set defaults here | ||||||||||||||||||||||||||||||||||||||||||||||
| if "enforce_eager" not in vllm_init_kwargs: | ||||||||||||||||||||||||||||||||||||||||||||||
| vllm_init_kwargs["enforce_eager"] = False | ||||||||||||||||||||||||||||||||||||||||||||||
| if "runner" not in vllm_init_kwargs: | ||||||||||||||||||||||||||||||||||||||||||||||
| vllm_init_kwargs["runner"] = "pooling" | ||||||||||||||||||||||||||||||||||||||||||||||
| if "model_impl" not in vllm_init_kwargs: | ||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: Once transformers is bumpted to 5.0 then we should also support transformers | ||||||||||||||||||||||||||||||||||||||||||||||
praateekmahajan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: Once transformers is bumpted to 5.0 then we should also support transformers | |
| # TODO: Once transformers is bumped to 5.0 then we should also support transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the parameters I was experimenting with:
self.model = vllm.LLM(
model=self.model_identifier,
max_num_seqs=self.model_inference_batch_size,
max_num_batched_tokens=self.max_num_batched_tokens,
gpu_memory_utilization=self.gpu_memory_utilization,
# task="embed",
tensor_parallel_size=1,
enforce_eager=False,
disable_log_stats=True,
# runner="pooling",
)
I am wondering if we should include this logic too:
if "max_num_seqs" not in vllm_init_kwargs:
vllm_init_kwargs["max_num_seqs"] = self.model_inference_batch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't benchmarked the perf for these kwargs so I feel reluctant to add them, I'd prefer if users pass it through model_kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For most kwargs I agree, yes. I guess it depends on how closely we want to stay aligned with ModelStage/EmbeddingModelStage with init parameters, like model_inference_batch_size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the cache_dir parameter from __init__ is not passed to the LLM initialization. vLLM's LLM class accepts a download_dir parameter that should be used here. currently, vLLM will use its default cache location which may differ from where snapshot_download cached the model in setup_on_node(), potentially causing duplicate downloads or cache misses
| self.model = LLM(model=self.model_identifier, **vllm_init_kwargs) | |
| self.model = LLM(model=self.model_identifier, download_dir=self.cache_dir, **vllm_init_kwargs) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P2] The tokenizer loading should respect the cache directory and authentication settings passed to the stage constructor. Currently only model_identifier is used, which may cause inconsistent behavior with model downloads.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tokenizer loading doesn't pass cache_dir, which means it will look in the default HuggingFace cache directory instead of using self.cache_dir specified in the constructor. This will fail when a custom cache directory is used, since setup_on_node downloads to self.cache_dir but the tokenizer loads from the default location.
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_identifier) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_identifier, cache_dir=self.cache_dir) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tokenizer is loaded without the cache directory, authentication parameter, or local-only flag that are used elsewhere in the codebase (see line 95 for snapshot_download). if the tokenizer isn't already cached or requires authentication, this will fail or make unexpected network calls during processing. the tokenizer loading should include cache_dir=self.cache_dir and local_files_only=True parameters for consistency.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tokenizer is loaded without proper parameters for distributed execution. it's missing local_files_only=True which could cause failures if model files aren't already cached locally. it should match the pattern used in setup_on_node() at line 95 where snapshot_download is called with local_files_only=False to download, and then this setup should use local_files_only=True.
praateekmahajan marked this conversation as resolved.
Show resolved
Hide resolved
praateekmahajan marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to top-level imports?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of top level imports especially in-case of imports that don't need to get resolved. For instance in case of pretokenize=False there is no need to for user to resolve from transformers import AutoTokenizer which will have its own plethora of imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the code accesses self.model.model_config.max_model_len without checking if self.model is initialized. if process() is called before setup(), this will raise an AttributeError: 'NoneType' object has no attribute 'model_config'
| max_model_len = self.model.model_config.max_model_len | |
| if self.model is None: | |
| msg = "Model is not initialized. Please call setup() before processing." | |
| raise ValueError(msg) | |
| max_model_len = self.model.model_config.max_model_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] Accessing model_config when model is None will cause AttributeError. This can happen if process() is called before setup() completes. Add check before line 123 to ensure model is initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing model.model_config without checking if model is None. When pretokenize is True and process is called before setup, this line will crash with AttributeError. Should validate self.model is initialized like the tokenizer check above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accessing self.model.model_config before checking if self.model is None could cause AttributeError if setup() wasn't called. add a check:
| max_model_len = self.model.model_config.max_model_len | |
| if self.model is None: | |
| msg = "Model is not initialized. Please call setup() before processing." | |
| raise ValueError(msg) | |
| max_model_len = self.model.model_config.max_model_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: validate model initialization before accessing config property, similar to tokenizer validation at line 123
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: accessing model.model_config without checking if model is initialized. if process() called before setup(), raises AttributeError: 'NoneType' object has no attribute 'model_config'
| max_model_len = self.model.model_config.max_model_len | |
| if self.model is None: | |
| raise ValueError("Model is not initialized. Please call setup() before processing.") | |
| max_model_len = self.model.model_config.max_model_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify that max_model_len is set to a reasonable value before using it for truncation. If the model config has an unusually large or unset max_model_len, the tokenization could produce unexpected results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code checks if self.tokenizer is None at line 116 but doesn't check if self.model is None before accessing self.model.model_config.max_model_len at line 123. If setup() was not called or failed to initialize the model, this will raise an AttributeError.
While setup() should normally be called before process(), defensive programming suggests adding a check for self.model as well, similar to the tokenizer check.
| if self.pretokenize: | |
| from vllm.inputs import TokensPrompt | |
| if self.tokenizer is None: | |
| msg = ( | |
| "Tokenizer is not initialized. Please call setup() before processing or set pretokenize to False." | |
| ) | |
| raise ValueError(msg) | |
| t0 = time.perf_counter() | |
| max_model_len = self.model.model_config.max_model_len | |
| tokenized_data = self.tokenizer.batch_encode_plus(input_data, truncation=True, max_length=max_model_len) | |
| input_data = [TokensPrompt(prompt_token_ids=ids) for ids in tokenized_data.input_ids] | |
| metrics["tokenization_time"] = time.perf_counter() - t0 | |
| if self.tokenizer is None: | |
| msg = ( | |
| "Tokenizer is not initialized. Please call setup() before processing or set pretokenize to False." | |
| ) | |
| raise ValueError(msg) | |
| if self.model is None: | |
| msg = "Model is not initialized. Please call setup() before processing." | |
| raise ValueError(msg) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an advantage to doing tokenization then the model forward pass within the same stage? I was thinking that we should either break tokenization into a separate CPU-only stage or always let self.model.embed(...) handle it like:
| vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose) | |
| vllm_output = self.model.embed(input_data, truncate_prompt_tokens=self.max_seq_length, use_tqdm=self.verbose) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an advantage to doing tokenization then the model forward pass within the same stage
Yup, if you see the description, the vllm_tokens, vllm_text and vllm_pretokenize.. vllm_pretokenize comes in most ahead, followed by vllm_text. vllm_tokens is also close but we have no reason to have an extra stage which incurs serialization overhead and increase dev mantainenance .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really see a strong reason for allowing multiple options unless we would expect there to be advantages when switching models (i.e., model A performs better with pretokenize=True but model B performs better with pretokenize=False).
The precedence from before was to delegate tokenization to a CPU-only stage, with serialization overhead not being a huge issue (due to max_chars and max_seq_len). In the current implementation here, tokenization is done during a GPU stage no matter what.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter value -1 for truncate_prompt_tokens may cause unexpected behavior. According to vLLM documentation, this should be either a positive integer or None for no truncation. Consider using None if truncation should be disabled, or verify that -1 is the intended value for this vLLM version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] When self.model is None at line 129, calling self.model.embed() will cause AttributeError. Add model initialization check before this line to fail fast with a clear error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the self.model.embed() call lacks error handling. if the embedding operation fails (GPU issues, invalid input, model errors), it will crash the entire processing pipeline. wrap this in a try-except block to provide better error messages and allow for graceful failure recovery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling for the model.embed() call. If the model fails to generate embeddings (e.g., due to malformed inputs, resource exhaustion, or vLLM errors), the stage will crash without a helpful error message. Consider wrapping this in a try-except block with informative error messaging.
praateekmahajan marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,11 +68,13 @@ dependencies = [ | |
|
|
||
| [project.optional-dependencies] | ||
| cuda12 = ["gpustat", "nvidia-ml-py"] | ||
| vllm = ["vllm>=0.13; (platform_machine == 'x86_64' and platform_system != 'Darwin')"] | ||
|
Comment on lines
70
to
+73
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The vLLM version constraint change from Before merging, please verify that:
If video curation requires vLLM 0.11.1 specifically, consider creating a separate vLLM version constraint for video vs text embedding use cases. |
||
|
|
||
| # Installs CPU + GPU text curation modules | ||
| deduplication_cuda12 = [ | ||
| "cudf-cu12==25.10.*", | ||
| "cuml-cu12==25.10.*", | ||
| "scikit-learn<1.8.0", # cuml 25.10.0 is incompatible with scikit-learn 1.8.0 | ||
| "pylibcugraph-cu12==25.10.*", | ||
| "pylibraft-cu12==25.10.*", | ||
| "raft-dask-cu12==25.10.*", | ||
|
|
@@ -124,6 +126,8 @@ text_cuda12 = [ | |
| "nemo_curator[cuda12]", | ||
| "nemo_curator[deduplication_cuda12]", | ||
| "nemo_curator[text_cpu]", | ||
| "nemo_curator[vllm]", | ||
| "sentence-transformers" | ||
| ] | ||
|
|
||
| # Video Curation Dependencies | ||
|
|
@@ -138,13 +142,13 @@ video_cpu = [ | |
| video_cuda12 = [ | ||
| "nemo_curator[video_cpu]", | ||
| "nemo_curator[cuda12]", | ||
| "nemo_curator[vllm]", | ||
| "cvcuda_cu12", | ||
| "flash-attn<=2.8.3; (platform_machine == 'x86_64' and platform_system != 'Darwin')", | ||
| "pycuda", | ||
| "PyNvVideoCodec==2.0.2; (platform_machine == 'x86_64' and platform_system != 'Darwin')", | ||
| "torch<=2.9.0", | ||
| "torch<=2.9.1", | ||
| "torchaudio", | ||
| "vllm==0.11.1; (platform_machine == 'x86_64' and platform_system != 'Darwin')", | ||
| ] | ||
|
|
||
| # All dependencies | ||
|
|
@@ -156,7 +160,7 @@ all = [ | |
| ] | ||
|
|
||
| [dependency-groups] | ||
| build = ["setuptools", "torch<=2.9.0"] | ||
| build = ["setuptools", "torch<=2.9.1"] | ||
| dev = ["jupyter"] | ||
| linting = ["pre-commit", "ruff==0.11.4"] | ||
| test = [ | ||
|
|
@@ -166,7 +170,7 @@ test = [ | |
| "pytest-asyncio", | ||
| "pytest-cov", | ||
| "pytest-loguru", | ||
| "scikit-learn", | ||
| "scikit-learn<1.8.0", # cuml 25.10.0 is incompatible with scikit-learn 1.8.0 | ||
| "s3fs", # added for testing cloud fs | ||
| ] | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] This unconditional import will cause
ImportErrorfor users who install without thetext_cuda12extra. The import should be conditional or guarded sincesentence-transformersis in optional dependencies. Consider importing inside the class methods or using a try/except block.