@@ -793,23 +793,23 @@ namespace v3
793793 class ConditionalGeneration : public BaseModelForConditionalGeneration
794794 {
795795 public:
796- typedef BaseModelForConditionalGeneration Base;
797796 typedef HeterogeneousModel ModelClass;
798797 public:
799- ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_QWEN3)
798+ ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_QWEN3, const bool skip_lm_head = false , int extra_tensors = 0 )
800799 : BaseModelForConditionalGeneration(type, config, runtime_config, 4096 * 4 ),
801800 config (config)
802801 {
803802 const size_t tensor_ovhd = ggml_tensor_overhead ();
804803 const int sparse_layers = get_sparse_layer_num ();
805804 const size_t num_tensors = 3 + (config.tie_word_embeddings ? -1 : 0 )
806805 + (config.num_hidden_layers - sparse_layers) * 14
807- + sparse_layers * (14 + 1 );
806+ + sparse_layers * (14 + 1 )
807+ + extra_tensors;
808808 const size_t ctx_size = num_tensors * tensor_ovhd;
809809 w_ctx_.gctx = GGMLContext ({.mem_size = ctx_size, .mem_buffer = nullptr , .no_alloc = true });
810810 w_ctx_.dtype = config.dtype ;
811811
812- if (config.tie_word_embeddings )
812+ if (skip_lm_head || config.tie_word_embeddings )
813813 {
814814 transformer = new ModelClass (&w_ctx_, config.num_hidden_layers , config.hidden_size ,
815815 create_embedding<Embedding>(&w_ctx_, config),
@@ -837,18 +837,18 @@ namespace v3
837837 {
838838 if (config.layer_is_sparse [i])
839839 {
840- auto layer = (QWen3MoEBlock128_8 *)Base:: get_typed_transformer<ModelClass>()->get_layer (i);
840+ auto layer = (QWen3MoEBlock128_8 *)get_typed_transformer<ModelClass>()->get_layer (i);
841841 layer->attention .freq_base = config.rope_theta ;
842842 layer->mlp .norm_topk_prob = config.norm_topk_prob != 0 ;
843843 }
844844 else
845845 {
846- auto layer = (QWen3Block *)Base:: get_typed_transformer<ModelClass>()->get_layer (i);
846+ auto layer = (QWen3Block *)get_typed_transformer<ModelClass>()->get_layer (i);
847847 layer->attention .freq_base = config.rope_theta ;
848848 }
849849 }
850850
851- CHATLLM_CHECK (w_ctx_.get_used_mem () == w_ctx_.get_mem_size ())
851+ CHATLLM_CHECK (w_ctx_.get_used_mem () + extra_tensors * ggml_tensor_overhead () == w_ctx_.get_mem_size ())
852852 << " corrupted model weights: " << w_ctx_.get_used_mem () / ggml_tensor_overhead () << " != " << w_ctx_.get_mem_size () / ggml_tensor_overhead ();
853853 }
854854
@@ -913,4 +913,151 @@ namespace ds_r1_distill_v3
913913 };
914914
915915 typedef v3::ConditionalGeneration ConditionalGeneration;
916+ }
917+
918+
919+ namespace v3_emb
920+ {
921+ typedef v3::Config Config;
922+
923+ class Tokenizer : public v3 ::Tokenizer
924+ {
925+ public:
926+ Tokenizer (const BaseConfig &config)
927+ : v3::Tokenizer(config)
928+ {
929+ task = " Given a web search query, retrieve relevant passages that answer the query" ;
930+ }
931+
932+ void encode_embedding (const std::string &text, std::vector<int > &ids, EmbeddingPurpose purpose) const override ;
933+
934+ public:
935+ std::string task;
936+ };
937+
938+ void Tokenizer::encode_embedding (const std::string &text, std::vector<int > &ids, EmbeddingPurpose purpose) const
939+ {
940+ std::ostringstream oss;
941+ switch (purpose)
942+ {
943+ case EmbeddingPurpose::Query:
944+ oss << " Instruct: " << task << " \n Query:" << text;
945+ BaseTokenizer::encode (oss.str (), ids);
946+ break ;
947+
948+ default :
949+ BaseTokenizer::encode (text, ids);
950+ break ;
951+ }
952+ ids.push_back (eos_token_id);
953+ }
954+
955+
956+ class ConditionalGeneration : public v3 ::ConditionalGeneration
957+ {
958+ public:
959+ ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_QWEN3_Embedding, const bool skip_lm_head = true , int extra_tensors = 0 )
960+ : v3::ConditionalGeneration(config, runtime_config, type, skip_lm_head, extra_tensors)
961+ {
962+ dynamic_cast <HeterogeneousModel *>(transformer)->set_final_steps (std::make_unique<EmbeddingLastTokenFinalSteps>());
963+ }
964+
965+ void set_additional_args (const std::map<std::string, std::string> &args) override
966+ {
967+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
968+ auto it = args.find (" task" );
969+ if (it != args.end ())
970+ {
971+ tok->task = it->second ;
972+ }
973+ }
974+ };
975+ }
976+
977+ namespace v3_ranker
978+ {
979+ typedef v3::Config Config;
980+
981+ class Tokenizer : public v3_emb ::Tokenizer
982+ {
983+ public:
984+ Tokenizer (const BaseConfig &config)
985+ : v3_emb::Tokenizer(config)
986+ {
987+ }
988+
989+ size_t load (tokenizer::DataReader *buffer, int n_vocab) override ;
990+
991+ void encode_qa (const std::string &q, const std::string &a, std::vector<int > &ids) const override ;
992+ public:
993+ int yes_token_id;
994+ int no_token_id;
995+ };
996+
997+ size_t Tokenizer::load (tokenizer::DataReader *buffer, int n_vocab)
998+ {
999+ size_t size = v3_emb::Tokenizer::load (buffer, n_vocab);
1000+
1001+ yes_token_id = tp->PieceToId (" yes" );
1002+ no_token_id = tp->PieceToId (" no" );
1003+
1004+ return size;
1005+ }
1006+
1007+ void Tokenizer::encode_qa (const std::string &q, const std::string &a, std::vector<int > &ids) const
1008+ {
1009+ std::ostringstream oss;
1010+ oss << " <|im_start|>system\n Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \" yes\" or \" no\" .<|im_end|>\n <|im_start|>user\n " ;
1011+ oss << " <Instruct>: " << task << " \n <Query>: " << q << " \n <Document>: " << a;
1012+ oss << " <|im_end|>\n <|im_start|>assistant\n <think>\n\n </think>\n\n " ;
1013+
1014+ BaseTokenizer::encode (oss.str (), ids);
1015+ }
1016+
1017+ class FinalSteps : public LMFinalSteps
1018+ {
1019+ public:
1020+ ggml::tensor *forward (HeterogeneousModel *model, ComputeContext *ctx, ggml::tensor *input_ids, ggml::tensor *hidden_states) override ;
1021+ public:
1022+ ggml::tensor *yes_no_ids;
1023+ };
1024+
1025+ ggml::tensor *FinalSteps::forward (HeterogeneousModel *model, ComputeContext *ctx, ggml::tensor *input_ids, ggml::tensor *hidden_states)
1026+ {
1027+ ggml::tensor *logits = LMFinalSteps::forward (model, ctx, input_ids, hidden_states);
1028+ logits = ggml::reshape_2d (ctx, logits, 1 , ggml::get_dim (logits, 0 ));
1029+ logits = ggml::get_rows (ctx, logits, yes_no_ids);
1030+ logits = ggml::reshape_1d (ctx, logits, 2 );
1031+ logits = ggml::soft_max (ctx, logits);
1032+ logits = ggml::view_1d (ctx, logits, 1 , 0 );
1033+ return logits;
1034+ }
1035+
1036+ class ConditionalGeneration : public v3_emb ::ConditionalGeneration
1037+ {
1038+ public:
1039+ ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config)
1040+ : v3_emb::ConditionalGeneration(config, runtime_config, MODEL_TYPE_QWEN3_ReRanker, false , 1 )
1041+ {
1042+ dynamic_cast <HeterogeneousModel *>(transformer)->set_final_steps (std::make_unique<FinalSteps>());
1043+
1044+ FinalSteps *steps = dynamic_cast <FinalSteps *>(dynamic_cast <HeterogeneousModel *>(transformer)->get_final_steps ());
1045+ steps->yes_no_ids = ggml::new_tensor_1d (&w_ctx_, ggml::type::GGML_TYPE_I32, 2 );
1046+ w_ctx_.get_allocator ()->alloc (steps->yes_no_ids );
1047+ yes_no_ids = steps->yes_no_ids ;
1048+ }
1049+
1050+ void set_tokenizer (BaseTokenizer *tokenizer) override
1051+ {
1052+ v3::ConditionalGeneration::set_tokenizer (tokenizer);
1053+
1054+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
1055+ int ids[2 ];
1056+ ids[0 ] = tok->yes_token_id ;
1057+ ids[1 ] = tok->no_token_id ;
1058+ Backend::write_tensor_data (yes_no_ids, ids, 0 , sizeof (ids));
1059+ }
1060+ protected:
1061+ ggml::tensor *yes_no_ids = nullptr ;
1062+ };
9161063}
0 commit comments