Skip to content

Commit e665b57

Browse files
committed
Merge branch 'master' into gg/llama-kv-cache
ggml-ci
2 parents a0c500b + df984e0 commit e665b57

File tree

6 files changed

+106
-96
lines changed

6 files changed

+106
-96
lines changed

.devops/vulkan.Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG UBUNTU_VERSION=24.04
1+
ARG UBUNTU_VERSION=22.04
22

33
FROM ubuntu:$UBUNTU_VERSION AS build
44

@@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget
77

88
# Install Vulkan SDK and cURL
99
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
10-
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
10+
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \
1111
apt update -y && \
1212
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
1313

.github/workflows/docker.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ jobs:
3232
env:
3333
COMMIT_SHA: ${{ github.sha }}
3434
strategy:
35+
fail-fast: false
3536
matrix:
3637
config:
3738
# Multi-stage build
38-
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: false}
39+
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
40+
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, freediskspace: false}
3941
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
4042
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
4143
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@
6464

6565
if (ctx->mtl_device == nil) {
6666
ctx->mtl_device = MTLCreateSystemDefaultDevice();
67+
}
6768

69+
if (ctx->mtl_device) {
6870
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
6971
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
7072

@@ -99,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
99101
ctx->mtl_device_ref_count--;
100102

101103
if (ctx->mtl_device_ref_count == 0) {
102-
[ctx->mtl_device release];
103-
ctx->mtl_device = nil;
104+
if (ctx->mtl_device) {
105+
[ctx->mtl_device release];
106+
ctx->mtl_device = nil;
107+
}
104108
}
105109
}
106110

src/llama-context.cpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cmath>
88
#include <cstring>
99
#include <stdexcept>
10+
#include <cinttypes>
1011

1112
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1213
// TODO move to hparams if a T5 variant appears that uses a different value
@@ -336,12 +337,55 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
336337
}
337338

338339
struct llama_batch_manager : public llama_batch_manager_i {
339-
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
340+
llama_batch_manager(llama_context & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
341+
const auto & model = lctx.model;
342+
const auto & cparams = lctx.cparams;
340343
const auto & hparams = lctx.model.hparams;
341-
const auto & n_embd = hparams.n_embd;
342344

343345
const auto & kv_self = lctx.kv_self;
344346

347+
const int64_t n_tokens_all = batch.n_tokens;
348+
const int64_t n_embd = hparams.n_embd;
349+
350+
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
351+
352+
if (batch.token) {
353+
for (int64_t i = 0; i < n_tokens_all; ++i) {
354+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
355+
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
356+
throw std::runtime_error("invalid token");
357+
}
358+
}
359+
}
360+
361+
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
362+
363+
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
364+
365+
if (lctx.t_compute_start_us == 0) {
366+
lctx.t_compute_start_us = ggml_time_us();
367+
}
368+
lctx.n_queued_tokens += n_tokens_all;
369+
370+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
371+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
372+
373+
lctx.embd_seq.clear();
374+
375+
// count outputs
376+
if (batch.logits && !embd_pooled) {
377+
for (uint32_t i = 0; i < n_tokens_all; ++i) {
378+
n_outputs_all += batch.logits[i] != 0;
379+
}
380+
} else if (lctx.logits_all || embd_pooled) {
381+
n_outputs_all = n_tokens_all;
382+
} else {
383+
// keep last output only
384+
n_outputs_all = 1;
385+
}
386+
387+
const bool logits_all = n_outputs_all == n_tokens_all;
388+
345389
lctx.sbatch.from_batch(batch, n_embd,
346390
/* simple_split */ !kv_self.recurrent,
347391
/* logits_all */ logits_all);
@@ -379,9 +423,29 @@ struct llama_batch_manager : public llama_batch_manager_i {
379423
virtual bool prepare() override {
380424
const auto & cparams = lctx.cparams;
381425
const auto & hparams = lctx.model.hparams;
426+
const auto & batch = lctx.sbatch.batch;
427+
428+
const auto n_tokens_all = batch->n_tokens;
382429

383430
auto & kv_self = lctx.kv_self;
384431

432+
// count the outputs in this u_batch
433+
{
434+
int32_t n_outputs_new = 0;
435+
436+
if (n_outputs_all == n_tokens_all) {
437+
n_outputs_new = ubatch.n_tokens;
438+
} else {
439+
GGML_ASSERT(ubatch.output);
440+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
441+
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
442+
}
443+
}
444+
445+
// needs to happen before the graph is built
446+
lctx.n_outputs = n_outputs_new;
447+
}
448+
385449
// non-causal masks do not use the KV cache
386450
if (hparams.causal_attn) {
387451
lctx.kv_self_update();
@@ -459,8 +523,8 @@ struct llama_batch_manager : public llama_batch_manager_i {
459523
llama_kv_slot_restorer kv_slot_restorer;
460524
};
461525

462-
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
463-
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
526+
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch) {
527+
return std::make_unique<llama_batch_manager>(*this, batch);
464528
}
465529

466530
enum ggml_status llama_context::compute_graph(

src/llama-context.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ struct llama_batch_manager_i {
2828
virtual void restore() = 0;
2929
virtual void update() = 0;
3030
virtual void finalize() = 0;
31+
32+
// TODO: might be temporary
33+
int64_t n_outputs_all = 0;
3134
};
3235

3336
// TODO: make implementation details private
@@ -98,7 +101,7 @@ struct llama_context {
98101
void * abort_callback_data = nullptr;
99102

100103
// TODO: do not pass logits_all explicitly
101-
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch, bool logits_all);
104+
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch);
102105

103106
// returns the result of ggml_backend_sched_graph_compute_async execution
104107
enum ggml_status compute_graph(

src/llama.cpp

Lines changed: 23 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <cstring>
2424
#include <ctime>
2525
#include <functional>
26+
#include <cinttypes>
2627

2728
#if defined(_MSC_VER)
2829
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -7751,7 +7752,7 @@ static struct ggml_cgraph * llama_build_graph(
77517752
// (for non-recurrent models) or cleaned (for recurrent models)
77527753
//
77537754
// - lctx: llama context
7754-
// - batch: batch to evaluate
7755+
// - inp_batch: batch to evaluate
77557756
//
77567757
// return 0 on success
77577758
// return positive int on warning
@@ -7774,98 +7775,34 @@ static int llama_decode_impl(
77747775

77757776
const llama_batch & batch = batch_allocr.batch;
77767777

7777-
const uint32_t n_tokens_all = batch.n_tokens;
7778-
77797778
const auto & model = lctx.model;
77807779
const auto & vocab = model.vocab;
7781-
const auto & hparams = model.hparams;
77827780
const auto & cparams = lctx.cparams;
7781+
const auto & hparams = lctx.model.hparams;
77837782

7784-
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
7785-
7786-
if (batch.token) {
7787-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
7788-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
7789-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
7790-
return -1;
7791-
}
7792-
}
7793-
}
7794-
7795-
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
7796-
7797-
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
7798-
7799-
if (lctx.t_compute_start_us == 0) {
7800-
lctx.t_compute_start_us = ggml_time_us();
7801-
}
7802-
lctx.n_queued_tokens += n_tokens_all;
7803-
7783+
const int32_t n_vocab = vocab.n_tokens();
78047784
const int64_t n_embd = hparams.n_embd;
7805-
const int64_t n_vocab = vocab.n_tokens();
7806-
7807-
uint32_t n_outputs = 0;
7808-
uint32_t n_outputs_prev = 0;
78097785

7810-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
7811-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
7786+
// TODO: try catch
7787+
auto bman = lctx.prepare_batch(batch);
78127788

7813-
lctx.embd_seq.clear();
7814-
7815-
// count outputs
7816-
if (batch.logits && !embd_pooled) {
7817-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
7818-
n_outputs += batch.logits[i] != 0;
7819-
}
7820-
} else if (lctx.logits_all || embd_pooled) {
7821-
n_outputs = n_tokens_all;
7822-
} else {
7823-
// keep last output only
7824-
n_outputs = 1;
7825-
}
7789+
const auto n_outputs_all = bman->n_outputs_all;
78267790

78277791
// reserve output buffer
7828-
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
7829-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
7792+
// TODO: move to batch manager?
7793+
if (llama_output_reserve(lctx, bman->n_outputs_all) < (size_t) n_outputs_all) {
7794+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
78307795
return -2;
78317796
};
78327797

7833-
const bool logits_all = n_outputs == n_tokens_all;
7834-
7835-
//auto & kv_self = lctx.kv_self;
7836-
//llama_kv_slot_restorer kv_slot_restorer(kv_self);
7837-
7838-
//lctx.sbatch.from_batch(batch, n_embd,
7839-
// /* simple_split */ !kv_self.recurrent,
7840-
// /* logits_all */ logits_all);
7841-
7842-
auto batch_manager = lctx.prepare_batch(batch, logits_all);
7798+
int64_t n_outputs_prev = 0;
78437799

78447800
while (lctx.sbatch.n_tokens > 0) {
7845-
llama_ubatch ubatch = batch_manager->next();
7846-
7847-
const uint32_t n_tokens = ubatch.n_tokens;
7848-
7849-
// count the outputs in this u_batch
7850-
{
7851-
int32_t n_outputs_new = 0;
7852-
7853-
if (n_outputs == n_tokens_all) {
7854-
n_outputs_new = n_tokens;
7855-
} else {
7856-
GGML_ASSERT(ubatch.output);
7857-
for (uint32_t i = 0; i < n_tokens; i++) {
7858-
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
7859-
}
7860-
}
7861-
7862-
// needs to happen before the graph is built
7863-
lctx.n_outputs = n_outputs_new;
7864-
}
7801+
llama_ubatch ubatch = bman->next();
78657802

7866-
if (!batch_manager->prepare()) {
7803+
if (!bman->prepare()) {
78677804
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
7868-
batch_manager->restore();
7805+
bman->restore();
78697806
return -3;
78707807
}
78717808

@@ -7927,9 +7864,9 @@ static int llama_decode_impl(
79277864
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
79287865
}
79297866

7930-
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
7867+
const auto compute_status = lctx.compute_graph(gf, ubatch.n_tokens > 1);
79317868
if (compute_status != GGML_STATUS_SUCCESS) {
7932-
batch_manager->restore();
7869+
bman->restore();
79337870
switch (compute_status) {
79347871
case GGML_STATUS_ABORTED:
79357872
return 2;
@@ -7941,7 +7878,7 @@ static int llama_decode_impl(
79417878
}
79427879
}
79437880

7944-
batch_manager->update();
7881+
bman->update();
79457882

79467883
// plot the computation graph in dot format (for debugging purposes)
79477884
//if (n_past%100 == 0) {
@@ -7958,7 +7895,7 @@ static int llama_decode_impl(
79587895
const int32_t n_outputs_new = lctx.n_outputs;
79597896

79607897
if (n_outputs_new) {
7961-
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
7898+
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all);
79627899
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
79637900
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
79647901
}
@@ -7978,7 +7915,7 @@ static int llama_decode_impl(
79787915
const int32_t n_outputs_new = lctx.n_outputs;
79797916

79807917
if (n_outputs_new) {
7981-
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
7918+
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all);
79827919
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
79837920
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
79847921
}
@@ -8027,9 +7964,9 @@ static int llama_decode_impl(
80277964
{
80287965
bool sorted_output = true;
80297966

8030-
GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
7967+
GGML_ASSERT(lctx.sbatch.out_ids.size() == (size_t) n_outputs_all);
80317968

8032-
for (size_t i = 0; i < n_outputs; ++i) {
7969+
for (size_t i = 0; i < (size_t) n_outputs_all; ++i) {
80337970
size_t out_id = lctx.sbatch.out_ids[i];
80347971
lctx.output_ids[out_id] = i;
80357972
if (out_id != i) {
@@ -8043,12 +7980,12 @@ static int llama_decode_impl(
80437980
}
80447981

80457982
// set to total number of outputs in the batch, for use in llama_get_logits_ith
8046-
lctx.n_outputs = n_outputs;
7983+
lctx.n_outputs = n_outputs_all;
80477984

80487985
// wait for the computation to finish (automatically done when obtaining the model output)
80497986
//llama_synchronize(&lctx);
80507987

8051-
batch_manager->finalize();
7988+
bman->finalize();
80527989

80537990
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
80547991
// overlap with device computation.

0 commit comments

Comments
 (0)