Skip to content

Commit c7968a4

Browse files
committed
embeddings: fix extraction of CLS pooling results
1 parent 8ad7b3e commit c7968a4

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

src/llama-graph.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
191191
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
192192
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
193193

194-
if (cparams.embeddings && (
195-
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
196-
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
197-
)) {
194+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_RANK) {
198195
GGML_ASSERT(cls);
199196
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
200197

@@ -211,15 +208,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
211208
}
212209
}
213210

214-
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
211+
if (cparams.embeddings && (
212+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
213+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
214+
)) {
215215
GGML_ASSERT(cls);
216216
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
217217

218218
uint32_t * data = (uint32_t *) cls->data;
219219
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
220220

221-
std::vector<int> last_pos(n_seqs_unq, -1);
222-
std::vector<int> last_row(n_seqs_unq, -1);
221+
std::vector<int> target_pos(n_seqs_unq, -1);
222+
std::vector<int> target_row(n_seqs_unq, -1);
223223

224224
for (int i = 0; i < n_tokens; ++i) {
225225
const llama_pos pos = ubatch->pos[i];
@@ -228,16 +228,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
228228
const llama_seq_id seq_id = ubatch->seq_id[i][s];
229229
const int32_t seq_idx = ubatch->seq_idx[seq_id];
230230

231-
if (pos >= last_pos[seq_idx]) {
232-
last_pos[seq_idx] = pos;
233-
last_row[seq_idx] = i;
231+
if (
232+
(target_pos[seq_idx] == -1) ||
233+
(cparams.pooling_type == LLAMA_POOLING_TYPE_CLS && pos < target_pos[seq_idx]) ||
234+
(cparams.pooling_type == LLAMA_POOLING_TYPE_LAST && pos >= target_pos[seq_idx])
235+
) {
236+
target_pos[seq_idx] = pos;
237+
target_row[seq_idx] = i;
234238
}
235239
}
236240
}
237241

238242
for (int s = 0; s < n_seqs_unq; ++s) {
239-
if (last_row[s] >= 0) {
240-
data[s] = last_row[s];
243+
if (target_row[s] >= 0) {
244+
data[s] = target_row[s];
241245
}
242246
}
243247
}

0 commit comments

Comments
 (0)