-
Notifications
You must be signed in to change notification settings - Fork 205
Add vLLM and Sentence Transformers support for embedding generation #1346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c06c705
d77ef3c
7a318d3
2894bdb
f709307
da477e9
03f9399
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |||||
| import pandas as pd | ||||||
| import torch | ||||||
| import torch.nn.functional as F # noqa: N812 | ||||||
| from loguru import logger | ||||||
| from sentence_transformers import SentenceTransformer | ||||||
| from transformers import AutoModel | ||||||
|
|
||||||
| from nemo_curator.backends.base import WorkerMetadata | ||||||
|
|
@@ -98,9 +100,49 @@ def _get_last_token(self, model_output: torch.Tensor, attention_mask: torch.Tens | |||||
| return F.normalize(last_token_embeddings, dim=1) | ||||||
|
|
||||||
|
|
||||||
| class SentenceTransformerEmbeddingModelStage(EmbeddingModelStage): | ||||||
| def __init__( # noqa: PLR0913 | ||||||
| self, | ||||||
| model_identifier: str, | ||||||
| embedding_field: str = "embeddings", | ||||||
| hf_token: str | None = None, | ||||||
| model_inference_batch_size: int = 1024, | ||||||
| has_seq_order: bool = True, | ||||||
| padding_side: Literal["left", "right"] = "right", | ||||||
| autocast: bool = True, | ||||||
| ): | ||||||
| super().__init__( | ||||||
| model_identifier=model_identifier, | ||||||
| hf_token=hf_token, | ||||||
| model_inference_batch_size=model_inference_batch_size, | ||||||
| has_seq_order=has_seq_order, | ||||||
| padding_side=padding_side, | ||||||
| autocast=autocast, | ||||||
| ) | ||||||
| # Override unpack_inference_batch to False as SentenceTransformer expects a dictionary input | ||||||
| self.unpack_inference_batch = False | ||||||
| self.embedding_field = embedding_field | ||||||
|
|
||||||
| def outputs(self) -> tuple[list[str], list[str]]: | ||||||
| return ["data"], [self.embedding_field] | ||||||
|
|
||||||
| def setup(self, _: WorkerMetadata | None = None) -> None: | ||||||
| """Load the model for inference.""" | ||||||
| self.model = SentenceTransformer(self.model_identifier, local_files_only=True) | ||||||
| self.model.eval().to("cuda") | ||||||
|
|
||||||
| def process_model_output( | ||||||
| self, | ||||||
| outputs: torch.Tensor, | ||||||
| model_input_batch: dict[str, torch.Tensor] | None = None, # noqa: ARG002 | ||||||
| ) -> torch.Tensor: | ||||||
| return outputs["sentence_embedding"].cpu() | ||||||
|
|
||||||
|
|
||||||
| @dataclass(kw_only=True) | ||||||
| class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]): | ||||||
| model_identifier: str = "sentence-transformers/all-MiniLM-L6-v2" | ||||||
| use_sentence_transformer: bool = True | ||||||
| text_field: str = "text" | ||||||
| embedding_field: str = "embeddings" | ||||||
| max_chars: int | None = None | ||||||
|
|
@@ -115,6 +157,16 @@ class EmbeddingCreatorStage(CompositeStage[DocumentBatch, DocumentBatch]): | |||||
| def __post_init__(self) -> None: | ||||||
| super().__init__() | ||||||
|
|
||||||
| model_class = SentenceTransformerEmbeddingModelStage if self.use_sentence_transformer else EmbeddingModelStage | ||||||
|
|
||||||
| if self.use_sentence_transformer: | ||||||
| logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo in warning message: "ignoring" should be a complete sentence. Consider:
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||
| model_additional_kwargs = {} | ||||||
| else: | ||||||
| model_additional_kwargs = { | ||||||
| "pooling": self.embedding_pooling, | ||||||
| } | ||||||
|
|
||||||
| self.stages = [ | ||||||
| TokenizerStage( | ||||||
| model_identifier=self.model_identifier, | ||||||
|
|
@@ -125,15 +177,15 @@ def __post_init__(self) -> None: | |||||
| padding_side=self.padding_side, | ||||||
| sort_by_length=self.sort_by_length, | ||||||
| ), | ||||||
| EmbeddingModelStage( | ||||||
| model_class( | ||||||
| model_identifier=self.model_identifier, | ||||||
| embedding_field=self.embedding_field, | ||||||
| pooling=self.embedding_pooling, | ||||||
| hf_token=self.hf_token, | ||||||
| model_inference_batch_size=self.model_inference_batch_size, | ||||||
| has_seq_order=self.sort_by_length, | ||||||
| padding_side=self.padding_side, | ||||||
| autocast=self.autocast, | ||||||
| **model_additional_kwargs, | ||||||
| ), | ||||||
| ] | ||||||
|
|
||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the file shouldn't be called |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,138 @@ | ||||||||||||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||||||||||||
| # | ||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||
| # You may obtain a copy of the License at | ||||||||||||
| # | ||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||
| # | ||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||
| # limitations under the License. | ||||||||||||
|
|
||||||||||||
| import time | ||||||||||||
| from typing import TYPE_CHECKING, Any | ||||||||||||
|
|
||||||||||||
| from huggingface_hub import snapshot_download | ||||||||||||
| from vllm import LLM | ||||||||||||
|
|
||||||||||||
| from nemo_curator.backends.base import NodeInfo, WorkerMetadata | ||||||||||||
| from nemo_curator.stages.base import ProcessingStage | ||||||||||||
| from nemo_curator.stages.resources import Resources | ||||||||||||
| from nemo_curator.stages.text.models.utils import format_name_with_suffix | ||||||||||||
| from nemo_curator.tasks import DocumentBatch | ||||||||||||
|
|
||||||||||||
| if TYPE_CHECKING: | ||||||||||||
| from transformers import AutoTokenizer | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class VLLMEmbeddingModelStage(ProcessingStage[DocumentBatch, DocumentBatch]): | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont understand why we have a separate VLLMEmbeddingModelStage . From a user workflow perspective, should we not always use the same interface and let user choose the backend ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like, when the eventual switch happens for VLLM as default, this will make users/tutorials go and change their workflow.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, should we not import this logic for ModelStage here ??? |
||||||||||||
| def __init__( # noqa: PLR0913 | ||||||||||||
| self, | ||||||||||||
| model_identifier: str, | ||||||||||||
| vllm_init_kwargs: dict[str, Any] | None = None, | ||||||||||||
| text_field: str = "text", | ||||||||||||
| pretokenize: bool = False, | ||||||||||||
| embedding_field: str = "embeddings", | ||||||||||||
| cache_dir: str | None = None, | ||||||||||||
| hf_token: str | None = None, | ||||||||||||
| verbose: bool = False, | ||||||||||||
| ): | ||||||||||||
| self.model_identifier = model_identifier | ||||||||||||
| self.vllm_init_kwargs = vllm_init_kwargs or {} | ||||||||||||
|
|
||||||||||||
| self.text_field = text_field | ||||||||||||
| self.pretokenize = pretokenize | ||||||||||||
| self.embedding_field = embedding_field | ||||||||||||
|
|
||||||||||||
| self.cache_dir = cache_dir | ||||||||||||
| self.hf_token = hf_token | ||||||||||||
|
|
||||||||||||
| self.verbose = verbose | ||||||||||||
| # after setup | ||||||||||||
| self.model: None | LLM = None | ||||||||||||
| self.tokenizer: None | AutoTokenizer = None | ||||||||||||
| # stage setup | ||||||||||||
| self.resources = Resources( | ||||||||||||
| cpus=1, | ||||||||||||
| gpus=1, | ||||||||||||
| ) | ||||||||||||
| self.name = format_name_with_suffix(model_identifier, suffix="_vllm") | ||||||||||||
|
|
||||||||||||
| def inputs(self) -> tuple[list[str], list[str]]: | ||||||||||||
| return ["data"], [self.text_field] | ||||||||||||
|
|
||||||||||||
| def outputs(self) -> tuple[list[str], list[str]]: | ||||||||||||
| return ["data"], [self.text_field, self.embedding_field] | ||||||||||||
|
|
||||||||||||
| def _initialize_vllm(self) -> None: | ||||||||||||
| vllm_init_kwargs = self.vllm_init_kwargs.copy() | ||||||||||||
| # set defaults here | ||||||||||||
| if "enforce_eager" not in vllm_init_kwargs: | ||||||||||||
| vllm_init_kwargs["enforce_eager"] = False | ||||||||||||
| if "runner" not in vllm_init_kwargs: | ||||||||||||
| vllm_init_kwargs["runner"] = "pooling" | ||||||||||||
| if "model_impl" not in vllm_init_kwargs: | ||||||||||||
| # TODO: Once transformers is bumpted to 5.0 then we should also support transformers | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: 'bumpted' should be 'bumped'
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo in comment: "bumpted" should be "bumped":
Suggested change
|
||||||||||||
| vllm_init_kwargs["model_impl"] = "vllm" | ||||||||||||
|
|
||||||||||||
| # Reduce verbosity when not in verbose mode | ||||||||||||
| if not self.verbose and "disable_log_stats" not in vllm_init_kwargs: | ||||||||||||
| vllm_init_kwargs["disable_log_stats"] = True | ||||||||||||
|
|
||||||||||||
| self.model = LLM(model=self.model_identifier, **vllm_init_kwargs) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are the parameters I was experimenting with: I am wondering if we should include this logic too: |
||||||||||||
|
|
||||||||||||
| def setup_on_node(self, node_info: NodeInfo | None = None, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002 | ||||||||||||
| if not self.verbose: | ||||||||||||
| from huggingface_hub.utils import disable_progress_bars | ||||||||||||
|
|
||||||||||||
| disable_progress_bars() | ||||||||||||
|
|
||||||||||||
| snapshot_download(self.model_identifier, cache_dir=self.cache_dir, token=self.hf_token, local_files_only=False) | ||||||||||||
| self._initialize_vllm() | ||||||||||||
|
|
||||||||||||
| def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002 | ||||||||||||
| if self.model is None: | ||||||||||||
| self._initialize_vllm() | ||||||||||||
| if self.pretokenize: | ||||||||||||
| from transformers import AutoTokenizer | ||||||||||||
|
|
||||||||||||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_identifier) | ||||||||||||
|
|
||||||||||||
| def process(self, batch: DocumentBatch) -> DocumentBatch: | ||||||||||||
| df = batch.to_pandas() | ||||||||||||
| input_data = df[self.text_field].tolist() | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? |
||||||||||||
| metrics = {} | ||||||||||||
|
|
||||||||||||
| if self.pretokenize: | ||||||||||||
| from vllm.inputs import TokensPrompt | ||||||||||||
|
Comment on lines
+99
to
+110
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move to top-level imports? |
||||||||||||
|
|
||||||||||||
| if self.tokenizer is None: | ||||||||||||
| msg = ( | ||||||||||||
| "Tokenizer is not initialized. Please call setup() before processing or set pretokenize to False." | ||||||||||||
| ) | ||||||||||||
| raise ValueError(msg) | ||||||||||||
|
|
||||||||||||
| t0 = time.perf_counter() | ||||||||||||
| max_model_len = self.model.model_config.max_model_len | ||||||||||||
| tokenized_data = self.tokenizer.batch_encode_plus(input_data, truncation=True, max_length=max_model_len) | ||||||||||||
|
Comment on lines
+119
to
+120
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verify that |
||||||||||||
| input_data = [TokensPrompt(prompt_token_ids=ids) for ids in tokenized_data.input_ids] | ||||||||||||
| metrics["tokenization_time"] = time.perf_counter() - t0 | ||||||||||||
|
|
||||||||||||
| t0 = time.perf_counter() | ||||||||||||
| vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an advantage to doing tokenization then the model forward pass within the same stage? I was thinking that we should either break tokenization into a separate CPU-only stage or always let
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameter value |
||||||||||||
| metrics["vllm_embedding_time"] = time.perf_counter() - t0 | ||||||||||||
|
|
||||||||||||
| df[self.embedding_field] = [e.outputs.embedding for e in vllm_output] | ||||||||||||
|
Comment on lines
+124
to
+128
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing error handling for the |
||||||||||||
|
|
||||||||||||
| self._log_metrics(metrics) | ||||||||||||
|
|
||||||||||||
| return DocumentBatch( | ||||||||||||
| task_id=batch.task_id, | ||||||||||||
| dataset_name=batch.dataset_name, | ||||||||||||
| data=df, | ||||||||||||
| _metadata=batch._metadata, | ||||||||||||
| _stage_perf=batch._stage_perf, | ||||||||||||
| ) | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parent class
EmbeddingModelStage.__init__()is called without passing apoolingparameter, which means the parent's default pooling value will be used. While this doesn't cause issues sinceSentenceTransformerEmbeddingModelStagedoesn't use the pooling attribute, consider adding a comment to clarify this is intentional.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!