@@ -191,10 +191,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
191191 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
192192 const int64_t n_seqs_unq = ubatch->n_seqs_unq ;
193193
194- if (cparams.embeddings && (
195- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
196- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
197- )) {
194+ if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_RANK) {
198195 GGML_ASSERT (cls);
199196 GGML_ASSERT (ggml_backend_buffer_is_host (cls->buffer ));
200197
@@ -211,15 +208,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
211208 }
212209 }
213210
214- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
211+ if (cparams.embeddings && (
212+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
213+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
214+ )) {
215215 GGML_ASSERT (cls);
216216 GGML_ASSERT (ggml_backend_buffer_is_host (cls->buffer ));
217217
218218 uint32_t * data = (uint32_t *) cls->data ;
219219 memset (cls->data , 0 , n_seqs_unq*ggml_element_size (cls));
220220
221- std::vector<int > last_pos (n_seqs_unq, -1 );
222- std::vector<int > last_row (n_seqs_unq, -1 );
221+ std::vector<int > target_pos (n_seqs_unq, -1 );
222+ std::vector<int > target_row (n_seqs_unq, -1 );
223223
224224 for (int i = 0 ; i < n_tokens; ++i) {
225225 const llama_pos pos = ubatch->pos [i];
@@ -228,16 +228,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
228228 const llama_seq_id seq_id = ubatch->seq_id [i][s];
229229 const int32_t seq_idx = ubatch->seq_idx [seq_id];
230230
231- if (pos >= last_pos[seq_idx]) {
232- last_pos[seq_idx] = pos;
233- last_row[seq_idx] = i;
231+ if (
232+ (target_pos[seq_idx] == -1 ) ||
233+ (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS && pos < target_pos[seq_idx]) ||
234+ (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST && pos >= target_pos[seq_idx])
235+ ) {
236+ target_pos[seq_idx] = pos;
237+ target_row[seq_idx] = i;
234238 }
235239 }
236240 }
237241
238242 for (int s = 0 ; s < n_seqs_unq; ++s) {
239- if (last_row [s] >= 0 ) {
240- data[s] = last_row [s];
243+ if (target_row [s] >= 0 ) {
244+ data[s] = target_row [s];
241245 }
242246 }
243247 }
0 commit comments