Skip to content

Conversation

@praateekmahajan
Copy link
Contributor

@praateekmahajan praateekmahajan commented Dec 31, 2025

Description

Tried out few models

  1. sentence_transformer : Current implementation of TokenizerStage + ModelStage
  2. vllm_text : vLLM with text as input - Single stage of vLLM only
    1. vllm_text_4cpus : Same as above, except we dedicate 4 CPUs to the stage
  3. vllm_tokens : vLLM with tokens as input, tokenization done by TokenizerStage
  4. vllm_text_pretokenized : vLLM with text as input, tokenization done within the stage

The experiment ran on 5gb of common crawl data and the findings were

  1. Dependent on model size
    1. Larger model (embedding gemma): vLLM wins
    2. Smaller model (miniLM) : Sentence Transformer Current implementation wins
  2. When measuring average time per task / partition (combining tokenization and modeling) the fastest, irrespective of model size is vllm_text_pretokenized. This suggests given a large number of tasks and amortized startup cost, we should see vllm_text_pretokenized come out ahead.
  3. For large model
    1. Among vLLM variations, total time taken is similar, but that includes startup time, so we should measure average time spent per task.
  4. For small model
    1. Among vLLM variations, vllm_text is slowest, likely due to tokenization happening sentence by sentence leading to higher GPU idle time, and not justifying the small model runtime on GPU. Increasing cpu allocation for the stage doesn’t improve runtimes.
  5. vllm_tokens is also fast for both small / large models, but it increases the following cost for a negligible overhead compared to vllm_text_pretokenized;
    1. maintenance overhead of maintaining a separate tokenization stage
    2. serialization overhead of transferring data from one actor to another
image image

Usage

# Add snippet demonstrating usage

Checklist

  • I am familiar with the Contributing Guide.
  • New or Existing tests cover these changes.
  • The documentation is up to date with these changes.

Signed-off-by: Praateek <[email protected]>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 31, 2025

Greptile Summary

This PR adds support for vLLM-based embedding generation alongside the existing Sentence Transformers implementation, providing performance improvements for larger models. The implementation introduces VLLMEmbeddingModelStage with optional pretokenization support and integrates cleanly with the existing embedding pipeline.

Key Changes:

  • New VLLMEmbeddingModelStage in vllm.py with optional pretokenization for optimized throughput
  • Added SentenceTransformerEmbeddingModelStage to base.py for proper SentenceTransformer integration
  • EmbeddingCreatorStage now supports both implementations via use_sentence_transformer flag
  • Dependencies updated: vLLM >=0.13, sentence-transformers, torch <=2.9.1, and scikit-learn <1.8.0 constraint
  • Comprehensive test coverage for both vLLM and SentenceTransformer paths with reference validation
  • Ray 2.53.0+ compatibility fix in integration tests

Benchmarking Results (from PR description):

  • For large models (e.g., embedding gemma): vLLM offers better performance
  • For small models (e.g., miniLM): Current SentenceTransformer implementation is faster
  • vllm_text_pretokenized shows fastest average time per task when startup costs are amortized

Minor Issue:

  • Typo in comment: 'bumpted' → 'bumped' at vllm.py:78

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The code is well-structured with comprehensive test coverage, proper error handling, and clean integration with existing systems. The only issue is a minor typo in a comment. The implementation follows established patterns in the codebase, includes proper type hints, and has been benchmarked with real data.
  • No files require special attention beyond fixing the typo in vllm.py:78

Important Files Changed

Filename Overview
nemo_curator/stages/text/embedders/vllm.py New vLLM embedding stage with pretokenization support and proper metrics logging
nemo_curator/stages/text/embedders/base.py Added SentenceTransformerEmbeddingModelStage and use_sentence_transformer flag to EmbeddingCreatorStage
pyproject.toml Added vLLM and sentence-transformers dependencies with proper version constraints

Sequence Diagram

sequenceDiagram
    participant User
    participant VLLMStage as VLLMEmbeddingModelStage
    participant HFHub as HuggingFace Hub
    participant vLLM as vLLM Engine
    participant Tokenizer
    participant Device as GPU Device

    User->>VLLMStage: Initialize stage
    User->>VLLMStage: setup_on_node()
    VLLMStage->>HFHub: snapshot_download(model)
    HFHub-->>VLLMStage: Model cached locally
    VLLMStage->>vLLM: Initialize LLM with pooling runner
    vLLM->>Device: Load model weights
    Device-->>vLLM: Model ready
    vLLM-->>VLLMStage: LLM instance created
    
    User->>VLLMStage: setup()
    alt pretokenize enabled
        VLLMStage->>Tokenizer: Load AutoTokenizer from model
        Tokenizer-->>VLLMStage: Tokenizer initialized
    end
    
    User->>VLLMStage: process(batch)
    VLLMStage->>VLLMStage: Extract text from DataFrame
    
    alt pretokenize enabled
        VLLMStage->>Tokenizer: batch_encode_plus(texts)
        Tokenizer-->>VLLMStage: Token IDs list
        VLLMStage->>VLLMStage: Wrap in TokensPrompt objects
        VLLMStage->>vLLM: embed(token_prompts)
    else pretokenize disabled
        VLLMStage->>vLLM: embed(raw_text_strings)
    end
    
    vLLM->>Device: Generate embeddings
    Device-->>vLLM: Embedding vectors
    vLLM-->>VLLMStage: EmbeddingOutput list
    VLLMStage->>VLLMStage: Extract embeddings into DataFrame
    VLLMStage->>VLLMStage: Log performance metrics
    VLLMStage-->>User: DocumentBatch with embeddings
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds vLLM and SentenceTransformer support for embedding generation, providing performance optimizations for different model sizes. The implementation introduces VLLMEmbeddingModelStage as a new single-stage embedder and adds use_sentence_transformer flag to EmbeddingCreatorStage for choosing between HuggingFace AutoModel and SentenceTransformer backends. According to the PR description, vLLM with pretokenization performs best for large models and high task volumes, while the existing SentenceTransformer implementation remains optimal for smaller models.

Confidence Score: 3/5

  • Safe to merge with minor issues addressed—code is well-tested but has unclear parameter usage
  • The PR introduces substantial new functionality with comprehensive test coverage, but has several minor issues that should be addressed: unclear vLLM parameter usage in the embed call that may cause unexpected behavior, missing error handling for model embedding which could lead to cryptic crashes, typos in comments and warning messages, and lack of validation for model length during pretokenization. The integration tests and reference embedding validation provide good coverage, but the vLLM-specific parameter usage needs clarification or correction.
  • nemo_curator/stages/text/embedders/vllm.py needs parameter validation and error handling; nemo_curator/stages/text/embedders/base.py has a minor typo in warning message

Important Files Changed

File Analysis

Filename Score Overview
nemo_curator/stages/text/embedders/base.py 3/5 Adds SentenceTransformerEmbeddingModelStage class and use_sentence_transformer flag to EmbeddingCreatorStage; warning message has typo
nemo_curator/stages/text/embedders/vllm.py 2/5 New VLLMEmbeddingModelStage for vLLM-based embeddings; missing error handling for model.embed() and unclear truncate_prompt_tokens parameter
pyproject.toml 3/5 Adds vLLM and sentence-transformers dependencies with version constraints; moves vLLM version constraint from video_cuda12 to separate vllm extra

Sequence Diagram

sequenceDiagram
    participant User
    participant EmbeddingCreatorStage
    participant TokenizerStage
    participant ModelStage
    participant VLLMStage

    Note over User,VLLMStage: Traditional Two-Stage Approach
    User->>EmbeddingCreatorStage: process(DocumentBatch)
    EmbeddingCreatorStage->>TokenizerStage: Decompose and setup
    EmbeddingCreatorStage->>ModelStage: Decompose and setup
    activate TokenizerStage
    TokenizerStage->>TokenizerStage: Tokenize text to input IDs
    TokenizerStage-->>ModelStage: TokenizedBatch
    deactivate TokenizerStage
    activate ModelStage
    alt SentenceTransformer Path
        ModelStage->>ModelStage: SentenceTransformer encode
    else HuggingFace Path
        ModelStage->>ModelStage: AutoModel with pooling
    end
    ModelStage-->>User: DocumentBatch with embeddings
    deactivate ModelStage

    Note over User,VLLMStage: New vLLM Single-Stage Approach
    User->>VLLMStage: process(DocumentBatch)
    activate VLLMStage
    alt Pretokenize Enabled
        VLLMStage->>VLLMStage: Tokenize with AutoTokenizer
        VLLMStage->>VLLMStage: vLLM embed with tokens
    else Pretokenize Disabled
        VLLMStage->>VLLMStage: vLLM embed with raw text
    end
    VLLMStage-->>User: DocumentBatch with embeddings
    deactivate VLLMStage
Loading

model_class = SentenceTransformerEmbeddingModelStage if self.use_sentence_transformer else EmbeddingModelStage

if self.use_sentence_transformer:
logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in warning message: "ignoring" should be a complete sentence. Consider:

Suggested change
logger.warning("Using SentenceTransformer for embedding model ignoring embedding_pooling")
logger.warning("Using SentenceTransformer for embedding model; ignoring embedding_pooling parameter")

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +124 to +128
t0 = time.perf_counter()
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose)
metrics["vllm_embedding_time"] = time.perf_counter() - t0

df[self.embedding_field] = [e.outputs.embedding for e in vllm_output]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling for the model.embed() call. If the model fails to generate embeddings (e.g., due to malformed inputs, resource exhaustion, or vLLM errors), the stage will crash without a helpful error message. Consider wrapping this in a try-except block with informative error messaging.

if "runner" not in vllm_init_kwargs:
vllm_init_kwargs["runner"] = "pooling"
if "model_impl" not in vllm_init_kwargs:
# TODO: Once transformers is bumpted to 5.0 then we should also support transformers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "bumpted" should be "bumped":

Suggested change
# TODO: Once transformers is bumpted to 5.0 then we should also support transformers
# TODO: Once transformers is bumped to 5.0 then we should also support transformers

metrics["tokenization_time"] = time.perf_counter() - t0

t0 = time.perf_counter()
vllm_output = self.model.embed(input_data, truncate_prompt_tokens=-1, use_tqdm=self.verbose)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter value -1 for truncate_prompt_tokens may cause unexpected behavior. According to vLLM documentation, this should be either a positive integer or None for no truncation. Consider using None if truncation should be disabled, or verify that -1 is the intended value for this vLLM version.

Comment on lines +114 to +121
super().__init__(
model_identifier=model_identifier,
hf_token=hf_token,
model_inference_batch_size=model_inference_batch_size,
has_seq_order=has_seq_order,
padding_side=padding_side,
autocast=autocast,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent class EmbeddingModelStage.__init__() is called without passing a pooling parameter, which means the parent's default pooling value will be used. While this doesn't cause issues since SentenceTransformerEmbeddingModelStage doesn't use the pooling attribute, consider adding a comment to clarify this is intentional.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +119 to +120
max_model_len = self.model.model_config.max_model_len
tokenized_data = self.tokenizer.batch_encode_plus(input_data, truncation=True, max_length=max_model_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that max_model_len is set to a reasonable value before using it for truncation. If the model config has an unusually large or unset max_model_len, the tokenization could produce unexpected results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants