Skip to content

Conversation

@tonywngzon
Copy link

Implement custom NXDI modeling for Qwen3 embedding models

Details

  • optimum==1.24.0
  • optimum-neuron==0.4.0
  • libneuronxla==2.2.12677.0+470fa032
  • neuronx-cc==2.21.18209.0+043b1bf7
  • neuronx-distributed==0.15.22404+1f27bddf
  • neuronx-distributed-inference==0.6.10598+a59fdc00
  • torch-neuronx==2.8.0.2.10.13553+1e4dd6ca
  • sentence-transformers==5.1.1
  • transformers==4.55.4
  • machine: trn2

When targeting neuron for compilation, any file with a config_sentence_transformers.json will be targeted with the sentence_transformers library and traced with no support for NXDI or TP. This happens because of the exporter logic in exporters/tasks.py:1994:

elif ( 
    any(file_path.startswith("sentence_") for file_path in all_files) 
    or "config_sentence_transformers.json" in all_files 
): 
    inferred_library_name = "sentence_transformers"

Original Behavior

Running a Qwen3Embedding model with the default path:

optimum-cli export neuron --model Qwen/Qwen3-Embedding-0.6B --batch_size 1 --sequence_length 1024 --auto_cast matmul --instance_type trn2 --tensor_parallel_size 4 qwen3-embedding-0.6b-neuron/

Results in the sentence_transformers model being traced, without NXDI / tensor parallelism support.

Solution

I've implemented the Qwen3 embedding model using decoder modeling code with up to 4x improvement in throughput for the same model and seqlen, with tensor parallelism enabled: https://github.com/tonywngzon/optimum-neuron

Changes Made

  1. Added optional embedding_model parameter to NxDNeuronConfig (optimum-neuron/optimum/neuron/models/inference/backend/config.py):

    • Set to True in the Qwen3 embedding model
    • Skips creating token generation model when embedding_model=True:
    token_generation_model = (NxDModelForCausalLM.create_token_generation_wrapper(
            model_cls,
            config,
            neuron_config,
            **model_init_kwargs,
        )
        if neuron_config.embedding_model == False
        else None
    )
  2. Changed export order (optimum-neuron/optimum/exporters/neuron/__main__.py):

    • Now tries to export a neuron model before attempting a sentencetransformers model
    • Moves the neuron model export attempt before the library-specific exports
  3. Registered the embedding model (optimum-neuron/optimum/neuron/models/inference/auto_models.py:109):

    @register_neuron_model_for_inference("qwen3", "feature-extraction")
    class Qwen3NeuronModelForCausalLMEmbedding(Qwen3NxDModelForCausalLMEmbedding):
  4. Created new modeling_qwen3_embedding.py file that overrides the original decoder foward function to support embeddings:

    • Extends NxDDecoderModel with a custom NxDQwen3EmbeddingModel that overrides the forward pass to return hidden states instead of logits
    • Implements Qwen3NxDModelForCausalLMEmbedding that wraps the model with embedding-specific methods
    • Key modifications:
      • forward() method returns hidden states directly
      • Disables continuous_batching and on_device_sampling for embeddings
      • Adds encode() method for getting embeddings with proper position_ids handling
      • Maintains the same computation graph as the decoder model for compilation compatibility

Result

With these changes:

  • The task is autodetected as feature-extraction due to the presence of the sentencetransformers library (optimum/exporters/tasks.py:1834)
  • The model is exported as a Qwen3NxDModelForCausalLMEmbedding with full NXDI modeling support
  • No need to manually remove modules.json or config_sentence_transformers.json - the exporter now correctly handles these models
  • Thus, we can handily export the model with the command, with full nxdi modeling and TP support:
optimum-cli export neuron --model Qwen/Qwen3-Embedding-0.6B --batch_size 1 --sequence_length 1024 --auto_cast matmul --instance_type trn2 --tensor_parallel_size 4 qwen3-embedding-0.6b-neuron/

This solves the issue where models with config_sentence_transformers.json were forced down the sentence_transformers path, preventing the use of optimum-neuron's more performant NXDI implementation with tensor parallelism support.

Testing

Tested locally by:

  1. Simply running the export command - the task and model are autodetected correctly
  2. Verifying tensor parallelism works correctly
  3. Benchmarking shows signficant throughput improvement

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dacorvo
Copy link
Collaborator

dacorvo commented Oct 22, 2025

Thanks for this pull-request. This goes in the right direction, however I would like to avoid modifying the classes under backend/modules/decoder, and rather write new classes under backend/modules/encoder. The amount of code to copy should be reduced after the last refactoring on main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants