1+ typedef llama::v3::Config Config;
2+
3+ class ChatHistoryEncoder : public BaseHistoryEncoder
4+ {
5+ public:
6+ void append_sys_prompt (std::vector<int > &ids) const override ;
7+ void append_pair (int round_idx, const std::string &user, const std::string &ai, std::vector<int > &ids) const override ;
8+ void do_append_user (int round_idx, const std::string &user, std::vector<int > &ids) const override ;
9+ };
10+
11+ static ChatHistoryEncoder _chat_encoder;
12+
13+ class Tokenizer : public llama ::v2::Tokenizer
14+ {
15+ public:
16+ Tokenizer (const Config &config)
17+ : llama::v2::Tokenizer(config, &_chat_encoder)
18+ {
19+ sys_prompt = " " ;
20+ resevered_0_token_id = 3 ;
21+ resevered_1_token_id = 4 ;
22+ }
23+ public:
24+ int resevered_0_token_id;
25+ int resevered_1_token_id;
26+ };
27+
28+ class ConditionalGeneration : public llama ::v3::ConditionalGeneration
29+ {
30+ public:
31+ ConditionalGeneration () = default ;
32+ ConditionalGeneration (const Config &config)
33+ : llama::v3::ConditionalGeneration(config, ModelType::MODEL_TYPE_INDEX)
34+ {}
35+ };
36+
37+ void ChatHistoryEncoder::append_sys_prompt (std::vector<int > &ids) const
38+ {
39+ if (tokenizer->get_system_prompt ().size () > 0 )
40+ {
41+ ids.push_back (tokenizer->pad_token_id );
42+ tokenizer->encode (tokenizer->get_system_prompt (), ids);
43+ }
44+
45+ }
46+
47+ void ChatHistoryEncoder::append_pair (int round_idx, const std::string &user, const std::string &ai, std::vector<int > &ids) const
48+ {
49+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
50+ do_append_user (round_idx, user, ids);
51+ tok->encode (ai, ids);
52+ }
53+
54+ void ChatHistoryEncoder::do_append_user (int round_idx, const std::string &user, std::vector<int > &ids) const
55+ {
56+ Tokenizer *tok = dynamic_cast <Tokenizer *>(tokenizer);
57+ ids.push_back (tok->resevered_0_token_id );
58+ tok->encode (user, ids);
59+ ids.push_back (tok->resevered_1_token_id );
60+ }
0 commit comments