Skip to content

Commit 37953b0

Browse files
nyo16jonatanklosko
andauthored
Add Qwen3 model support (#423)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent 55ec9ac commit 37953b0

File tree

9 files changed

+1097
-4
lines changed

9 files changed

+1097
-4
lines changed

lib/bumblebee.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ defmodule Bumblebee do
178178
"Phi3ForCausalLM" => {Bumblebee.Text.Phi3, :for_causal_language_modeling},
179179
"Phi3ForSequenceClassification" => {Bumblebee.Text.Phi3, :for_sequence_classification},
180180
"Phi3ForTokenClassification" => {Bumblebee.Text.Phi3, :for_token_classification},
181+
"Qwen3Model" => {Bumblebee.Text.Qwen3, :base},
182+
"Qwen3ForCausalLM" => {Bumblebee.Text.Qwen3, :for_causal_language_modeling},
183+
"Qwen3ForSequenceClassification" => {Bumblebee.Text.Qwen3, :for_sequence_classification},
181184
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
182185
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
183186
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
@@ -258,6 +261,7 @@ defmodule Bumblebee do
258261
"mbart" => :mbart,
259262
"phi" => :code_gen,
260263
"phi3" => :llama,
264+
"qwen3" => :qwen2,
261265
"roberta" => :roberta,
262266
"smollm3" => :smollm3,
263267
"t5" => :t5,

lib/bumblebee/layers/transformer.ex

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ defmodule Bumblebee.Layers.Transformer do
5353
:layer_norm,
5454
:block_type,
5555
:attention_window_size,
56-
:scale_attention_weights
56+
:scale_attention_weights,
57+
:query_norm,
58+
:key_norm
5759
]
5860

5961
opts =
@@ -330,7 +332,9 @@ defmodule Bumblebee.Layers.Transformer do
330332
layer_norm: [],
331333
attention_window_size: nil,
332334
scale_attention_weights: true,
333-
rotary_embedding: nil
335+
rotary_embedding: nil,
336+
query_norm: nil,
337+
key_norm: nil
334338
])
335339

336340
name = opts[:name]
@@ -360,6 +364,8 @@ defmodule Bumblebee.Layers.Transformer do
360364
attention_window_size = opts[:attention_window_size]
361365
scale_attention_weights = opts[:scale_attention_weights]
362366
rotary_embedding = opts[:rotary_embedding]
367+
query_norm = opts[:query_norm]
368+
key_norm = opts[:key_norm]
363369

364370
ffn_fun =
365371
case ffn do
@@ -418,6 +424,8 @@ defmodule Bumblebee.Layers.Transformer do
418424
attention_window_size: attention_window_size,
419425
scale_attention_weights: scale_attention_weights,
420426
rotary_embedding: rotary_embedding,
427+
query_norm: query_norm,
428+
key_norm: key_norm,
421429
name: join(name, "self_attention")
422430
)
423431

@@ -703,6 +711,14 @@ defmodule Bumblebee.Layers.Transformer do
703711
704712
* `:max_positions` - the maximum number of distinct positions
705713
714+
* `:query_norm` - a function that applies normalization to the query
715+
projection before rotary embedding. The function should accept two
716+
arguments: the input and a name for the layer. Defaults to `nil`
717+
718+
* `:key_norm` - a function that applies normalization to the key
719+
projection before rotary embedding. The function should accept two
720+
arguments: the input and a name for the layer. Defaults to `nil`
721+
706722
* `:name` - the prefix for layer names
707723
708724
## References
@@ -734,7 +750,9 @@ defmodule Bumblebee.Layers.Transformer do
734750
key_use_bias: true,
735751
value_use_bias: true,
736752
output_use_bias: true,
737-
rotary_embedding: nil
753+
rotary_embedding: nil,
754+
query_norm: nil,
755+
key_norm: nil
738756
])
739757

740758
attention_mask = opts[:attention_mask]
@@ -752,6 +770,8 @@ defmodule Bumblebee.Layers.Transformer do
752770
scale_attention_weights = opts[:scale_attention_weights]
753771
dropout_rate = opts[:dropout_rate]
754772
rotary_embedding = opts[:rotary_embedding]
773+
query_norm = opts[:query_norm]
774+
key_norm = opts[:key_norm]
755775

756776
query_use_bias = opts[:query_use_bias]
757777
key_use_bias = opts[:key_use_bias]
@@ -791,6 +811,21 @@ defmodule Bumblebee.Layers.Transformer do
791811
)
792812
|> Layers.split_heads(num_key_value_heads)
793813

814+
# Apply query and key normalization if configured (before rotary embedding)
815+
query =
816+
if query_norm do
817+
query_norm.(query, join(name, "query_norm"))
818+
else
819+
query
820+
end
821+
822+
key =
823+
if key_norm do
824+
key_norm.(key, join(name, "key_norm"))
825+
else
826+
key
827+
end
828+
794829
{query, key} =
795830
case rotary_embedding do
796831
opts when is_list(opts) ->

lib/bumblebee/text.ex

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ defmodule Bumblebee.Text do
385385
Note that we currently assume that the CLS token is the first token
386386
in the sequence
387387
388+
* `:last_token_pooling` - takes the embedding for the last non-padding
389+
token in each sequence
390+
388391
By default no pooling is applied
389392
390393
* `:embedding_processor` - a post-processing step to apply to the
@@ -444,6 +447,82 @@ defmodule Bumblebee.Text do
444447
defdelegate text_embedding(model_info, tokenizer, opts \\ []),
445448
to: Bumblebee.Text.TextEmbedding
446449

450+
@type text_reranking_qwen3_input :: {String.t(), String.t()} | [{String.t(), String.t()}]
451+
@type text_reranking_qwen3_output :: %{
452+
scores: text_reranking_qwen3_score() | list(text_reranking_qwen3_score())
453+
}
454+
@type text_reranking_qwen3_score :: %{score: number(), query: String.t(), document: String.t()}
455+
456+
@doc """
457+
Builds a serving for text reranking with Qwen3 reranker models.
458+
459+
The serving expects input in one of the following formats:
460+
461+
* `{query, document}` - a tuple with query and document text
462+
* `[{query1, doc1}, {query2, doc2}, ...]` - a list of query-document pairs
463+
464+
## Options
465+
466+
* `:yes_token` - the token ID corresponding to "yes" for relevance scoring.
467+
If not provided, will be inferred from the tokenizer
468+
469+
* `:no_token` - the token ID corresponding to "no" for relevance scoring.
470+
If not provided, will be inferred from the tokenizer
471+
472+
* `:instruction_prefix` - the instruction prefix to use. Defaults to the
473+
Qwen3 reranker format
474+
475+
* `:instruction_suffix` - the instruction suffix to use. Defaults to the
476+
Qwen3 reranker format
477+
478+
* `:task_description` - the task description to include in prompts. Defaults
479+
to "Given a web search query, retrieve relevant passages that answer the query"
480+
481+
* `:compile` - compiles all computations for predefined input shapes
482+
during serving initialization. Should be a keyword list with the
483+
following keys:
484+
485+
* `:batch_size` - the maximum batch size of the input. Inputs
486+
are optionally padded to always match this batch size
487+
488+
* `:sequence_length` - the maximum input sequence length. Input
489+
sequences are always padded/truncated to match that length
490+
491+
It is advised to set this option in production and also configure
492+
a defn compiler using `:defn_options` to maximally reduce inference
493+
time
494+
495+
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
496+
497+
* `:preallocate_params` - when `true`, explicitly allocates params
498+
on the device configured in `:defn_options`. You may want to set
499+
this option when using partitioned models on the GPU. Defaults to `false`
500+
501+
## Examples
502+
503+
{:ok, model_info} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"})
504+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Reranker-0.6B"})
505+
506+
serving = Bumblebee.Text.text_reranking_qwen3(model_info, tokenizer)
507+
508+
query = "What is the capital of France?"
509+
documents = [
510+
"Paris is the capital of France.",
511+
"Berlin is the capital of Germany."
512+
]
513+
514+
pairs = Enum.map(documents, &{query, &1})
515+
Nx.Serving.run(serving, pairs)
516+
517+
"""
518+
@spec text_reranking_qwen3(
519+
Bumblebee.model_info(),
520+
Bumblebee.Tokenizer.t(),
521+
keyword()
522+
) :: Nx.Serving.t()
523+
defdelegate text_reranking_qwen3(model_info, tokenizer, opts \\ []),
524+
to: Bumblebee.Text.TextRerankingQwen3
525+
447526
@type fill_mask_input :: String.t()
448527
@type fill_mask_output :: %{predictions: list(fill_mask_prediction())}
449528
@type fill_mask_prediction :: %{score: number(), token: String.t()}

lib/bumblebee/text/pre_trained_tokenizer.ex

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
200200
},
201201
default_template_options: [language_token: "eng_Latn"]
202202
},
203+
qwen2: %{
204+
special_tokens: %{
205+
unk: "<|endoftext|>",
206+
eos: "<|endoftext|>",
207+
pad: "<|endoftext|>"
208+
}
209+
},
203210
roberta: %{
204211
special_tokens: %{
205212
bos: "<s>",

0 commit comments

Comments
 (0)