Skip to content

Commit 2403543

Browse files
committed
add qwen3 rag models
1 parent d3e46b5 commit 2403543

File tree

10 files changed

+232
-20
lines changed

10 files changed

+232
-20
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
1313

1414
**What's New:**
1515

16+
* 2025-06-06: Qwen-3 Embedding & Reranker
1617
* 2025-06-03: Kimi-VL
1718
* 2025-05-28: Gemma3 fully supported
1819
* 2025-05-23: [I can see](./docs/multimodal.md): Fuyu

convert.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ class ModelType(Enum):
201201
OrpheusTTS = 0x10000106
202202
OuteTTSLlaMA = 0x10000107
203203
OuteTTSQwen3 = 0x10000108
204+
QWen3_Embedding = 0x10000109
205+
QWen3_ReRanker = 0x1000010A
204206

205207
LlaMAMulti = 0x20000001
206208

@@ -4489,6 +4491,7 @@ class QWen3Converter(BaseConverter):
44894491
MODEL_TYPE = ModelType.QWen3
44904492

44914493
layer_is_sparse = []
4494+
has_lm_head = True
44924495

44934496
@staticmethod
44944497
def dump_config(f, config, ggml_type):
@@ -4591,13 +4594,33 @@ def get_weight_names(config):
45914594
"model.norm.weight"
45924595
]
45934596

4594-
if not config.tie_word_embeddings:
4597+
if QWen3Converter.has_lm_head and (not config.tie_word_embeddings):
45954598
weight_names += [
45964599
"lm_head.weight"
45974600
]
45984601

45994602
return weight_names
46004603

4604+
class QWen3EmbConverter(BaseConverter):
4605+
MODEL_TYPE = ModelType.QWen3_Embedding
4606+
4607+
@classmethod
4608+
def state_dict_pp(cls, config, state_dict):
4609+
r = {}
4610+
for name in state_dict:
4611+
r['model.' + name] = state_dict[name]
4612+
4613+
return r
4614+
4615+
@staticmethod
4616+
def dump_config(f, config, ggml_type):
4617+
QWen3Converter.dump_config(f, config, ggml_type)
4618+
4619+
@staticmethod
4620+
def get_weight_names(config):
4621+
QWen3Converter.has_lm_head = False
4622+
return QWen3Converter.get_weight_names(config)
4623+
46014624
def permute2(weights: torch.Tensor, n_head: int, partial_rotary_factor: float) -> torch.Tensor:
46024625
hidden_size = weights.shape[0]
46034626
head_dim = hidden_size // n_head
@@ -7126,6 +7149,11 @@ def main():
71267149
elif arch == 'deepseek-r1-distill-qwen3':
71277150
QWen3Converter.MODEL_TYPE = ModelType.DeepSeek_R1_Distill_QWen3
71287151
QWen3Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
7152+
elif arch == 'qwen3-embedding':
7153+
QWen3EmbConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
7154+
elif arch == 'qwen3-reranker':
7155+
QWen3Converter.MODEL_TYPE = ModelType.QWen3_ReRanker
7156+
QWen3Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
71297157
elif arch == 'reka-flash-3':
71307158
assert config.rope_scaling is None, 'config.rope_scaling must be null'
71317159
assert not config.tie_word_embeddings, 'config.tie_word_embeddings must be false'

docs/models.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,13 @@ Please use `--format completion` for these models.
313313

314314
Note: Only dense embedding is implemented.
315315

316+
* Qwen-3 Embedding: [0.6B](https://huggingface.co/Qwen/Qwen3-Embedding-0.6B/tree/b22da495047858cce924d27d76261e96be6febc0), [4B](https://huggingface.co/Qwen/Qwen3-Embedding-4B/tree/636cd9bf47d976946cdbb2b0c3ca0cb2f8eea5ff), [8B](https://huggingface.co/Qwen/Qwen3-Embedding-8B/commit/4e423935c619ae4df87b646a3ce949610c66241c)
317+
316318
* QA Ranking (`XLMRobertaForSequenceClassification`)
317319
* [x] [BCE-ReRanker](https://huggingface.co/maidalun1020/bce-reranker-base_v1)
318320
* [x] [BGE-ReRanker-M3](https://huggingface.co/BAAI/bge-reranker-v2-m3) (`-a BGE-Reranker-M3`)
319321
* [x] [MiniCPM-Reranker-Light](https://huggingface.co/openbmb/MiniCPM-Reranker-Light)
322+
* [x] Qwen-3 Reranker: [0.6B](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/tree/ad4c588e592307dad69ff0fabc1b3ca5ea8e9f76), [4B](https://huggingface.co/Qwen/Qwen3-Reranker-4B/tree/57906229d41697e4494d50ca5859598cf86154a1), [8B](https://huggingface.co/Qwen/Qwen3-Reranker-8B/tree/d678ef8b29dd0eb9d784473da5d5169b21ec948a)
320323

321324
## LoRA Models
322325

models/qwen.cpp

Lines changed: 154 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -793,23 +793,23 @@ namespace v3
793793
class ConditionalGeneration : public BaseModelForConditionalGeneration
794794
{
795795
public:
796-
typedef BaseModelForConditionalGeneration Base;
797796
typedef HeterogeneousModel ModelClass;
798797
public:
799-
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_QWEN3)
798+
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_QWEN3, const bool skip_lm_head = false, int extra_tensors = 0)
800799
: BaseModelForConditionalGeneration(type, config, runtime_config, 4096 * 4),
801800
config(config)
802801
{
803802
const size_t tensor_ovhd = ggml_tensor_overhead();
804803
const int sparse_layers = get_sparse_layer_num();
805804
const size_t num_tensors = 3 + (config.tie_word_embeddings ? -1 : 0)
806805
+ (config.num_hidden_layers - sparse_layers) * 14
807-
+ sparse_layers * (14 + 1);
806+
+ sparse_layers * (14 + 1)
807+
+ extra_tensors;
808808
const size_t ctx_size = num_tensors * tensor_ovhd;
809809
w_ctx_.gctx = GGMLContext({.mem_size = ctx_size, .mem_buffer = nullptr, .no_alloc = true});
810810
w_ctx_.dtype = config.dtype;
811811

812-
if (config.tie_word_embeddings)
812+
if (skip_lm_head || config.tie_word_embeddings)
813813
{
814814
transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
815815
create_embedding<Embedding>(&w_ctx_, config),
@@ -837,18 +837,18 @@ namespace v3
837837
{
838838
if (config.layer_is_sparse[i])
839839
{
840-
auto layer = (QWen3MoEBlock128_8 *)Base::get_typed_transformer<ModelClass>()->get_layer(i);
840+
auto layer = (QWen3MoEBlock128_8 *)get_typed_transformer<ModelClass>()->get_layer(i);
841841
layer->attention.freq_base = config.rope_theta;
842842
layer->mlp.norm_topk_prob = config.norm_topk_prob != 0;
843843
}
844844
else
845845
{
846-
auto layer = (QWen3Block *)Base::get_typed_transformer<ModelClass>()->get_layer(i);
846+
auto layer = (QWen3Block *)get_typed_transformer<ModelClass>()->get_layer(i);
847847
layer->attention.freq_base = config.rope_theta;
848848
}
849849
}
850850

851-
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
851+
CHATLLM_CHECK(w_ctx_.get_used_mem() + extra_tensors * ggml_tensor_overhead() == w_ctx_.get_mem_size())
852852
<< "corrupted model weights: " << w_ctx_.get_used_mem() / ggml_tensor_overhead() << " != " << w_ctx_.get_mem_size() / ggml_tensor_overhead();
853853
}
854854

@@ -913,4 +913,151 @@ namespace ds_r1_distill_v3
913913
};
914914

915915
typedef v3::ConditionalGeneration ConditionalGeneration;
916+
}
917+
918+
919+
namespace v3_emb
920+
{
921+
typedef v3::Config Config;
922+
923+
class Tokenizer : public v3::Tokenizer
924+
{
925+
public:
926+
Tokenizer(const BaseConfig &config)
927+
: v3::Tokenizer(config)
928+
{
929+
task = "Given a web search query, retrieve relevant passages that answer the query";
930+
}
931+
932+
void encode_embedding(const std::string &text, std::vector<int> &ids, EmbeddingPurpose purpose) const override;
933+
934+
public:
935+
std::string task;
936+
};
937+
938+
void Tokenizer::encode_embedding(const std::string &text, std::vector<int> &ids, EmbeddingPurpose purpose) const
939+
{
940+
std::ostringstream oss;
941+
switch (purpose)
942+
{
943+
case EmbeddingPurpose::Query:
944+
oss << "Instruct: " << task << "\nQuery:" << text;
945+
BaseTokenizer::encode(oss.str(), ids);
946+
break;
947+
948+
default:
949+
BaseTokenizer::encode(text, ids);
950+
break;
951+
}
952+
ids.push_back(eos_token_id);
953+
}
954+
955+
956+
class ConditionalGeneration : public v3::ConditionalGeneration
957+
{
958+
public:
959+
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_QWEN3_Embedding, const bool skip_lm_head = true, int extra_tensors = 0)
960+
: v3::ConditionalGeneration(config, runtime_config, type, skip_lm_head, extra_tensors)
961+
{
962+
dynamic_cast<HeterogeneousModel *>(transformer)->set_final_steps(std::make_unique<EmbeddingLastTokenFinalSteps>());
963+
}
964+
965+
void set_additional_args(const std::map<std::string, std::string> &args) override
966+
{
967+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
968+
auto it = args.find("task");
969+
if (it != args.end())
970+
{
971+
tok->task = it->second;
972+
}
973+
}
974+
};
975+
}
976+
977+
namespace v3_ranker
978+
{
979+
typedef v3::Config Config;
980+
981+
class Tokenizer : public v3_emb::Tokenizer
982+
{
983+
public:
984+
Tokenizer(const BaseConfig &config)
985+
: v3_emb::Tokenizer(config)
986+
{
987+
}
988+
989+
size_t load(tokenizer::DataReader *buffer, int n_vocab) override;
990+
991+
void encode_qa(const std::string &q, const std::string &a, std::vector<int> &ids) const override;
992+
public:
993+
int yes_token_id;
994+
int no_token_id;
995+
};
996+
997+
size_t Tokenizer::load(tokenizer::DataReader *buffer, int n_vocab)
998+
{
999+
size_t size = v3_emb::Tokenizer::load(buffer, n_vocab);
1000+
1001+
yes_token_id = tp->PieceToId("yes");
1002+
no_token_id = tp->PieceToId("no");
1003+
1004+
return size;
1005+
}
1006+
1007+
void Tokenizer::encode_qa(const std::string &q, const std::string &a, std::vector<int> &ids) const
1008+
{
1009+
std::ostringstream oss;
1010+
oss << "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n";
1011+
oss << "<Instruct>: " << task << "\n<Query>: " << q << "\n<Document>: " << a;
1012+
oss << "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1013+
1014+
BaseTokenizer::encode(oss.str(), ids);
1015+
}
1016+
1017+
class FinalSteps : public LMFinalSteps
1018+
{
1019+
public:
1020+
ggml::tensor *forward(HeterogeneousModel *model, ComputeContext *ctx, ggml::tensor *input_ids, ggml::tensor *hidden_states) override;
1021+
public:
1022+
ggml::tensor *yes_no_ids;
1023+
};
1024+
1025+
ggml::tensor *FinalSteps::forward(HeterogeneousModel *model, ComputeContext *ctx, ggml::tensor *input_ids, ggml::tensor *hidden_states)
1026+
{
1027+
ggml::tensor *logits = LMFinalSteps::forward(model, ctx, input_ids, hidden_states);
1028+
logits = ggml::reshape_2d(ctx, logits, 1, ggml::get_dim(logits, 0));
1029+
logits = ggml::get_rows(ctx, logits, yes_no_ids);
1030+
logits = ggml::reshape_1d(ctx, logits, 2);
1031+
logits = ggml::soft_max(ctx, logits);
1032+
logits = ggml::view_1d(ctx, logits, 1, 0);
1033+
return logits;
1034+
}
1035+
1036+
class ConditionalGeneration : public v3_emb::ConditionalGeneration
1037+
{
1038+
public:
1039+
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config)
1040+
: v3_emb::ConditionalGeneration(config, runtime_config, MODEL_TYPE_QWEN3_ReRanker, false, 1)
1041+
{
1042+
dynamic_cast<HeterogeneousModel *>(transformer)->set_final_steps(std::make_unique<FinalSteps>());
1043+
1044+
FinalSteps *steps = dynamic_cast<FinalSteps *>(dynamic_cast<HeterogeneousModel *>(transformer)->get_final_steps());
1045+
steps->yes_no_ids = ggml::new_tensor_1d(&w_ctx_, ggml::type::GGML_TYPE_I32, 2);
1046+
w_ctx_.get_allocator()->alloc(steps->yes_no_ids);
1047+
yes_no_ids = steps->yes_no_ids;
1048+
}
1049+
1050+
void set_tokenizer(BaseTokenizer *tokenizer) override
1051+
{
1052+
v3::ConditionalGeneration::set_tokenizer(tokenizer);
1053+
1054+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
1055+
int ids[2];
1056+
ids[0] = tok->yes_token_id;
1057+
ids[1] = tok->no_token_id;
1058+
Backend::write_tensor_data(yes_no_ids, ids, 0, sizeof(ids));
1059+
}
1060+
protected:
1061+
ggml::tensor *yes_no_ids = nullptr;
1062+
};
9161063
}

src/chat.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,11 @@ namespace chatllm
613613
tp->Encode(input, &ids);
614614
}
615615

616+
void BaseTokenizer::encode_embedding(const std::string &text, std::vector<int> &ids, EmbeddingPurpose purpose) const
617+
{
618+
encode(text, ids);
619+
}
620+
616621
std::vector<int> BaseTokenizer::encode(const std::string &text) const
617622
{
618623
std::vector<int> ids;
@@ -1851,11 +1856,11 @@ namespace chatllm
18511856
tokenizer->encode(input, result);
18521857
}
18531858

1854-
void Pipeline::text_embedding(const std::string &input, const GenerationConfig &gen_config, std::vector<float> &result)
1859+
void Pipeline::text_embedding(const std::string &input, const GenerationConfig &gen_config, std::vector<float> &result, BaseTokenizer::EmbeddingPurpose purpose)
18551860
{
18561861
if (!modelobj.loaded) return;
18571862
std::vector<int> input_ids;
1858-
tokenizer->encode(input, input_ids);
1863+
tokenizer->encode_embedding(input, input_ids, purpose);
18591864
model->text_embedding(gen_config, input_ids, result);
18601865
}
18611866

src/chat.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ namespace chatllm
250250
int emb_vec_number;
251251
std::vector<float> data;
252252
};
253+
254+
enum EmbeddingPurpose
255+
{
256+
Document,
257+
Query,
258+
};
259+
253260
public:
254261
BaseTokenizer(const BaseConfig &config,
255262
BaseHistoryEncoder *chat_encoder,
@@ -268,6 +275,7 @@ namespace chatllm
268275
virtual void encode_external_text_completion(const std::string &text, std::vector<int> &ids) const;
269276

270277
virtual void encode_qa(const std::string &q, const std::string &a, std::vector<int> &ids) const;
278+
virtual void encode_embedding(const std::string &text, std::vector<int> &ids, EmbeddingPurpose purpose) const;
271279

272280
virtual std::string decode(const std::vector<int> &ids) const;
273281

@@ -1320,7 +1328,7 @@ namespace chatllm
13201328
void set_extending_method(ExtendingMethod method);
13211329
virtual void set_additional_args(const std::map<std::string, std::string> &args);
13221330

1323-
void text_embedding(const std::string &input, const GenerationConfig &gen_config, std::vector<float> &result);
1331+
void text_embedding(const std::string &input, const GenerationConfig &gen_config, std::vector<float> &result, BaseTokenizer::EmbeddingPurpose purpose = BaseTokenizer::EmbeddingPurpose::Document);
13241332
void text_tokenize(const std::string &input, const GenerationConfig &gen_config, std::vector<int> &result);
13251333
float qa_rank(const std::string &q, const std::string &a, const GenerationConfig &gen_config);
13261334

src/layers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ namespace chatllm
125125
ggml::tensor *norm_inplace(ComputeContext *ctx, ggml::tensor *a, float eps);
126126
ggml::tensor *rms_norm_inplace(ComputeContext *ctx, ggml::tensor *a, float eps);
127127
ggml::tensor *rms_norm(ComputeContext *ctx, ggml::tensor *a, float eps);
128-
ggml::tensor *simple_norm(ComputeContext *ctx, ggml::tensor *a, float eps);
128+
ggml::tensor *simple_norm(ComputeContext *ctx, ggml::tensor *a, float eps); // p=2 normalization
129129

130130
ggml::tensor *rope(ComputeContext *ctx, ggml::tensor *a, ggml::tensor *b, int n_dims, int mode);
131131
ggml::tensor *rope_ext(ComputeContext *ctx, ggml::tensor *a, ggml::tensor *b, ggml::tensor *c,

src/main.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -811,17 +811,18 @@ static void run_tts(Args &args, chatllm::Pipeline &pipeline, TextStreamer &strea
811811
static void run_text_embedding(Args &args, chatllm::Pipeline &pipeline, TextStreamer &streamer, const chatllm::GenerationConfig &gen_config)
812812
{
813813
std::vector<float> result;
814+
chatllm::BaseTokenizer::EmbeddingPurpose purpose = chatllm::BaseTokenizer::EmbeddingPurpose::Document;
814815

815816
if (!args.interactive)
816817
{
817-
pipeline.text_embedding(args.prompt, gen_config, result);
818+
pipeline.text_embedding(args.prompt, gen_config, result, purpose);
818819
print_embedding(result, streamer.cout);
819820
return;
820821
}
821822

822823
while (1)
823824
{
824-
streamer.cout << "Input > " << std::flush;
825+
streamer.cout << "Input " << (purpose == chatllm::BaseTokenizer::EmbeddingPurpose::Document ? "Doc" : "Query") << " > " << std::flush;
825826
std::string input;
826827
if (!get_utf8_line(input, args.multi_line))
827828
{
@@ -831,11 +832,13 @@ static void run_text_embedding(Args &args, chatllm::Pipeline &pipeline, TextStre
831832
if (input.empty()) continue;
832833

833834
result.clear();
834-
pipeline.text_embedding(input, gen_config, result);
835+
pipeline.text_embedding(input, gen_config, result, purpose);
835836
streamer.cout << " > ";
836837

837838
print_embedding(result, streamer.cout);
838839

840+
purpose = purpose == chatllm::BaseTokenizer::EmbeddingPurpose::Document ?
841+
chatllm::BaseTokenizer::EmbeddingPurpose::Query : chatllm::BaseTokenizer::EmbeddingPurpose::Document;
839842
}
840843
streamer.cout << "Bye\n";
841844
}

0 commit comments

Comments
 (0)