Skip to content

Commit 87384fb

Browse files
committed
K-Shift commit
1 parent ff252ea commit 87384fb

File tree

10 files changed

+114
-11
lines changed

10 files changed

+114
-11
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
922922
params.sparams.temp = std::max(params.sparams.temp, 0.0f);
923923
}
924924
).set_sparam());
925+
add_opt(common_arg(
926+
{"--k-shift"}, "N",
927+
string_format("k-shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift),
928+
[](common_params & params, int value) {
929+
params.sparams.k_shift = value;
930+
}
931+
).set_sparam());
925932
add_opt(common_arg(
926933
{"--top-k"}, "N",
927934
string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2092,6 +2092,7 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
20922092

20932093
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
20942094
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
2095+
fprintf(stream, "k_shift: %d # default: 0\n", sparams.k_shift);
20952096
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
20962097
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
20972098
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);

common/common.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ enum llama_example {
8585
enum common_sampler_type {
8686
COMMON_SAMPLER_TYPE_NONE = 0,
8787
COMMON_SAMPLER_TYPE_DRY = 1,
88-
COMMON_SAMPLER_TYPE_TOP_K = 2,
89-
COMMON_SAMPLER_TYPE_TOP_P = 3,
90-
COMMON_SAMPLER_TYPE_MIN_P = 4,
91-
COMMON_SAMPLER_TYPE_TFS_Z = 5,
92-
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
93-
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
94-
COMMON_SAMPLER_TYPE_XTC = 8,
95-
COMMON_SAMPLER_TYPE_INFILL = 9,
88+
COMMON_SAMPLER_TYPE_K_SHIFT = 2,
89+
COMMON_SAMPLER_TYPE_TOP_K = 3,
90+
COMMON_SAMPLER_TYPE_TOP_P = 4,
91+
COMMON_SAMPLER_TYPE_MIN_P = 5,
92+
COMMON_SAMPLER_TYPE_TFS_Z = 6,
93+
COMMON_SAMPLER_TYPE_TYPICAL_P = 7,
94+
COMMON_SAMPLER_TYPE_TEMPERATURE = 8,
95+
COMMON_SAMPLER_TYPE_XTC = 9,
96+
COMMON_SAMPLER_TYPE_INFILL = 10,
9697
};
9798

9899
// dimensionality reduction methods, used by cvector-generator
@@ -108,6 +109,7 @@ struct common_sampler_params {
108109
int32_t n_prev = 64; // number of previous tokens to remember
109110
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
110111
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
112+
int32_t k_shift = 0; // 0 = disabled
111113
int32_t top_k = 40; // <= 0 to use vocab size
112114
float top_p = 0.95f; // 1.0 = disabled
113115
float min_p = 0.05f; // 0.0 = disabled
@@ -138,6 +140,7 @@ struct common_sampler_params {
138140

139141
std::vector<enum common_sampler_type> samplers = {
140142
COMMON_SAMPLER_TYPE_DRY,
143+
COMMON_SAMPLER_TYPE_K_SHIFT,
141144
COMMON_SAMPLER_TYPE_TOP_K,
142145
COMMON_SAMPLER_TYPE_TFS_Z,
143146
COMMON_SAMPLER_TYPE_TYPICAL_P,

common/sampling.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,11 @@ std::string common_sampler_params::print() const {
131131
snprintf(result, sizeof(result),
132132
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
133133
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
134-
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
134+
"\tk_shift = %d, top_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
135135
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
136136
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
137137
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
138-
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
138+
k_shift, top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
139139
mirostat, mirostat_eta, mirostat_tau);
140140

141141
return std::string(result);
@@ -187,6 +187,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
187187
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
188188
}
189189
break;
190+
case COMMON_SAMPLER_TYPE_K_SHIFT:
191+
llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift));
192+
break;
190193
case COMMON_SAMPLER_TYPE_TOP_K:
191194
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
192195
break;
@@ -372,6 +375,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
372375
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
373376
switch (cnstr) {
374377
case COMMON_SAMPLER_TYPE_DRY: return 'd';
378+
case COMMON_SAMPLER_TYPE_K_SHIFT: return 's';
375379
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
376380
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
377381
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
@@ -387,6 +391,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
387391
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
388392
switch (cnstr) {
389393
case COMMON_SAMPLER_TYPE_DRY: return "dry";
394+
case COMMON_SAMPLER_TYPE_K_SHIFT: return "k_shift";
390395
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
391396
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
392397
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
@@ -403,6 +408,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
403408
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
404409
{ "dry", COMMON_SAMPLER_TYPE_DRY },
405410
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
411+
{ "k_shift", COMMON_SAMPLER_TYPE_K_SHIFT },
406412
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
407413
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
408414
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
@@ -416,6 +422,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
416422
// make it ready for both system names and input names
417423
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
418424
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
425+
{ "k-shift", COMMON_SAMPLER_TYPE_K_SHIFT },
419426
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
420427
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
421428
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -451,6 +458,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
451458
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
452459
std::unordered_map<char, common_sampler_type> sampler_name_map = {
453460
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
461+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_K_SHIFT), COMMON_SAMPLER_TYPE_K_SHIFT },
454462
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
455463
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
456464
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },

examples/main/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,14 @@ DRY sampling provides more nuanced control over text generation, particularly fo
211211

212212
Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"`
213213

214+
### K-Shift Sampling
215+
216+
- `--k-shift N`: Shift the first token selection by cutting out N tokens from the top once (default: 0).
217+
218+
K-Shift is a sampling method that guides models away from the most obvious output, eliciting reasoning and analysis. It cuts out k top tokens once at the beginning of inference, making sure that the dialog will start from a less obvious path without guiding the model too much. The method was mentoned in a paper [Chain-of-Thought Reasoning without Prompting](https://arxiv.org/pdf/2402.10200) as a simple trick to guiding a model towards reasoning. In practice, K-Shift can improve the quality of reasoning, help bypass bias or censorship in certain cases, and may also be used as a diagnostics tool. K-Shift is intended to be used with greedy sampling (`--k-shift 10 --top-k 1`), but can help with creative writing too - albeit, not as much as XTC. The default value is 0.
219+
220+
Example usage: `--k-shift 10`
221+
214222
### Top-K Sampling
215223

216224
- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).

examples/server/public/index-new.html

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
dry_base: 1.75, // 0.0 = disabled
4545
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
4646
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
47-
top_k: 0, // <= 0 to use vocab size
47+
k_shift: 0, // <= 0 to use vocab size
48+
top_k: 0, // 0 = disabled
4849
top_p: 1.0, // 1.0 = disabled
4950
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
5051
xtc_probability: 0.0, // 0 = disabled;
@@ -835,6 +836,7 @@
835836
<details>
836837
<summary><span class="summary-title">Further Options</span></summary>
837838
<fieldset class="params">
839+
${IntField({ label: "K-Shift", title: "Cuts out first k tokens once at the start of sampling. Intended to use with greedy sampling.", max: 100, min: 0, step: 1, name: "k_shift", value: params.value.k_shift })}
838840
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
839841
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
840842
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}

examples/server/public/index.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@
308308
dry_base: 1.75, // 0.0 = disabled
309309
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
310310
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
311+
k_shift: 0, // 0 = disabled
311312
top_k: 40, // <= 0 to use vocab size
312313
top_p: 0.95, // 1.0 = disabled
313314
min_p: 0.05, // 0 = disabled
@@ -1008,6 +1009,7 @@
10081009
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
10091010
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
10101011
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
1012+
${IntField({ label: "K-shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })}
10111013
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
10121014
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
10131015
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

examples/server/server.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ struct server_context {
804804
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
805805
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
806806
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
807+
slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift);
807808
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
808809
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
809810
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
@@ -1144,6 +1145,7 @@ struct server_context {
11441145
{"temperature", slot.sparams.temp},
11451146
{"dynatemp_range", slot.sparams.dynatemp_range},
11461147
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
1148+
{"k_shift", slot.sparams.k_shift},
11471149
{"top_k", slot.sparams.top_k},
11481150
{"top_p", slot.sparams.top_p},
11491151
{"min_p", slot.sparams.min_p},

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,9 @@ extern "C" {
11021102
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
11031103
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
11041104

1105+
1106+
LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k);
1107+
11051108
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
11061109
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
11071110
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

src/llama-sampling.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,17 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
188188
cur_p->size = k;
189189
}
190190

191+
static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) {
192+
// sort before shifting
193+
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
194+
return a.logit > b.logit;
195+
});
196+
197+
// shift to a token #[k]
198+
cur_p->data += k;
199+
cur_p->size -= k;
200+
}
201+
191202
static uint32_t get_rng_seed(uint32_t seed) {
192203
if (seed == LLAMA_DEFAULT_SEED) {
193204
// use system clock if std::random_device is not a true RNG
@@ -1177,6 +1188,62 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
11771188
};
11781189
}
11791190

1191+
// k-shift
1192+
1193+
struct llama_sampler_k_shift {
1194+
const int32_t k;
1195+
bool k_set;
1196+
};
1197+
1198+
static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*smpl*/) {
1199+
return "k-shift";
1200+
}
1201+
1202+
static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1203+
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
1204+
1205+
if (ctx->k <= 0 || ctx->k_set == true) {
1206+
return;
1207+
}
1208+
1209+
llama_sampler_top_shift_impl(cur_p, ctx->k);
1210+
ctx->k_set = true;
1211+
}
1212+
1213+
static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) {
1214+
auto * ctx = (const llama_sampler_k_shift *) smpl->ctx;
1215+
1216+
return llama_sampler_init_k_shift(ctx->k);
1217+
}
1218+
1219+
static void llama_sampler_k_shift_free(struct llama_sampler * smpl) {
1220+
delete (llama_sampler_k_shift *) smpl->ctx;
1221+
}
1222+
1223+
static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) {
1224+
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
1225+
ctx->k_set = false;
1226+
}
1227+
1228+
static struct llama_sampler_i llama_sampler_k_shift_i = {
1229+
/* .name = */ llama_sampler_k_shift_name,
1230+
/* .accept = */ nullptr,
1231+
/* .apply = */ llama_sampler_k_shift_apply,
1232+
/* .reset = */ llama_sampler_k_shift_reset,
1233+
/* .clone = */ llama_sampler_k_shift_clone,
1234+
/* .free = */ llama_sampler_k_shift_free,
1235+
};
1236+
1237+
struct llama_sampler * llama_sampler_init_k_shift(int32_t k) {
1238+
return new llama_sampler {
1239+
/* .iface = */ &llama_sampler_k_shift_i,
1240+
/* .ctx = */ new llama_sampler_k_shift {
1241+
/* .k = */ k,
1242+
/* .k_set = */ false,
1243+
},
1244+
};
1245+
}
1246+
11801247
// mirostat
11811248

11821249
struct llama_sampler_mirostat {

0 commit comments

Comments
 (0)