Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
GGML_ASSERT(false && "unknown mirostat version");
}
} else {
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
if (params.n_probs > 0) {
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
}

Expand Down
5 changes: 4 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ struct seq_draft {
int main(int argc, char ** argv) {
gpt_params params;

// needed to get candidate probs even for temp <= 0.0
params.sparams.n_probs = 128;

if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1;
}
Expand All @@ -49,7 +52,7 @@ int main(int argc, char ** argv) {
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split;

std::default_random_engine rng(params.sparams.seed);
std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
std::uniform_real_distribution<> u_dist;

// init llama.cpp
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);

/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);

/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
Expand Down
7 changes: 4 additions & 3 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
#include "llama-vocab.h"
#include "llama-grammar.h"

#include <cassert>
#include <algorithm>
#include <cstring>
#include <ctime>
#include <cassert>
#include <cfloat>
#include <chrono>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <numeric>
#include <random>
#include <unordered_map>
Expand Down
42 changes: 41 additions & 1 deletion tests/test-sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "ggml.h"
#include "llama.h"
#include "llama-sampling.h"

#ifdef NDEBUG
#undef NDEBUG
Expand Down Expand Up @@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
}

static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
std::vector<llama_token_data> cur(data.size());
std::copy(data.begin(), data.end(), cur.begin());
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(cnstr, &cur_p);
llama_sampler_reset(cnstr);
const int64_t t_start = ggml_time_us();
for (int i = 0; i < n_iter; i++) {
std::copy(data.begin(), data.end(), cur.begin());
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(cnstr, &cur_p);
llama_sampler_reset(cnstr);
}
const int64_t t_end = ggml_time_us();
llama_sampler_free(cnstr);
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
}

#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))

static void test_perf() {
const int n_vocab = 1 << 17;

std::vector<llama_token_data> data;

data.reserve(n_vocab);
for (int i = 0; i < n_vocab; i++) {
const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f);
data.emplace_back(llama_token_data{i, logit, 0.0f});
}

BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
BENCH(llama_sampler_init_softmax (), data, 32);
}

int main(void) {
ggml_time_init();

Expand Down Expand Up @@ -316,5 +354,7 @@ int main(void) {

printf("OK\n");

test_perf();

return 0;
}
Loading