Skip to content

Commit 4f80618

Browse files
committed
sampling : add adaptive temperature sampler
1 parent 6687503 commit 4f80618

File tree

7 files changed

+93
-11
lines changed

7 files changed

+93
-11
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10721072
params.sparams.dynatemp_exponent = std::stof(value);
10731073
}
10741074
).set_sparam());
1075+
add_opt(common_arg(
1076+
{"--temp-adaptive"},
1077+
"ignore arguments for temp and dynatemp, and automatically set temperature based on entropy",
1078+
[](common_params & params) {
1079+
params.sparams.temp_adaptive = true;
1080+
}
1081+
).set_sparam());
10751082
add_opt(common_arg(
10761083
{"--mirostat"}, "N",
10771084
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ struct common_sampler_params {
132132
bool penalize_nl = false; // consider newlines as a repeatable token
133133
bool ignore_eos = false;
134134
bool no_perf = false; // disable performance metrics
135+
bool temp_adaptive = false; // enables automatic adaptive setting of temperature
135136

136137
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
137138

common/sampling.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,11 @@ std::string common_sampler_params::print() const {
131131
snprintf(result, sizeof(result),
132132
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
133133
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
134-
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
134+
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f, temp_adaptive = %d\n"
135135
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
136136
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
137137
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
138-
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
138+
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, temp_adaptive,
139139
mirostat, mirostat_eta, mirostat_tau);
140140

141141
return std::string(result);
@@ -188,28 +188,34 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
188188
}
189189
break;
190190
case COMMON_SAMPLER_TYPE_TOP_K:
191-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
191+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
192192
break;
193193
case COMMON_SAMPLER_TYPE_TOP_P:
194-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
194+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
195195
break;
196196
case COMMON_SAMPLER_TYPE_MIN_P:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
197+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
198198
break;
199199
case COMMON_SAMPLER_TYPE_XTC:
200-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
200+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
201201
break;
202202
case COMMON_SAMPLER_TYPE_TFS_Z:
203-
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
203+
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free (params.tfs_z, params.min_keep));
204204
break;
205205
case COMMON_SAMPLER_TYPE_TYPICAL_P:
206-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
206+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
207207
break;
208208
case COMMON_SAMPLER_TYPE_TEMPERATURE:
209-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
209+
{
210+
if (!params.temp_adaptive) {
211+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
212+
} else {
213+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_adaptive());
214+
}
215+
}
210216
break;
211217
case COMMON_SAMPLER_TYPE_INFILL:
212-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
218+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
213219
break;
214220
default:
215221
GGML_ASSERT(false && "unknown sampler type");

examples/server/server.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ struct server_context {
812812
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
813813
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
814814
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
815+
slot.sparams.temp_adaptive = json_value(data, "temp_adaptive", default_sparams.temp_adaptive);
815816
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
816817
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
817818
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
@@ -1142,6 +1143,7 @@ struct server_context {
11421143
{"seed", slot.sparams.seed},
11431144
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
11441145
{"temperature", slot.sparams.temp},
1146+
{"temp_adaptive", slot.sparams.temp_adaptive},
11451147
{"dynatemp_range", slot.sparams.dynatemp_range},
11461148
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
11471149
{"top_k", slot.sparams.top_k},

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,9 @@ extern "C" {
10991099
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
11001100
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
11011101

1102+
/// @details Adaptive temperature implementation described in the paper https://arxiv.org/abs/2410.01104.
1103+
LLAMA_API struct llama_sampler * llama_sampler_init_temp_adaptive (void);
1104+
11021105
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
11031106
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
11041107

src/llama-sampling.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,55 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
10821082
};
10831083
}
10841084

1085+
// temp-adaptive
1086+
1087+
static const char * llama_sampler_temp_adaptive_name(const struct llama_sampler * /*smpl*/) {
1088+
return "temp-adaptive";
1089+
}
1090+
1091+
static void llama_sampler_temp_adaptive_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
1092+
llama_sampler_softmax_impl(cur_p);
1093+
1094+
// calculate entropy
1095+
float entropy = 0.0f;
1096+
for (size_t i = 0; i < cur_p->size; ++i) {
1097+
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p + 1e-9);
1098+
}
1099+
1100+
// calculate beta
1101+
float beta = 0.0f;
1102+
if (entropy > 0.5) { // don't overcorrect low-entropy heads
1103+
beta = -0.037 * powf(entropy, 4)
1104+
+ 0.481 * powf(entropy, 3)
1105+
+ -2.3 * powf(entropy, 2)
1106+
+ 4.917 * entropy
1107+
+ -1.791;
1108+
// never increase entropy
1109+
beta = (beta < 1.0) ? 1.0 : beta;
1110+
} else {
1111+
beta = 1.0;
1112+
}
1113+
1114+
// beta = 1 / temp
1115+
llama_sampler_temp_impl(cur_p, 1.0f / beta);
1116+
}
1117+
1118+
static struct llama_sampler_i llama_sampler_temp_adaptive_i = {
1119+
/* .name = */ llama_sampler_temp_adaptive_name,
1120+
/* .accept = */ nullptr,
1121+
/* .apply = */ llama_sampler_temp_adaptive_apply,
1122+
/* .reset = */ nullptr,
1123+
/* .clone = */ nullptr,
1124+
/* .free = */ nullptr,
1125+
};
1126+
1127+
struct llama_sampler * llama_sampler_init_temp_adaptive() {
1128+
return new llama_sampler {
1129+
/* .iface = */ &llama_sampler_temp_adaptive_i,
1130+
/* .ctx = */ nullptr,
1131+
};
1132+
}
1133+
10851134
// xtc
10861135

10871136
struct llama_sampler_xtc {

tests/test-sampling.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
7272
tester.check();
7373
}
7474

75+
static void test_temp_adaptive(const std::vector<float> & probs, const std::vector<float> & probs_expected) {
76+
sampler_tester tester(probs, probs_expected);
77+
78+
DUMP(&tester.cur_p);
79+
tester.apply(llama_sampler_init_temp_adaptive());
80+
tester.apply(llama_sampler_init_dist(0));
81+
DUMP(&tester.cur_p);
82+
83+
tester.check();
84+
}
85+
7586
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
7687
sampler_tester tester(probs, probs_expected);
7788

@@ -311,7 +322,10 @@ int main(void) {
311322
ggml_time_init();
312323

313324
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
314-
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
325+
test_temp({0.4f, 0.3f, 0.2f, 0.1f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
326+
327+
test_temp_adaptive({0.1f, 0.2f, 0.3f, 0.4f}, {0.488836, 0.304651, 0.156445, 0.050068});
328+
test_temp_adaptive({0.7f, 0.1f, 0.1f, 0.1f}, {0.764643, 0.078452, 0.078452, 0.078452});
315329

316330
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
317331
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);

0 commit comments

Comments
 (0)