@@ -64,10 +64,11 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
6464
6565 int32_t n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model));
6666
67- std::vector<llama_token_data> candidates;
68- candidates. reserve (n_vocab);
67+ std::vector<llama_token_data> candidates (n_vocab) ;
68+
6969 std::vector<llama_token_data> conf_candidates;
7070 conf_candidates.reserve (max_length);
71+
7172 std::vector<int32_t > mask_positions;
7273 mask_positions.reserve (max_length);
7374
@@ -124,18 +125,9 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
124125 return nullptr ;
125126 }
126127
127- std::vector<float > shifted_logits (max_length * n_vocab);
128-
129- // Move logits to left, because decode path will shift logits to right
130- // Position 0 keeps its own logits: shifted_logits[0] = raw_logits[0]
131- std::copy (raw_logits, raw_logits + n_vocab, shifted_logits.data ());
132-
133- // Positions 1+ get logits from previous position: shifted_logits[i] = raw_logits[i-1]
134- for (int32_t i = 1 ; i < max_length; i++) {
135- std::copy (raw_logits + (i - 1 ) * n_vocab, raw_logits + i * n_vocab, shifted_logits.data () + i * n_vocab);
136- }
137-
138- float * logits = shifted_logits.data ();
128+ auto get_logits_for_pos = [&](int32_t pos) -> const float * {
129+ return pos == 0 ? raw_logits : raw_logits + (pos - 1 ) * n_vocab;
130+ };
139131
140132 mask_positions.clear ();
141133 for (int32_t i = 0 ; i < max_length; i++) {
@@ -156,14 +148,16 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
156148
157149 for (int32_t pos : mask_positions) {
158150 if (std::uniform_real_distribution<float >(0 .0f , 1 .0f )(rng) < p_transfer) {
159- candidates. clear ( );
151+ const float * pos_logits = get_logits_for_pos (pos );
160152 for (int32_t token_id = 0 ; token_id < n_vocab; token_id++) {
161- candidates.emplace_back (llama_token_data{ token_id, logits[pos * n_vocab + token_id], 0 .0f });
153+ candidates[token_id].id = token_id;
154+ candidates[token_id].logit = pos_logits[token_id];
155+ candidates[token_id].p = 0 .0f ;
162156 }
163157
164158 llama_token_data_array cur_p = {
165159 /* .data = */ candidates.data (),
166- /* .size = */ candidates. size (),
160+ /* .size = */ ( size_t )n_vocab, // Reset size to full vocab
167161 /* .selected = */ -1 ,
168162 /* .sorted = */ false ,
169163 };
@@ -173,19 +167,17 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
173167 }
174168 }
175169 } else {
176- candidates.clear ();
177- candidates.shrink_to_fit ();
178-
179170 std::vector<std::pair<float , int32_t >> confidences;
180171 std::vector<llama_token> sampled_tokens (mask_positions.size ());
181172
182173 for (size_t i = 0 ; i < mask_positions.size (); i++) {
183174 int32_t pos = mask_positions[i];
184- float * pos_logits = logits + pos * n_vocab ;
175+ const float * pos_logits = get_logits_for_pos ( pos) ;
185176
186- candidates.clear ();
187177 for (int32_t token_id = 0 ; token_id < n_vocab; token_id++) {
188- candidates.emplace_back (llama_token_data{ token_id, pos_logits[token_id], 0 .0f });
178+ candidates[token_id].logit = pos_logits[token_id];
179+ candidates[token_id].p = 0 .0f ;
180+ candidates[token_id].id = token_id;
189181 }
190182
191183 llama_token_data_array cur_p = {
@@ -207,15 +199,14 @@ llama_token * diffusion_generate(llama_context * ctx, const llama_token * input_
207199 confidence += prob * logf (prob + epsilon);
208200 }
209201 } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
210- std::partial_sort (cur_p.data , cur_p.data + 2 , cur_p.data + cur_p.size ,
211- [](const llama_token_data & a, const llama_token_data & b) { return a.p > b.p ; });
212202 confidence = cur_p.data [0 ].p - cur_p.data [1 ].p ;
213203 } else {
214204 confidence = cur_p.data [cur_p.selected ].p ;
215205 }
216206
217207 sampled_tokens[i] = sampled_token;
218208 confidences.emplace_back (confidence, i);
209+
219210 }
220211
221212 int32_t num_transfer =
0 commit comments