Skip to content

Add support for Qwen3 Embedding #1957

@BBC-Esq

Description

@BBC-Esq

Now that the PR for Qwen3 has been merged it'd be awesome to support the qwen3 embedding models. I'm going to simply past the research (ai-assisted) that I've done so far to get the conversation going, but these models are the best open source ones out there, that is, with reasonable compute requires and/or speed:

https://huggingface.co/collections/Qwen/qwen3-embedding

MY RESEARCH WITH CLAUDE 4.5 OPUS

Now I have a complete picture. Looking at the code, I can see:

  1. CTranslate2 already has EncoderReplica (in language_model.cc) that returns last_hidden_state + optional pooler_output — this is exactly what embedding models need.

  2. But it's designed for encoder models (like BERT), not decoder models.

  3. The Qwen3 Embedding model uses Qwen3ForCausalLM architecture but functions as an embedding model by extracting hidden states instead of generating logits.

The Core Problem

Looking at language_model.cc:

  • DecoderReplica::forward() returns logits (goes through lm_head)
  • EncoderReplica::forward_impl() returns last_hidden_state + pooler_output

For Qwen3 Embedding, you need a decoder architecture that returns hidden states (like an encoder does).


Implementation Options

Option 1: Create a New DecoderEmbeddingReplica Class (Cleanest)

This requires C++ changes but is the proper solution.

Option 2: Hack via Encoder Spec (Might work, less code)

Since TransformerEncoderModelSpec already returns hidden states, you could potentially:

  1. Treat the Qwen3 decoder layers as "encoder" layers for CTranslate2's purposes
  2. Map the weights appropriately

But this is fragile because decoder attention has different masking behavior.

Option 3: Modify the Converter to Skip lm_head and Return Hidden States

This still requires C++ changes to DecoderReplica.


Recommended Implementation: Option 1

Here's what you need to create:

Step 1: Python Spec (New File or Add to transformer_spec.py)

class TransformerDecoderEmbeddingModelConfig(model_spec.LanguageModelConfig):
    """Configuration for Transformer decoder embedding models."""
def __init__(
    self,
    layer_norm_epsilon: Optional[float] = None,
    pooling_strategy: str = "last_token",
    **kwargs
):
    super().__init__(layer_norm_epsilon=layer_norm_epsilon, **kwargs)
    self.pooling_strategy = pooling_strategy

class TransformerDecoderEmbeddingModelSpec(model_spec.LanguageModelSpec):
"""Describes a Transformer decoder model used for embeddings (e.g. Qwen3-Embedding)."""

def __init__(
    self,
    decoder: TransformerDecoderSpec,
    pooling_strategy: str = "last_token",
):
    if not isinstance(decoder, TransformerDecoderSpec):
        raise TypeError("decoder argument must be a TransformerDecoderSpec")

    super().__init__()
    self.decoder = decoder
    
    # Remove the projection layer (lm_head) - we don't need it for embeddings
    # Keep everything else from the decoder
    delattr(self.decoder, 'projection')
    
    self._pooling_strategy = pooling_strategy
    for key, value in self.decoder.config.items():
        self._config.add_attribute(key, value)
    self._config.add_attribute("pooling_strategy", pooling_strategy)

@classmethod
def from_config(cls, pooling_strategy: str = "last_token", **kwargs):
    """Creates a Transformer decoder embedding model specification.
    
    Args:
        pooling_strategy: How to pool hidden states ("last_token", "mean", "first_token")
        **kwargs: Same arguments as TransformerDecoderModelSpec.from_config()
    """
    # Create a decoder spec without projection
    decoder = TransformerDecoderSpec(
        with_encoder_attention=False,
        **kwargs
    )
    
    return cls(decoder, pooling_strategy=pooling_strategy)

@property
def name(self):
    return "TransformerDecoderEmbeddingSpec"

@property
def revision(self):
    return 1

def get_default_config(self):
    return TransformerDecoderEmbeddingModelConfig(
        pooling_strategy=self._pooling_strategy
    )

def get_vocabulary_size(self):
    return self.decoder.embeddings.weight.shape[0]

Step 2: Python Converter (Modify your Qwen3Loader)

@register_loader("Qwen3EmbeddingConfig")  # Or detect from model metadata
class Qwen3EmbeddingLoader(Qwen3Loader):
    @property
    def architecture_name(self):
        return "Qwen3ForCausalLM"  # Same architecture, different usage
def get_model_spec(self, model):
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
    head_dim = getattr(
        model.config, "head_dim", model.config.hidden_size // num_heads
    )

    if num_heads_kv == num_heads:
        num_heads_kv = None

    # Handle RoPE scaling same as parent
    rope_scaling = getattr(model.config, "rope_scaling", None)
    if rope_scaling:
        rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
        rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
        rotary_scaling_factor = rope_scaling["factor"]
        if rotary_scaling_type is None:
            raise NotImplementedError(
                "RoPE scaling type '%s' is not yet implemented." % rope_type
            )
    else:
        rotary_scaling_type = None
        rotary_scaling_factor = 1

    # Use the NEW embedding spec instead of decoder model spec
    spec = transformer_spec.TransformerDecoderEmbeddingModelSpec.from_config(
        num_layers=num_layers,
        num_heads=num_heads,
        activation=common_spec.Activation.SWISH,
        pre_norm=True,
        ffn_glu=True,
        rms_norm=True,
        rotary_dim=model.config.head_dim,
        rotary_interleave=False,
        rotary_scaling_type=rotary_scaling_type,
        rotary_scaling_factor=rotary_scaling_factor,
        rotary_base=getattr(model.config, "rope_theta", 10000),
        num_heads_kv=num_heads_kv,
        head_dim=head_dim,
        qk_norm=True,
        pooling_strategy="last_token",  # Qwen3 Embedding uses last token
    )

    self.set_decoder(spec.decoder, model.model)
    # NO lm_head projection - that's the key difference!
    return spec

def set_decoder(self, spec, module):
    """Same as parent but without setting projection."""
    spec.scale_embeddings = False
    self.set_embeddings(spec.embeddings, module.embed_tokens)
    self.set_layer_norm(spec.layer_norm, module.norm)

    for layer_idx, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
        self.set_layer_norm(
            layer_spec.self_attention.layer_norm, layer.input_layernorm
        )
        self.set_layer_norm(
            layer_spec.ffn.layer_norm, layer.post_attention_layernorm
        )

        self.set_layer_norm(
            layer_spec.self_attention.q_norm, layer.self_attn.q_norm
        )
        self.set_layer_norm(
            layer_spec.self_attention.k_norm, layer.self_attn.k_norm
        )

        split_layers = [common_spec.LinearSpec() for _ in range(3)]
        self.set_linear(split_layers[0], layer.self_attn.q_proj)
        self.set_linear(split_layers[1], layer.self_attn.k_proj)
        self.set_linear(split_layers[2], layer.self_attn.v_proj)
        utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)

        self.set_linear(
            layer_spec.self_attention.linear[1],
            layer.self_attn.o_proj,
        )

        self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
        self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
        self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

        delattr(layer, "self_attn")
        delattr(layer, "mlp")

Step 3: C++ Changes Required

This is where the significant work is. You need to modify language_model.cc and language_model.h:

In include/ctranslate2/models/language_model.h:

// Add new output struct
struct DecoderEmbeddingOutput {
  StorageView last_hidden_state;
  StorageView pooled_output;  // After pooling strategy applied
};

// Add new replica class
class DecoderEmbeddingReplica : public SequenceEncoderReplica {
public:
DecoderEmbeddingReplica(const std::shared_ptr<const LanguageModel>& model,
std::unique_ptr<layers::Decoder> decoder);

EncoderForwardOutput forward_impl(const StorageView& ids,
const StorageView& lengths,
const StorageView& token_type_ids) override;

private:
std::shared_ptr<const LanguageModel> _model;
std::unique_ptr<layers::Decoder> _decoder;
std::string _pooling_strategy;

StorageView pool_hidden_states(const StorageView& hidden_states,
const StorageView& lengths);
};

In src/models/language_model.cc:

DecoderEmbeddingReplica::DecoderEmbeddingReplica(
    const std::shared_ptr<const LanguageModel>& model,
    std::unique_ptr<layers::Decoder> decoder)
  : SequenceEncoderReplica(model)
  , _model(model)
  , _decoder(std::move(decoder))
  , _pooling_strategy(model->config.value("pooling_strategy", "last_token"))
{
}

EncoderForwardOutput
DecoderEmbeddingReplica::forward_impl(const StorageView& ids,
const StorageView& lengths,
const StorageView& token_type_ids) {
if (ids.rank() != 2)
throw std::invalid_argument("Expected input ids to have 2 dimensions");

auto& decoder = *_decoder;
const Device device = _model->device();
const DataType dtype = decoder.output_type();

// Run decoder WITHOUT the output projection (lm_head)
auto state = decoder.initial_state(/iterative_decoding=/false);

StorageView hidden_states(dtype, device);
// Need to modify decoder to support returning hidden states
decoder.forward_without_projection(ids, lengths, state, hidden_states);

EncoderForwardOutput output;
output.last_hidden_state = std::move(hidden_states);

// Apply pooling
output.pooler_output = pool_hidden_states(output.last_hidden_state, lengths);

return output;
}

StorageView
DecoderEmbeddingReplica::pool_hidden_states(const StorageView& hidden_states,
const StorageView& lengths) {
const Device device = hidden_states.device();
const DataType dtype = hidden_states.dtype();
const dim_t batch_size = hidden_states.dim(0);
const dim_t hidden_dim = hidden_states.dim(2);

StorageView pooled({batch_size, hidden_dim}, dtype, device);

if (_pooling_strategy == "last_token") {
// Extract hidden state at last token position for each batch
for (dim_t i = 0; i < batch_size; ++i) {
dim_t last_pos = lengths.at<int32_t>(i) - 1;
// Copy hidden_states[i, last_pos, :] to pooled[i, :]
ops::Gather(/axis=/1, /batch_dims=/1)(
hidden_states.index({i}),
StorageView({1}, last_pos, device),
pooled.index({i})
);
}
} else if (_pooling_strategy == "mean") {
// Mean pool over sequence length (respecting lengths)
ops::Mean(/axis=/1)(hidden_states, lengths, pooled);
} else if (_pooling_strategy == "first_token") {
StorageView first_index({batch_size}, int32_t(0), device);
ops::Gather(/axis=/1, /batch_dims=/1)(hidden_states, first_index, pooled);
}

return pooled;
}

Step 4: Model Factory Registration

In src/models/model_factory.cc, add:

case "TransformerDecoderEmbeddingSpec":
  return std::make_unique<DecoderEmbeddingModel>(...);

Step 5: Python Bindings

In python/cpp/module.cc or create a new file, expose the DecoderEmbedder class similar to how Encoder is exposed.


Summary of Files to Modify/Create

File Action
python/ctranslate2/specs/transformer_spec.py Add TransformerDecoderEmbeddingModelSpec
python/ctranslate2/converters/transformers.py Add Qwen3EmbeddingLoader class
include/ctranslate2/models/language_model.h Add DecoderEmbeddingReplica class
src/models/language_model.cc Implement DecoderEmbeddingReplica
src/models/model_factory.cc Register new model type
python/cpp/encoder.cc or new file Python bindings for embedding model
python/ctranslate2/init.py Export new classes

I'm also attaching some basic AI-assisted research files since I'm not a programmer by trade, so don't hold it against me!

Complete Implementation Guide_ Adding Embedding Support for Decoder-Only Models in CTranslate2.pdf

Adding Qwen3 Embedding Support to CTranslate2_ Complete Implementation Guide.pdf

NOTE: This one was written before the Qwen3 architecture was supported so it's outdated. The portion after that regarding Qwen3 embedding model support might still be useful though.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions