Skip to content

Commit b045eac

Browse files
committed
restore old example
1 parent 829b762 commit b045eac

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

examples/speculative/speculative.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,49 @@ int main(int argc, char ** argv) {
9999
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
100100
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
101101

102+
if (vocab_type_tgt != vocab_type_dft) {
103+
LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__);
104+
LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
105+
return 1;
106+
}
107+
108+
if (
109+
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
110+
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
111+
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
112+
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
113+
) {
114+
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
115+
return 1;
116+
}
117+
118+
{
119+
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
120+
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
121+
const int vocab_diff = n_vocab_tgt > n_vocab_dft
122+
? n_vocab_tgt - n_vocab_dft
123+
: n_vocab_dft - n_vocab_tgt;
124+
125+
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
126+
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
127+
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
128+
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
129+
return 1;
130+
}
131+
132+
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
133+
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
134+
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
135+
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
136+
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
137+
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
138+
common_token_to_piece(ctx_tgt, i).c_str(),
139+
common_token_to_piece(ctx_dft, i).c_str());
140+
return 1;
141+
}
142+
}
143+
}
144+
102145
auto * mem_tgt = llama_get_memory(ctx_tgt);
103146
auto * mem_dft = llama_get_memory(ctx_dft);
104147

0 commit comments

Comments
 (0)