Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,27 +260,27 @@ int main(int argc, char ** argv) {
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);

if (params.embd_out.empty()) {
LOG("\n");
printf("\n");

if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int j = 0; j < n_embd_count; j++) {
LOG("embedding %d: ", j);
printf("embedding %d: ", j);
for (int i = 0; i < std::min(3, n_embd); i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd + i]);
printf("%6.0f ", emb[j * n_embd + i]);
} else {
LOG("%9.6f ", emb[j * n_embd + i]);
printf("%9.6f ", emb[j * n_embd + i]);
}
}
LOG(" ... ");
printf(" ... ");
for (int i = n_embd - 3; i < n_embd; i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd + i]);
printf("%6.0f ", emb[j * n_embd + i]);
} else {
LOG("%9.6f ", emb[j * n_embd + i]);
printf("%9.6f ", emb[j * n_embd + i]);
}
}
LOG("\n");
printf("\n");
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
Expand All @@ -296,41 +296,41 @@ int main(int argc, char ** argv) {
for (uint32_t i = 0; i < n_cls_out; i++) {
// NOTE: if you change this log - update the tests in ci/run.sh
if (n_cls_out == 1) {
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
printf("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
} else {
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
printf("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
}
}
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
LOG("embedding %d: ", j);
printf("embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd + i]);
printf("%6.0f ", emb[j * n_embd + i]);
} else {
LOG("%9.6f ", emb[j * n_embd + i]);
printf("%9.6f ", emb[j * n_embd + i]);
}
}
LOG("\n");
printf("\n");
}

// print cosine similarity matrix
if (n_prompts > 1) {
LOG("\n");
LOG("cosine similarity matrix:\n\n");
printf("\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
LOG("%6.6s ", prompts[i].c_str());
printf("%6.6s ", prompts[i].c_str());
}
LOG("\n");
printf("\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
LOG("%6.2f ", sim);
printf("%6.2f ", sim);
}
LOG("%1.10s", prompts[i].c_str());
LOG("\n");
printf("%1.10s", prompts[i].c_str());
printf("\n");
}
}
}
Expand All @@ -339,42 +339,42 @@ int main(int argc, char ** argv) {
if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
const bool notArray = params.embd_out != "array";

LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
printf(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
for (int j = 0;;) { // at least one iteration (one prompt)
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
LOG("[");
if (notArray) printf(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
printf("[");
for (int i = 0;;) { // at least one iteration (n_embd > 0)
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
printf(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
i++;
if (i < n_embd) LOG(","); else break;
if (i < n_embd) printf(","); else break;
}
LOG(notArray ? "]\n }" : "]");
printf(notArray ? "]\n }" : "]");
j++;
if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break;
if (j < n_embd_count) printf(notArray ? ",\n" : ","); else break;
}
LOG(notArray ? "\n ]" : "]\n");
printf(notArray ? "\n ]" : "]\n");

if (params.embd_out == "json+" && n_prompts > 1) {
LOG(",\n \"cosineSimilarity\": [\n");
printf(",\n \"cosineSimilarity\": [\n");
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
LOG(" [");
printf(" [");
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
LOG("%6.2f", sim);
printf("%6.2f", sim);
j++;
if (j < n_embd_count) LOG(", "); else break;
if (j < n_embd_count) printf(", "); else break;
}
LOG(" ]");
printf(" ]");
i++;
if (i < n_embd_count) LOG(",\n"); else break;
if (i < n_embd_count) printf(",\n"); else break;
}
LOG("\n ]");
printf("\n ]");
}

if (notArray) LOG("\n}\n");
if (notArray) printf("\n}\n");
}

LOG("\n");
printf("\n");
llama_perf_context_print(ctx);

// clean up
Expand Down