Skip to content

Commit a96ddd7

Browse files
committed
re-write + change parameters + simplify
1 parent 67a7336 commit a96ddd7

File tree

4 files changed

+130
-211
lines changed

4 files changed

+130
-211
lines changed

common/common.h

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -164,35 +164,35 @@ enum common_params_sampling_config : uint64_t {
164164
struct common_params_sampling {
165165
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
166166

167-
int32_t n_prev = 64; // number of previous tokens to remember
168-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
169-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
170-
int32_t top_k = 40; // <= 0 to use vocab size
171-
float top_p = 0.95f; // 1.0 = disabled
172-
float min_p = 0.05f; // 0.0 = disabled
173-
float xtc_probability = 0.00f; // 0.0 = disabled
174-
float xtc_threshold = 0.10f; // > 0.5 disables XTC
175-
float typ_p = 1.00f; // typical_p, 1.0 = disabled
176-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
177-
float dynatemp_range = 0.00f; // 0.0 = disabled
178-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
179-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
180-
float penalty_repeat = 1.00f; // 1.0 = disabled
181-
float penalty_freq = 0.00f; // 0.0 = disabled
182-
float penalty_present = 0.00f; // 0.0 = disabled
183-
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
184-
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
185-
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
186-
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
187-
float power_law_target = -1.0f; // target probability for Power Law sampling (valid range 0.0 to 1.0; <0 = disabled)
188-
int32_t power_law_window_size = 10; // rolling window size for target adaptation in Power Law sampling (≤0 = fixed target)
189-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
190-
float top_n_sigma = -1.00f; // -1.0 = disabled
191-
float mirostat_tau = 5.00f; // target entropy
192-
float mirostat_eta = 0.10f; // learning rate
193-
bool ignore_eos = false;
194-
bool no_perf = false; // disable performance metrics
195-
bool timing_per_token = false;
167+
int32_t n_prev = 64; // number of previous tokens to remember
168+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
169+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
170+
int32_t top_k = 40; // <= 0 to use vocab size
171+
float top_p = 0.95f; // 1.0 = disabled
172+
float min_p = 0.05f; // 0.0 = disabled
173+
float xtc_probability = 0.00f; // 0.0 = disabled
174+
float xtc_threshold = 0.10f; // > 0.5 disables XTC
175+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
176+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
177+
float dynatemp_range = 0.00f; // 0.0 = disabled
178+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
179+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
180+
float penalty_repeat = 1.00f; // 1.0 = disabled
181+
float penalty_freq = 0.00f; // 0.0 = disabled
182+
float penalty_present = 0.00f; // 0.0 = disabled
183+
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
184+
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
185+
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
186+
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
187+
float power_law_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
188+
float power_law_decay = 0.9f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
189+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
190+
float top_n_sigma = -1.00f; // -1.0 = disabled
191+
float mirostat_tau = 5.00f; // target entropy
192+
float mirostat_eta = 0.10f; // learning rate
193+
bool ignore_eos = false;
194+
bool no_perf = false; // disable performance metrics
195+
bool timing_per_token = false;
196196

197197
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
198198

include/llama.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,24 +1289,28 @@ extern "C" {
12891289
const char ** seq_breakers,
12901290
size_t num_breakers);
12911291

1292-
/// @details power-law sampler - reshapes probability distribution to target specific probability ranges
1292+
/// power-law
1293+
///
1294+
/// this sampler implements a power law probability transformation with adaptive
1295+
/// target tracking. it reshapes token probability distributions to favor tokens near a
1296+
/// configurable target probability, rather than always selecting from the highest probability
1297+
/// candidates. it is ideal for creative, unpredictable text generation.
12931298
///
12941299
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
12951300
/// rather than just transforming logits. therefore it must always be the last sampler in the
12961301
/// sampler chain.
12971302
///
1298-
/// it is recommended to only perform minimal truncation before this sampler.
1303+
/// minimal truncation before this sampler is recommended.
12991304
///
1300-
/// @param target target probability (valid range 0.0 to 1.0; <0 = disabled)
1301-
/// @param window_size rolling window size for target adaptation (≤0 = fixed target)
1302-
/// @param seed RNG seed
1305+
/// @param target select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
1306+
/// @param decay decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
13031307
///
1304-
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation)
1308+
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
13051309
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
13061310
LLAMA_API struct llama_sampler * llama_sampler_init_power_law(
1307-
float target,
1308-
int32_t window_size,
1309-
uint32_t seed);
1311+
float target,
1312+
float decay,
1313+
uint32_t seed);
13101314

13111315
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
13121316
int32_t n_vocab,

src/llama-sampling.cpp

Lines changed: 61 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,133 +2315,62 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
23152315

23162316
// power-law
23172317
//
2318+
// this sampler implements a power law probability transformation with adaptive
2319+
// target tracking. it reshapes token probability distributions to favor tokens near a
2320+
// configurable target probability, rather than always selecting from the highest probability
2321+
// candidates. it is ideal for creative, unpredictable text generation.
2322+
//
23182323
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
23192324
// rather than just transforming logits. therefore it must always be the last sampler in the
23202325
// sampler chain.
23212326
//
2322-
// it is recommended to only perform minimal truncation before this sampler.
2327+
// minimal truncation before this sampler is recommended.
23232328
//
2324-
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation)
2329+
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
23252330
// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
23262331

23272332
struct llama_sampler_power_law {
2328-
const float target;
2329-
const int32_t window_size;
23302333

2331-
const uint32_t seed;
2332-
std::mt19937 rng;
2333-
ring_buffer<float> window;
2334+
// the desired average probability for selected tokens (0.0 to 1.0)
2335+
// higher values favor more probable tokens (more deterministic)
2336+
// lower values favor less probable tokens (more creative)
2337+
// negative values disable Power Law sampling (sample from distribution as-is)
2338+
const float target;
2339+
2340+
// controls how quickly history influence fades (0.0 to 0.99)
2341+
// lower values = faster adaptation, more reactive to recent tokens
2342+
// higher values = slower adaptation, more stable over time
2343+
// effective history length ≈ 1/(1-decay) tokens
2344+
// examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
2345+
// internally clamped to <= 0.99 to prevent unbounded accumulation
2346+
const float decay;
2347+
2348+
const uint32_t seed;
2349+
std::mt19937 rng;
2350+
2351+
// historical token probabilities weighted by recency
2352+
float weighted_sum;
2353+
// sum of weights, converges to 1/(1-decay)
2354+
float total_weight;
23342355
};
23352356

23362357
static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
23372358
return "power-law";
23382359
}
23392360

2340-
// Computes the target probability for the current sampling step.
2341-
//
2342-
// The target determines which token probabilities the power law distribution
2343-
// will favor. This function implements a dynamic feedback mechanism to maintain
2344-
// an average selection probability close to the base target over time.
2345-
//
2346-
// When the window is empty:
2347-
// - Returns the base target value (ctx->target)
2348-
//
2349-
// When the window has entries:
2350-
// - Calculates what the next target should be to keep the weighted average
2351-
// of selected token probabilities equal to ctx->target
2352-
// - Uses exponential decay weighting: newer values have more influence
2353-
//
2354-
// Exponential Decay Weighting:
2355-
// After inserting the new value, the weights will be:
2356-
// new_value: weight = 1 (age 0, newest)
2357-
// rat(0): weight = decay (age 1)
2358-
// rat(1): weight = decay^2 (age 2)
2359-
// ...
2360-
// rat(sz-2): weight = decay^(sz-1)
2361-
// rat(sz-1): evicted (oldest)
2362-
//
2363-
// The "effective window size" is approximately 1/(1-decay):
2364-
// decay=0.9 → effective window ≈ 10 tokens
2365-
// decay=0.95 → effective window ≈ 20 tokens
2366-
// decay=1.0 → no decay, equivalent to simple average (original behavior)
2367-
//
2368-
// Formula derivation:
2369-
// We want the weighted average after insertion to equal target:
2370-
//
2371-
// (new_value * 1 + Σ rat(i) * decay^(i+1)) / total_weight = target
2372-
//
2373-
// Where total_weight = 1 + decay + decay^2 + ... + decay^(sz-1)
2374-
// = (1 - decay^sz) / (1 - decay) [geometric series]
2375-
//
2376-
// Solving for new_value:
2377-
// new_value = target * total_weight - decay * Σ rat(i) * decay^i
2378-
//
2379-
// The factor of 'decay' on the sum accounts for all existing values
2380-
// shifting one position older when the new value is inserted.
2381-
//
2382-
// The exponential decay helps prevent "fishtailing" - a phenomenon where
2383-
// forced high-probability selections (when the model is very confident)
2384-
// cause the algorithm to overcorrect with many low-probability selections,
2385-
// then swing back the other way. By decaying old values, the influence of
2386-
// forced selections fades faster, reducing oscillation amplitude and
2387-
// recovery time.
2388-
//
2389-
// Finally, the computed target is clamped to [min_target, max_target] to
2390-
// prevent extreme values that could destabilize sampling.
2391-
//
2392-
static float llama_sampler_power_law_compute_target(
2393-
const llama_sampler_power_law * ctx,
2394-
float min_target,
2395-
float max_target,
2396-
float tail_decay) {
2397-
2398-
float computed_target = ctx->target;
2399-
size_t sz = ctx->window.size();
2400-
2401-
if (sz > 0) {
2402-
// Check if window is at capacity (oldest element will be evicted on next push)
2403-
// Use the window_size parameter from context, not a capacity() method
2404-
const bool window_full = (sz == (size_t)ctx->window_size);
2405-
2406-
// Compute weighted sum with exponential decay
2407-
// rat(0) = newest in buffer, gets weight 1
2408-
// rat(i) gets weight decay^i
2409-
//
2410-
// When window is full: exclude oldest element (it will be evicted)
2411-
// When window is not full: include all elements (nothing evicted)
2412-
float weighted_sum = 0.0f;
2413-
float weight = 1.0f;
2414-
size_t elements_to_sum = window_full ? (sz - 1) : sz;
2415-
2416-
for (size_t i = 0; i < elements_to_sum; ++i) {
2417-
weighted_sum += ctx->window.rat(i) * weight;
2418-
weight *= tail_decay;
2419-
}
2420-
2421-
// Shift weights to account for new value taking position 0
2422-
// All existing values age by 1, so multiply their weights by decay
2423-
float shifted_weighted_sum = weighted_sum * tail_decay;
2424-
2425-
// Compute total weight after new value is inserted
2426-
// When full: sz elements remain (oldest evicted, new added)
2427-
// When not full: sz + 1 elements (new added, nothing evicted)
2428-
size_t final_element_count = window_full ? sz : (sz + 1);
2429-
2430-
float total_weight;
2431-
if (std::abs(tail_decay - 1.0f) < FLT_EPSILON) {
2432-
total_weight = (float) final_element_count;
2433-
} else {
2434-
total_weight = (1.0f - std::pow(tail_decay, (float) final_element_count)) / (1.0f - tail_decay);
2435-
}
2436-
2437-
// Solve for the new value that achieves target weighted average
2438-
float next_value = (ctx->target * total_weight) - shifted_weighted_sum;
2439-
2440-
// Clamp to allowed range
2441-
computed_target = std::max(min_target, std::min(next_value, max_target));
2361+
// compute the adaptive target probability for the current sampling step
2362+
static float llama_sampler_power_law_compute_target(const llama_sampler_power_law * ctx, float decay) {
2363+
if (ctx->total_weight == 0.0f) {
2364+
// if there is no history, just use base target
2365+
return ctx->target;
24422366
}
24432367

2444-
return computed_target;
2368+
// maintain a running weighted sum with exponential decay
2369+
float new_total_weight = 1.0f + decay * ctx->total_weight;
2370+
float next_value = ctx->target * new_total_weight - decay * ctx->weighted_sum;
2371+
2372+
// clamp to [0.0, 1.0]
2373+
return std::max(0.0f, std::min(next_value, 1.0f));
24452374
}
24462375

24472376
static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -2455,30 +2384,25 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
24552384
return;
24562385
}
24572386

2387+
// clamp decay to avoid degenerate case at 1.0 (unbounded accumulation)
2388+
const float decay = std::min(ctx->decay, 0.99f);
2389+
24582390
// fixed power law transform parameters
24592391
const float distribution_width = 0.3f;
24602392
const float peak_logit_value = 5.0f;
24612393
const float tail_heaviness = 2.0f;
24622394

2463-
// target computation parameters
2464-
const float min_target = 0.0f;
2465-
const float max_target = 1.0f;
2466-
const float tail_decay = 0.50f; // exponential decay factor for history weighting
2467-
// lower = faster response, higher = more stability
2468-
// effective window ≈ 1/(1-decay) ≈ 20 tokens
2469-
2470-
// compute probabilities to get the "original" values
2395+
// get the original probabilities
24712396
llama_sampler_softmax_impl(cur_p, false);
24722397

2473-
// store original probabilities (used for future target adaptation)
2398+
// store the original probabilities (needed for history update after selection)
24742399
std::vector<float> original_probs;
24752400
original_probs.reserve(cur_p->size);
24762401
for (size_t i = 0; i < cur_p->size; ++i) {
24772402
original_probs.push_back(cur_p->data[i].p);
24782403
}
24792404

2480-
// calculate adaptive target
2481-
float computed_target = llama_sampler_power_law_compute_target(ctx, min_target, max_target, tail_decay);
2405+
float computed_target = llama_sampler_power_law_compute_target(ctx, decay);
24822406

24832407
//
24842408
// power law transform
@@ -2492,40 +2416,30 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
24922416

24932417
llama_sampler_softmax_impl(cur_p, false);
24942418

2495-
// sample from the transformed distribution
2419+
// sample from transformed distribution
24962420
const int idx = llama_sample_dist(cur_p, ctx->rng);
24972421
cur_p->selected = idx;
24982422

2499-
// uncomment this to log the target values and history window contents for every token
2500-
//
2501-
// fprintf(stderr, "power_law: window_size=%zu/%d values=[",
2502-
// ctx->window.size(), ctx->window_size);
2503-
// for (size_t i = 0; i < ctx->window.size(); ++i) {
2504-
// fprintf(stderr, "%.1f", ctx->window.rat(i));
2505-
// if (i < ctx->window.size() - 1) fprintf(stderr, ",");
2506-
// }
2507-
// fprintf(stderr, "] computed_target=%.4f selected_token=%d orig_prob=%.4f\n",
2508-
// computed_target, cur_p->data[idx].id, original_probs[idx]);
2509-
// fflush(stderr);
2510-
2511-
// add the ORIGINAL probability to the rolling window
2512-
float original_p = original_probs[idx];
2513-
2514-
ctx->window.push_back(original_p);
2423+
// update running history with the original probability of the selected token
2424+
float original_p = original_probs[idx];
2425+
ctx->weighted_sum = original_p + decay * ctx->weighted_sum;
2426+
ctx->total_weight = 1.0f + decay * ctx->total_weight;
25152427
}
25162428

25172429
static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
2518-
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
2519-
ctx->window = ring_buffer<float>(ctx->window_size);
2430+
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
2431+
ctx->weighted_sum = 0.0f;
2432+
ctx->total_weight = 0.0f;
25202433
}
25212434

25222435
static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) {
25232436
const auto * ctx = (const llama_sampler_power_law *) smpl->ctx;
2524-
auto * result = llama_sampler_init_power_law(ctx->target, ctx->window_size, ctx->seed);
2437+
auto * result = llama_sampler_init_power_law(ctx->target, ctx->decay, ctx->seed);
25252438
auto * result_ctx = (llama_sampler_power_law *) result->ctx;
25262439

2527-
result_ctx->rng = ctx->rng;
2528-
result_ctx->window = ctx->window;
2440+
result_ctx->rng = ctx->rng;
2441+
result_ctx->weighted_sum = ctx->weighted_sum;
2442+
result_ctx->total_weight = ctx->total_weight;
25292443

25302444
return result;
25312445
}
@@ -2545,18 +2459,19 @@ static struct llama_sampler_i llama_sampler_power_law_i = {
25452459

25462460
struct llama_sampler * llama_sampler_init_power_law(
25472461
float target,
2548-
int32_t window_size,
2462+
float decay,
25492463
uint32_t seed
25502464
) {
25512465
auto seed_cur = get_rng_seed(seed);
25522466
return llama_sampler_init(
25532467
/* .iface = */ &llama_sampler_power_law_i,
25542468
/* .ctx = */ new llama_sampler_power_law {
25552469
/* .target = */ target,
2556-
/* .window_size = */ window_size,
2470+
/* .decay = */ decay,
25572471
/* .seed = */ seed_cur,
25582472
/* .rng = */ std::mt19937(seed_cur),
2559-
/* .window = */ ring_buffer<float>(window_size),
2473+
/* .weighted_sum = */ 0.0f,
2474+
/* .total_weight = */ 0.0f,
25602475
}
25612476
);
25622477
}

0 commit comments

Comments
 (0)