@@ -65,8 +65,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
65
65
Array<GenerationConfig> generation_cfg;
66
66
std::vector<RandomGenerator*> rngs;
67
67
std::vector<std::vector<SampleResult>> draft_output_tokens;
68
+ std::vector<int64_t > token_tree_parent_ptr;
68
69
request_internal_ids.reserve (num_rsentries);
69
70
all_tokens_to_verify.reserve (total_draft_length);
71
+ token_tree_parent_ptr.reserve (total_draft_length);
70
72
verify_request_mstates.reserve (num_rsentries);
71
73
rngs.reserve (num_rsentries);
72
74
generation_cfg.reserve (num_rsentries);
@@ -83,9 +85,12 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
83
85
// the last committed token + all the draft tokens but the last one.
84
86
all_tokens_to_verify.push_back (draft_mstate->committed_tokens .back ().GetTokenId ());
85
87
draft_token_slots_.push_back (0 ); // placeholder for the last committed token
88
+ token_tree_parent_ptr.push_back (-1 );
89
+
86
90
for (int j = 0 ; j < static_cast <int >(draft_mstate->draft_output_tokens .size ()); ++j) {
87
91
all_tokens_to_verify.push_back (draft_mstate->draft_output_tokens [j].GetTokenId ());
88
92
draft_token_slots_.push_back (draft_mstate->draft_token_slots [j]);
93
+ token_tree_parent_ptr.push_back (draft_mstate->draft_token_parent_idx [j] + 1 );
89
94
}
90
95
verify_request_mstates.push_back (verify_mstate);
91
96
generation_cfg.push_back (rsentries[i]->request ->generation_cfg );
@@ -111,16 +116,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
111
116
{IntTuple{all_tokens_to_verify.begin (), all_tokens_to_verify.end ()}});
112
117
RECORD_EVENT (trace_recorder_, request_ids, " finish verify embedding" );
113
118
114
- // Construct the token tree. Right now only chains are supported.
115
- std::vector<int64_t > token_tree_parent_ptr;
116
- token_tree_parent_ptr.reserve (cum_verify_lengths.back ());
117
- for (int i = 0 ; i < num_rsentries; ++i) {
118
- for (int pos = 0 ; pos < verify_lengths[i]; ++pos) {
119
- token_tree_parent_ptr.push_back (pos - 1 );
120
- }
121
- }
122
- ICHECK_EQ (token_tree_parent_ptr.size (), cum_verify_lengths.back ());
123
-
124
119
RECORD_EVENT (trace_recorder_, request_ids, " start verify" );
125
120
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden (
126
121
embeddings, request_internal_ids, verify_lengths, token_tree_parent_ptr);
@@ -143,7 +138,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
143
138
std::vector<std::vector<SampleResult>> sample_results_arr =
144
139
sampler_->BatchVerifyDraftTokensWithProbAfterTopP (
145
140
renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,
146
- draft_output_tokens, draft_probs_on_device);
141
+ draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);
147
142
ICHECK_EQ (sample_results_arr.size (), num_rsentries);
148
143
149
144
// We collect the requests whose drafts are fully accepted.
@@ -398,7 +393,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
398
393
&model_workspaces_[0 ].draft_hidden_states_storage );
399
394
}
400
395
for (int i = 0 ; i < static_cast <int >(mstates.size ()); ++i) {
401
- mstates[i]->AddDraftToken (sample_results[i], draft_token_slots_[i]);
396
+ int64_t parent_idx = static_cast <int64_t >(mstates[i]->draft_output_tokens .size ()) - 1 ;
397
+ mstates[i]->AddDraftToken (sample_results[i], draft_token_slots_[i], parent_idx);
402
398
}
403
399
}
404
400
/* !
0 commit comments