Skip to content

Commit 2cd8a90

Browse files
committed
context : make output functions members
ggml-ci
1 parent d1d8d53 commit 2cd8a90

File tree

2 files changed

+122
-124
lines changed

2 files changed

+122
-124
lines changed

src/llama-context.cpp

Lines changed: 114 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -9,121 +9,6 @@
99
#include <stdexcept>
1010
#include <cinttypes>
1111

12-
// llama output (TMP)
13-
14-
// Make sure enough space is available for outputs.
15-
// Returns max number of outputs for which space was reserved.
16-
static size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
17-
const auto & cparams = lctx.cparams;
18-
const auto & hparams = lctx.model.hparams;
19-
const auto & vocab = lctx.model.vocab;
20-
21-
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
22-
23-
const auto n_batch = cparams.n_batch;
24-
const auto n_vocab = vocab.n_tokens();
25-
const auto n_embd = hparams.n_embd;
26-
27-
// TODO: use a per-batch flag for logits presence instead
28-
const bool has_logits = !cparams.embeddings;
29-
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
30-
31-
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
32-
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
33-
34-
if (lctx.output_ids.empty()) {
35-
// init, never resized afterwards
36-
lctx.output_ids.resize(n_batch);
37-
}
38-
39-
const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
40-
const size_t new_size = (logits_size + embd_size) * sizeof(float);
41-
42-
// alloc only when more than the current capacity is required
43-
// TODO: also consider shrinking the buffer
44-
if (!lctx.buf_output || prev_size < new_size) {
45-
if (lctx.buf_output) {
46-
#ifndef NDEBUG
47-
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
48-
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
49-
#endif
50-
lctx.buf_output = nullptr;
51-
lctx.logits = nullptr;
52-
lctx.embd = nullptr;
53-
}
54-
55-
auto * buft = ggml_backend_cpu_buffer_type();
56-
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
57-
auto * output_dev = lctx.model.dev_output();
58-
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
59-
if (output_dev_host_buft) {
60-
buft = output_dev_host_buft;
61-
}
62-
lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
63-
if (lctx.buf_output == nullptr) {
64-
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
65-
return 0;
66-
}
67-
}
68-
69-
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get());
70-
71-
lctx.logits = has_logits ? output_base : nullptr;
72-
lctx.embd = has_embd ? output_base + logits_size : nullptr;
73-
74-
lctx.output_size = n_outputs_max;
75-
lctx.logits_size = logits_size;
76-
lctx.embd_size = embd_size;
77-
78-
// set all ids as invalid (negative)
79-
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
80-
81-
ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
82-
83-
lctx.n_outputs = 0;
84-
85-
return n_outputs_max;
86-
}
87-
88-
// make the outputs have the same order they had in the user-provided batch
89-
static void llama_output_reorder(struct llama_context & ctx) {
90-
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
91-
if (!out_ids.empty()) {
92-
const uint32_t n_vocab = ctx.model.vocab.n_tokens();
93-
const uint32_t n_embd = ctx.model.hparams.n_embd;
94-
95-
const int32_t n_outputs = ctx.n_outputs;
96-
GGML_ASSERT((size_t) n_outputs == out_ids.size());
97-
98-
// TODO: is there something more efficient which also minimizes swaps?
99-
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
100-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
101-
int32_t j_min = i;
102-
for (int32_t j = i + 1; j < n_outputs; ++j) {
103-
if (out_ids[j] < out_ids[j_min]) {
104-
j_min = j;
105-
}
106-
}
107-
if (j_min == i) { continue; }
108-
std::swap(out_ids[i], out_ids[j_min]);
109-
if (ctx.logits_size > 0) {
110-
for (uint32_t k = 0; k < n_vocab; k++) {
111-
std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
112-
}
113-
}
114-
if (ctx.embd_size > 0) {
115-
for (uint32_t k = 0; k < n_embd; k++) {
116-
std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
117-
}
118-
}
119-
}
120-
std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
121-
for (int32_t i = 0; i < n_outputs; ++i) {
122-
ctx.output_ids[out_ids[i]] = i;
123-
}
124-
out_ids.clear();
125-
}
126-
}
12712
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
12813
// TODO move to hparams if a T5 variant appears that uses a different value
12914
const int64_t max_distance = 128;
@@ -334,7 +219,7 @@ llama_context::llama_context(
334219
// graph outputs buffer
335220
{
336221
// resized during inference when a batch uses more outputs
337-
if (llama_output_reserve(*this, params.n_seq_max) < params.n_seq_max) {
222+
if (reserve_outputs(params.n_seq_max) < params.n_seq_max) {
338223
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
339224
throw std::runtime_error("failed to reserve initial output buffer");
340225
}
@@ -716,7 +601,7 @@ int llama_context::decode(llama_batch & inp_batch) {
716601

717602
// reserve output buffer
718603
// TODO: move to batch manager?
719-
if (llama_output_reserve(*this, bman->n_outputs_all) < (size_t) n_outputs_all) {
604+
if (reserve_outputs(bman->n_outputs_all) < (size_t) n_outputs_all) {
720605
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
721606
return -2;
722607
};
@@ -940,7 +825,7 @@ int llama_context::encode(llama_batch & inp_batch) {
940825
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
941826

942827
// reserve output buffer
943-
if (llama_output_reserve(*this, n_tokens) < n_tokens) {
828+
if (reserve_outputs(n_tokens) < n_tokens) {
944829
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
945830
return -2;
946831
};
@@ -1555,6 +1440,113 @@ void llama_context::set_inputs(const llama_ubatch & ubatch) {
15551440
}
15561441
}
15571442

1443+
void llama_context::reorder_outputs() {
1444+
std::vector<size_t> & out_ids = sbatch.out_ids;
1445+
if (!out_ids.empty()) {
1446+
const uint32_t n_vocab = model.vocab.n_tokens();
1447+
const uint32_t n_embd = model.hparams.n_embd;
1448+
1449+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
1450+
1451+
// TODO: is there something more efficient which also minimizes swaps?
1452+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1453+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
1454+
int32_t j_min = i;
1455+
for (int32_t j = i + 1; j < n_outputs; ++j) {
1456+
if (out_ids[j] < out_ids[j_min]) {
1457+
j_min = j;
1458+
}
1459+
}
1460+
if (j_min == i) { continue; }
1461+
std::swap(out_ids[i], out_ids[j_min]);
1462+
if (logits_size > 0) {
1463+
for (uint32_t k = 0; k < n_vocab; k++) {
1464+
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1465+
}
1466+
}
1467+
if (embd_size > 0) {
1468+
for (uint32_t k = 0; k < n_embd; k++) {
1469+
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1470+
}
1471+
}
1472+
}
1473+
std::fill(output_ids.begin(), output_ids.end(), -1);
1474+
for (int32_t i = 0; i < n_outputs; ++i) {
1475+
output_ids[out_ids[i]] = i;
1476+
}
1477+
out_ids.clear();
1478+
}
1479+
}
1480+
1481+
size_t llama_context::reserve_outputs(size_t n_outputs) {
1482+
const auto & hparams = model.hparams;
1483+
const auto & vocab = model.vocab;
1484+
1485+
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
1486+
1487+
const auto n_batch = cparams.n_batch;
1488+
const auto n_vocab = vocab.n_tokens();
1489+
const auto n_embd = hparams.n_embd;
1490+
1491+
// TODO: use a per-batch flag for logits presence instead
1492+
const bool has_logits = !cparams.embeddings;
1493+
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1494+
1495+
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1496+
embd_size = has_embd ? n_embd*n_outputs_max : 0;
1497+
1498+
if (output_ids.empty()) {
1499+
// init, never resized afterwards
1500+
output_ids.resize(n_batch);
1501+
}
1502+
1503+
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
1504+
const size_t new_size = (logits_size + embd_size) * sizeof(float);
1505+
1506+
// alloc only when more than the current capacity is required
1507+
// TODO: also consider shrinking the buffer
1508+
if (!buf_output || prev_size < new_size) {
1509+
if (buf_output) {
1510+
#ifndef NDEBUG
1511+
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1512+
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1513+
#endif
1514+
buf_output = nullptr;
1515+
logits = nullptr;
1516+
embd = nullptr;
1517+
}
1518+
1519+
auto * buft = ggml_backend_cpu_buffer_type();
1520+
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
1521+
auto * output_dev = model.dev_output();
1522+
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
1523+
if (output_dev_host_buft) {
1524+
buft = output_dev_host_buft;
1525+
}
1526+
buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
1527+
if (buf_output == nullptr) {
1528+
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
1529+
return 0;
1530+
}
1531+
}
1532+
1533+
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
1534+
1535+
logits = has_logits ? output_base : nullptr;
1536+
embd = has_embd ? output_base + logits_size : nullptr;
1537+
1538+
output_size = n_outputs_max;
1539+
1540+
// set all ids as invalid (negative)
1541+
std::fill(output_ids.begin(), output_ids.end(), -1);
1542+
1543+
ggml_backend_buffer_clear(buf_output.get(), 0);
1544+
1545+
n_outputs = 0;
1546+
1547+
return n_outputs_max;
1548+
}
1549+
15581550
// do mat_mul, while optionally apply lora
15591551
ggml_tensor * llama_context::build_lora_mm(
15601552
ggml_context * ctx0,
@@ -2827,8 +2819,7 @@ float * llama_get_logits(struct llama_context * ctx) {
28272819
llama_synchronize(ctx);
28282820

28292821
// reorder logits for backward compatibility
2830-
// TODO: maybe deprecate this
2831-
llama_output_reorder(*ctx);
2822+
ctx->reorder_outputs();
28322823

28332824
return ctx->logits;
28342825
}
@@ -2877,8 +2868,7 @@ float * llama_get_embeddings(struct llama_context * ctx) {
28772868
llama_synchronize(ctx);
28782869

28792870
// reorder embeddings for backward compatibility
2880-
// TODO: maybe deprecate this
2881-
llama_output_reorder(*ctx);
2871+
ctx->reorder_outputs();
28822872

28832873
return ctx->embd;
28842874
}
@@ -3187,7 +3177,7 @@ struct llama_data_write {
31873177
//}
31883178

31893179
void write_output_ids(struct llama_context * ctx) {
3190-
llama_output_reorder(*ctx);
3180+
ctx->reorder_outputs();
31913181

31923182
const uint32_t n_outputs = ctx->n_outputs;
31933183

@@ -3281,7 +3271,7 @@ struct llama_data_read {
32813271
uint32_t n_outputs;
32823272
read_to(&n_outputs, sizeof(n_outputs));
32833273

3284-
if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
3274+
if (n_outputs > ctx->reserve_outputs(n_outputs)) {
32853275
throw std::runtime_error("could not reserve outputs");
32863276
}
32873277

src/llama-context.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ struct llama_context {
114114

115115
void set_inputs(const llama_ubatch & ubatch);
116116

117+
// make the outputs have the same order they had in the user-provided batch
118+
// TODO: maybe deprecate this
119+
void reorder_outputs();
120+
121+
// Make sure enough space is available for outputs.
122+
// Returns max number of outputs for which space was reserved.
123+
size_t reserve_outputs(size_t n_outputs);
124+
117125
ggml_tensor * build_lora_mm(
118126
ggml_context * ctx0,
119127
ggml_tensor * w,

0 commit comments

Comments
 (0)