@@ -11,22 +11,18 @@ struct common_speculative {
1111
1212 struct common_sampler * smpl;
1313
14- std::vector<int > i_batch_tgt;
15-
16- std::vector<llama_token> tokens;
14+ llama_tokens prompt_last;
1715};
1816
1917struct common_speculative * common_speculative_init (struct common_speculative_params params) {
2018 auto * result = new common_speculative {
2119 /* .params = */ params,
2220 /* .batch_dft = */ llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 ),
2321 /* .smpl = */ nullptr ,
24- /* .i_batch_tgt = */ {},
25- /* .tokens = */ {},
2622 };
2723
2824 // TODO: optimize or pass from outside?
29- #if 0
25+ #if 1
3026 {
3127 common_sampler_params sparams;
3228 sparams.no_perf = false ;
@@ -70,30 +66,79 @@ void common_speculative_free(struct common_speculative * spec) {
7066 delete spec;
7167}
7268
73- void common_speculative_set_prompt (struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
74- llama_kv_cache_clear (spec->params .ctx_dft );
75-
76- // TODO: error handling
77- llama_decode (spec->params .ctx_dft , llama_batch_get_one (tokens, n_tokens));
78- }
79-
8069void common_speculative_add_draft (
8170 struct common_speculative * spec,
8271 struct llama_batch & batch_tgt,
72+ const llama_tokens & prompt,
8373 llama_token id_last,
84- int n_past) {
85- spec->tokens .clear ();
74+ llama_token n_past_tgt) {
8675
87- spec-> i_batch_tgt . clear () ;
88- spec-> i_batch_tgt . push_back ( 0 ) ;
76+ int reuse_i = 0 ;
77+ int reuse_n = 0 ;
8978
90- common_sampler_reset (spec->smpl );
79+ const int n_ctx = llama_n_ctx (spec->params .ctx_dft ) - spec->params .n_draft ;
80+
81+ const int i_start = std::max<int >(0 , (int ) prompt.size () - n_ctx);
82+
83+ for (int i = 0 ; i < (int ) spec->prompt_last .size (); ++i) {
84+ int cur = 0 ;
85+ while (i_start + cur < (int ) prompt.size () &&
86+ i + cur < (int ) spec->prompt_last .size () &&
87+ prompt[i_start + cur] == spec->prompt_last [i + cur]) {
88+ cur++;
89+ }
90+
91+ if ((cur >= spec->params .n_reuse || prompt.size () <= n_ctx) && cur > reuse_n) {
92+ reuse_i = i;
93+ reuse_n = cur;
94+ }
95+ }
96+
97+ LOG_DBG (" %s: reuse_i = %d, reuse_n = %d\n " , __func__, reuse_i, reuse_n);
98+
99+ if (reuse_n == 0 ) {
100+ llama_kv_cache_clear (spec->params .ctx_dft );
101+
102+ spec->prompt_last .clear ();
103+ } else {
104+ llama_kv_cache_seq_rm (spec->params .ctx_dft , 0 , 0 , reuse_i);
105+ llama_kv_cache_seq_rm (spec->params .ctx_dft , 0 , reuse_i + reuse_n, -1 );
106+ llama_kv_cache_seq_add (spec->params .ctx_dft , 0 , reuse_i, -1 , -reuse_i);
107+
108+ spec->prompt_last .erase (spec->prompt_last .begin (), spec->prompt_last .begin () + reuse_i);
109+ spec->prompt_last .erase (spec->prompt_last .begin () + reuse_n, spec->prompt_last .end ());
110+ }
111+
112+ common_batch_clear (spec->batch_dft );
113+
114+ for (int i = i_start + reuse_n; i < (int ) prompt.size (); ++i) {
115+ // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]);
116+ common_batch_add (spec->batch_dft , prompt[i], i - i_start, { 0 }, false );
117+
118+ spec->prompt_last .push_back (prompt[i]);
119+ }
120+
121+ const llama_pos n_past = prompt.size () - i_start;
122+
123+ LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
124+
125+ if (spec->batch_dft .n_tokens > 0 ) {
126+ LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (spec->params .ctx_dft , spec->batch_dft ).c_str ());
127+
128+ llama_decode (spec->params .ctx_dft , spec->batch_dft );
129+ }
91130
92131 common_batch_clear (spec->batch_dft );
93132 common_batch_add (spec->batch_dft , id_last, n_past, { 0 }, true );
94133
134+ spec->prompt_last .push_back (id_last);
135+
136+ LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (spec->params .ctx_dft , spec->prompt_last ).c_str ());
137+
95138 llama_decode (spec->params .ctx_dft , spec->batch_dft );
96139
140+ common_sampler_reset (spec->smpl );
141+
97142 // sample n_draft tokens from the draft model
98143 for (int i = 0 ; i < spec->params .n_draft ; ++i) {
99144 common_batch_clear (spec->batch_dft );
@@ -111,18 +156,13 @@ void common_speculative_add_draft(
111156 const llama_token id = cur_p->data [0 ].id ;
112157
113158 // only collect very high-confidence draft tokens
114- if (cur_p->data [0 ].p < 0.75 && spec->tokens . size () >= 0 ) {
159+ if (cur_p->data [0 ].p < spec->params . p_min ) {
115160 break ;
116161 }
117162
118163 common_sampler_accept (spec->smpl , id, true );
119164
120- spec->tokens .push_back (id);
121-
122- // add unique drafted tokens to the target batch
123- spec->i_batch_tgt .push_back (batch_tgt.n_tokens );
124-
125- common_batch_add (batch_tgt, id, n_past + i + 1 , { 0 }, true );
165+ common_batch_add (batch_tgt, id, n_past_tgt + i, { 0 }, true );
126166
127167 if (batch_tgt.n_tokens > spec->params .n_draft ) {
128168 break ;
@@ -132,23 +172,13 @@ void common_speculative_add_draft(
132172
133173 // evaluate the drafted tokens on the draft model
134174 llama_decode (spec->params .ctx_dft , spec->batch_dft );
175+
176+ spec->prompt_last .push_back (id);
135177 }
136178
137179 // don't waste time on small batches
138180 // TODO: do not evaluate the draft model for that many rounds
139181 if (batch_tgt.n_tokens < spec->params .n_min ) {
140182 batch_tgt.n_tokens = 1 ;
141- spec->tokens .resize (0 );
142- spec->i_batch_tgt .resize (1 );
143183 }
144-
145- // print current draft sequences
146- LOG_DBG (" draft %s\n " , string_from (spec->params .ctx_dft , spec->tokens ).c_str ());
147- }
148-
149- std::vector<llama_token> common_speculative_sample (
150- struct common_speculative * spec,
151- struct common_sampler * smpl,
152- struct llama_context * ctx_tgt) {
153- return common_sampler_sample_n (smpl, ctx_tgt, spec->i_batch_tgt , spec->tokens );
154184}
0 commit comments