-
Notifications
You must be signed in to change notification settings - Fork 205
Add vLLM and Sentence Transformers support for embedding generation #1346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add vLLM and Sentence Transformers support for embedding generation #1346
Conversation
Signed-off-by: Praateek <[email protected]>
Greptile SummaryThis PR adds support for vLLM-based embedding generation alongside the existing Sentence Transformers implementation, providing performance improvements for larger models. The implementation introduces Key Changes:
Benchmarking Results (from PR description):
Minor Issue:
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant VLLMStage as VLLMEmbeddingModelStage
participant HFHub as HuggingFace Hub
participant vLLM as vLLM Engine
participant Tokenizer
participant Device as GPU Device
User->>VLLMStage: Initialize stage
User->>VLLMStage: setup_on_node()
VLLMStage->>HFHub: snapshot_download(model)
HFHub-->>VLLMStage: Model cached locally
VLLMStage->>vLLM: Initialize LLM with pooling runner
vLLM->>Device: Load model weights
Device-->>vLLM: Model ready
vLLM-->>VLLMStage: LLM instance created
User->>VLLMStage: setup()
alt pretokenize enabled
VLLMStage->>Tokenizer: Load AutoTokenizer from model
Tokenizer-->>VLLMStage: Tokenizer initialized
end
User->>VLLMStage: process(batch)
VLLMStage->>VLLMStage: Extract text from DataFrame
alt pretokenize enabled
VLLMStage->>Tokenizer: batch_encode_plus(texts)
Tokenizer-->>VLLMStage: Token IDs list
VLLMStage->>VLLMStage: Wrap in TokensPrompt objects
VLLMStage->>vLLM: embed(token_prompts)
else pretokenize disabled
VLLMStage->>vLLM: embed(raw_text_strings)
end
vLLM->>Device: Generate embeddings
Device-->>vLLM: Embedding vectors
vLLM-->>VLLMStage: EmbeddingOutput list
VLLMStage->>VLLMStage: Extract embeddings into DataFrame
VLLMStage->>VLLMStage: Log performance metrics
VLLMStage-->>User: DocumentBatch with embeddings
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds vLLM and SentenceTransformer support for embedding generation, providing performance optimizations for different model sizes. The implementation introduces VLLMEmbeddingModelStage as a new single-stage embedder and adds use_sentence_transformer flag to EmbeddingCreatorStage for choosing between HuggingFace AutoModel and SentenceTransformer backends. According to the PR description, vLLM with pretokenization performs best for large models and high task volumes, while the existing SentenceTransformer implementation remains optimal for smaller models.
Confidence Score: 3/5
- Safe to merge with minor issues addressed—code is well-tested but has unclear parameter usage
- The PR introduces substantial new functionality with comprehensive test coverage, but has several minor issues that should be addressed: unclear vLLM parameter usage in the embed call that may cause unexpected behavior, missing error handling for model embedding which could lead to cryptic crashes, typos in comments and warning messages, and lack of validation for model length during pretokenization. The integration tests and reference embedding validation provide good coverage, but the vLLM-specific parameter usage needs clarification or correction.
- nemo_curator/stages/text/embedders/vllm.py needs parameter validation and error handling; nemo_curator/stages/text/embedders/base.py has a minor typo in warning message
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| nemo_curator/stages/text/embedders/base.py | 3/5 | Adds SentenceTransformerEmbeddingModelStage class and use_sentence_transformer flag to EmbeddingCreatorStage; warning message has typo |
| nemo_curator/stages/text/embedders/vllm.py | 2/5 | New VLLMEmbeddingModelStage for vLLM-based embeddings; missing error handling for model.embed() and unclear truncate_prompt_tokens parameter |
| pyproject.toml | 3/5 | Adds vLLM and sentence-transformers dependencies with version constraints; moves vLLM version constraint from video_cuda12 to separate vllm extra |
Sequence Diagram
sequenceDiagram
participant User
participant EmbeddingCreatorStage
participant TokenizerStage
participant ModelStage
participant VLLMStage
Note over User,VLLMStage: Traditional Two-Stage Approach
User->>EmbeddingCreatorStage: process(DocumentBatch)
EmbeddingCreatorStage->>TokenizerStage: Decompose and setup
EmbeddingCreatorStage->>ModelStage: Decompose and setup
activate TokenizerStage
TokenizerStage->>TokenizerStage: Tokenize text to input IDs
TokenizerStage-->>ModelStage: TokenizedBatch
deactivate TokenizerStage
activate ModelStage
alt SentenceTransformer Path
ModelStage->>ModelStage: SentenceTransformer encode
else HuggingFace Path
ModelStage->>ModelStage: AutoModel with pooling
end
ModelStage-->>User: DocumentBatch with embeddings
deactivate ModelStage
Note over User,VLLMStage: New vLLM Single-Stage Approach
User->>VLLMStage: process(DocumentBatch)
activate VLLMStage
alt Pretokenize Enabled
VLLMStage->>VLLMStage: Tokenize with AutoTokenizer
VLLMStage->>VLLMStage: vLLM embed with tokens
else Pretokenize Disabled
VLLMStage->>VLLMStage: vLLM embed with raw text
end
VLLMStage-->>User: DocumentBatch with embeddings
deactivate VLLMStage
| model_class = SentenceTransformerEmbeddingModelStage if self.use_sentence_transformer else EmbeddingModelStage | ||
|
|
||
| if self.use_sentence_transformer: | ||
| logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in warning message: "ignoring" should be a complete sentence. Consider:
| logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling") | |
| logger.warning("Using SentenceTransformer for embedding model; ignoring embedding_pooling parameter") |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| t0 = time.perf_counter() | ||
| vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose) | ||
| metrics["vllm_embedding_time"] = time.perf_counter() - t0 | ||
|
|
||
| df[self.embedding_field] = [e.outputs.embedding for e in vllm_output] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling for the model.embed() call. If the model fails to generate embeddings (e.g., due to malformed inputs, resource exhaustion, or vLLM errors), the stage will crash without a helpful error message. Consider wrapping this in a try-except block with informative error messaging.
| if "runner" not in vllm_init_kwargs: | ||
| vllm_init_kwargs["runner"] = "pooling" | ||
| if "model_impl" not in vllm_init_kwargs: | ||
| # TODO: Once transformers is bumpted to 5.0 then we should also support transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in comment: "bumpted" should be "bumped":
| # TODO: Once transformers is bumpted to 5.0 then we should also support transformers | |
| # TODO: Once transformers is bumped to 5.0 then we should also support transformers |
| metrics["tokenization_time"] = time.perf_counter() - t0 | ||
|
|
||
| t0 = time.perf_counter() | ||
| vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter value -1 for truncate_prompt_tokens may cause unexpected behavior. According to vLLM documentation, this should be either a positive integer or None for no truncation. Consider using None if truncation should be disabled, or verify that -1 is the intended value for this vLLM version.
| super().__init__( | ||
| model_identifier=model_identifier, | ||
| hf_token=hf_token, | ||
| model_inference_batch_size=model_inference_batch_size, | ||
| has_seq_order=has_seq_order, | ||
| padding_side=padding_side, | ||
| autocast=autocast, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parent class EmbeddingModelStage.__init__() is called without passing a pooling parameter, which means the parent's default pooling value will be used. While this doesn't cause issues since SentenceTransformerEmbeddingModelStage doesn't use the pooling attribute, consider adding a comment to clarify this is intentional.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| max_model_len = self.model.model_config.max_model_len | ||
| tokenized_data = self.tokenizer.batch_encode_plus(input_data, truncation=True, max_length=max_model_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify that max_model_len is set to a reasonable value before using it for truncation. If the model config has an unusually large or unset max_model_len, the tokenization could produce unexpected results.
Description
Tried out few models
sentence_transformer: Current implementation of TokenizerStage + ModelStagevllm_text: vLLM with text as input - Single stage of vLLM onlyvllm_text_4cpus: Same as above, except we dedicate 4 CPUs to the stagevllm_tokens: vLLM with tokens as input, tokenization done by TokenizerStagevllm_text_pretokenized: vLLM with text as input, tokenization done within the stageThe experiment ran on 5gb of common crawl data and the findings were
vllm_text_pretokenized. This suggests given a large number of tasks and amortized startup cost, we should seevllm_text_pretokenizedcome out ahead.vllm_textis slowest, likely due to tokenization happening sentence by sentence leading to higher GPU idle time, and not justifying the small model runtime on GPU. Increasing cpu allocation for the stage doesn’t improve runtimes.vllm_text_pretokenized;Usage
# Add snippet demonstrating usageChecklist