@@ -2315,133 +2315,62 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
23152315
23162316// power-law
23172317//
2318+ // this sampler implements a power law probability transformation with adaptive
2319+ // target tracking. it reshapes token probability distributions to favor tokens near a
2320+ // configurable target probability, rather than always selecting from the highest probability
2321+ // candidates. it is ideal for creative, unpredictable text generation.
2322+ //
23182323// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
23192324// rather than just transforming logits. therefore it must always be the last sampler in the
23202325// sampler chain.
23212326//
2322- // it is recommended to only perform minimal truncation before this sampler.
2327+ // minimal truncation before this sampler is recommended .
23232328//
2324- // ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl, documentation )
2329+ // ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
23252330// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
23262331
23272332struct llama_sampler_power_law {
2328- const float target;
2329- const int32_t window_size;
23302333
2331- const uint32_t seed;
2332- std::mt19937 rng;
2333- ring_buffer<float > window;
2334+ // the desired average probability for selected tokens (0.0 to 1.0)
2335+ // higher values favor more probable tokens (more deterministic)
2336+ // lower values favor less probable tokens (more creative)
2337+ // negative values disable Power Law sampling (sample from distribution as-is)
2338+ const float target;
2339+
2340+ // controls how quickly history influence fades (0.0 to 0.99)
2341+ // lower values = faster adaptation, more reactive to recent tokens
2342+ // higher values = slower adaptation, more stable over time
2343+ // effective history length ≈ 1/(1-decay) tokens
2344+ // examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
2345+ // internally clamped to <= 0.99 to prevent unbounded accumulation
2346+ const float decay;
2347+
2348+ const uint32_t seed;
2349+ std::mt19937 rng;
2350+
2351+ // historical token probabilities weighted by recency
2352+ float weighted_sum;
2353+ // sum of weights, converges to 1/(1-decay)
2354+ float total_weight;
23342355};
23352356
23362357static const char * llama_sampler_power_law_name (const struct llama_sampler * /* smpl*/ ) {
23372358 return " power-law" ;
23382359}
23392360
2340- // Computes the target probability for the current sampling step.
2341- //
2342- // The target determines which token probabilities the power law distribution
2343- // will favor. This function implements a dynamic feedback mechanism to maintain
2344- // an average selection probability close to the base target over time.
2345- //
2346- // When the window is empty:
2347- // - Returns the base target value (ctx->target)
2348- //
2349- // When the window has entries:
2350- // - Calculates what the next target should be to keep the weighted average
2351- // of selected token probabilities equal to ctx->target
2352- // - Uses exponential decay weighting: newer values have more influence
2353- //
2354- // Exponential Decay Weighting:
2355- // After inserting the new value, the weights will be:
2356- // new_value: weight = 1 (age 0, newest)
2357- // rat(0): weight = decay (age 1)
2358- // rat(1): weight = decay^2 (age 2)
2359- // ...
2360- // rat(sz-2): weight = decay^(sz-1)
2361- // rat(sz-1): evicted (oldest)
2362- //
2363- // The "effective window size" is approximately 1/(1-decay):
2364- // decay=0.9 → effective window ≈ 10 tokens
2365- // decay=0.95 → effective window ≈ 20 tokens
2366- // decay=1.0 → no decay, equivalent to simple average (original behavior)
2367- //
2368- // Formula derivation:
2369- // We want the weighted average after insertion to equal target:
2370- //
2371- // (new_value * 1 + Σ rat(i) * decay^(i+1)) / total_weight = target
2372- //
2373- // Where total_weight = 1 + decay + decay^2 + ... + decay^(sz-1)
2374- // = (1 - decay^sz) / (1 - decay) [geometric series]
2375- //
2376- // Solving for new_value:
2377- // new_value = target * total_weight - decay * Σ rat(i) * decay^i
2378- //
2379- // The factor of 'decay' on the sum accounts for all existing values
2380- // shifting one position older when the new value is inserted.
2381- //
2382- // The exponential decay helps prevent "fishtailing" - a phenomenon where
2383- // forced high-probability selections (when the model is very confident)
2384- // cause the algorithm to overcorrect with many low-probability selections,
2385- // then swing back the other way. By decaying old values, the influence of
2386- // forced selections fades faster, reducing oscillation amplitude and
2387- // recovery time.
2388- //
2389- // Finally, the computed target is clamped to [min_target, max_target] to
2390- // prevent extreme values that could destabilize sampling.
2391- //
2392- static float llama_sampler_power_law_compute_target (
2393- const llama_sampler_power_law * ctx,
2394- float min_target,
2395- float max_target,
2396- float tail_decay) {
2397-
2398- float computed_target = ctx->target ;
2399- size_t sz = ctx->window .size ();
2400-
2401- if (sz > 0 ) {
2402- // Check if window is at capacity (oldest element will be evicted on next push)
2403- // Use the window_size parameter from context, not a capacity() method
2404- const bool window_full = (sz == (size_t )ctx->window_size );
2405-
2406- // Compute weighted sum with exponential decay
2407- // rat(0) = newest in buffer, gets weight 1
2408- // rat(i) gets weight decay^i
2409- //
2410- // When window is full: exclude oldest element (it will be evicted)
2411- // When window is not full: include all elements (nothing evicted)
2412- float weighted_sum = 0 .0f ;
2413- float weight = 1 .0f ;
2414- size_t elements_to_sum = window_full ? (sz - 1 ) : sz;
2415-
2416- for (size_t i = 0 ; i < elements_to_sum; ++i) {
2417- weighted_sum += ctx->window .rat (i) * weight;
2418- weight *= tail_decay;
2419- }
2420-
2421- // Shift weights to account for new value taking position 0
2422- // All existing values age by 1, so multiply their weights by decay
2423- float shifted_weighted_sum = weighted_sum * tail_decay;
2424-
2425- // Compute total weight after new value is inserted
2426- // When full: sz elements remain (oldest evicted, new added)
2427- // When not full: sz + 1 elements (new added, nothing evicted)
2428- size_t final_element_count = window_full ? sz : (sz + 1 );
2429-
2430- float total_weight;
2431- if (std::abs (tail_decay - 1 .0f ) < FLT_EPSILON) {
2432- total_weight = (float ) final_element_count;
2433- } else {
2434- total_weight = (1 .0f - std::pow (tail_decay, (float ) final_element_count)) / (1 .0f - tail_decay);
2435- }
2436-
2437- // Solve for the new value that achieves target weighted average
2438- float next_value = (ctx->target * total_weight) - shifted_weighted_sum;
2439-
2440- // Clamp to allowed range
2441- computed_target = std::max (min_target, std::min (next_value, max_target));
2361+ // compute the adaptive target probability for the current sampling step
2362+ static float llama_sampler_power_law_compute_target (const llama_sampler_power_law * ctx, float decay) {
2363+ if (ctx->total_weight == 0 .0f ) {
2364+ // if there is no history, just use base target
2365+ return ctx->target ;
24422366 }
24432367
2444- return computed_target;
2368+ // maintain a running weighted sum with exponential decay
2369+ float new_total_weight = 1 .0f + decay * ctx->total_weight ;
2370+ float next_value = ctx->target * new_total_weight - decay * ctx->weighted_sum ;
2371+
2372+ // clamp to [0.0, 1.0]
2373+ return std::max (0 .0f , std::min (next_value, 1 .0f ));
24452374}
24462375
24472376static void llama_sampler_power_law_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -2455,30 +2384,25 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
24552384 return ;
24562385 }
24572386
2387+ // clamp decay to avoid degenerate case at 1.0 (unbounded accumulation)
2388+ const float decay = std::min (ctx->decay , 0 .99f );
2389+
24582390 // fixed power law transform parameters
24592391 const float distribution_width = 0 .3f ;
24602392 const float peak_logit_value = 5 .0f ;
24612393 const float tail_heaviness = 2 .0f ;
24622394
2463- // target computation parameters
2464- const float min_target = 0 .0f ;
2465- const float max_target = 1 .0f ;
2466- const float tail_decay = 0 .50f ; // exponential decay factor for history weighting
2467- // lower = faster response, higher = more stability
2468- // effective window ≈ 1/(1-decay) ≈ 20 tokens
2469-
2470- // compute probabilities to get the "original" values
2395+ // get the original probabilities
24712396 llama_sampler_softmax_impl (cur_p, false );
24722397
2473- // store original probabilities (used for future target adaptation )
2398+ // store the original probabilities (needed for history update after selection )
24742399 std::vector<float > original_probs;
24752400 original_probs.reserve (cur_p->size );
24762401 for (size_t i = 0 ; i < cur_p->size ; ++i) {
24772402 original_probs.push_back (cur_p->data [i].p );
24782403 }
24792404
2480- // calculate adaptive target
2481- float computed_target = llama_sampler_power_law_compute_target (ctx, min_target, max_target, tail_decay);
2405+ float computed_target = llama_sampler_power_law_compute_target (ctx, decay);
24822406
24832407 //
24842408 // power law transform
@@ -2492,40 +2416,30 @@ static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_tok
24922416
24932417 llama_sampler_softmax_impl (cur_p, false );
24942418
2495- // sample from the transformed distribution
2419+ // sample from transformed distribution
24962420 const int idx = llama_sample_dist (cur_p, ctx->rng );
24972421 cur_p->selected = idx;
24982422
2499- // uncomment this to log the target values and history window contents for every token
2500- //
2501- // fprintf(stderr, "power_law: window_size=%zu/%d values=[",
2502- // ctx->window.size(), ctx->window_size);
2503- // for (size_t i = 0; i < ctx->window.size(); ++i) {
2504- // fprintf(stderr, "%.1f", ctx->window.rat(i));
2505- // if (i < ctx->window.size() - 1) fprintf(stderr, ",");
2506- // }
2507- // fprintf(stderr, "] computed_target=%.4f selected_token=%d orig_prob=%.4f\n",
2508- // computed_target, cur_p->data[idx].id, original_probs[idx]);
2509- // fflush(stderr);
2510-
2511- // add the ORIGINAL probability to the rolling window
2512- float original_p = original_probs[idx];
2513-
2514- ctx->window .push_back (original_p);
2423+ // update running history with the original probability of the selected token
2424+ float original_p = original_probs[idx];
2425+ ctx->weighted_sum = original_p + decay * ctx->weighted_sum ;
2426+ ctx->total_weight = 1 .0f + decay * ctx->total_weight ;
25152427}
25162428
25172429static void llama_sampler_power_law_reset (struct llama_sampler * smpl) {
2518- auto * ctx = (llama_sampler_power_law *) smpl->ctx ;
2519- ctx->window = ring_buffer<float >(ctx->window_size );
2430+ auto * ctx = (llama_sampler_power_law *) smpl->ctx ;
2431+ ctx->weighted_sum = 0 .0f ;
2432+ ctx->total_weight = 0 .0f ;
25202433}
25212434
25222435static struct llama_sampler * llama_sampler_power_law_clone (const struct llama_sampler * smpl) {
25232436 const auto * ctx = (const llama_sampler_power_law *) smpl->ctx ;
2524- auto * result = llama_sampler_init_power_law (ctx->target , ctx->window_size , ctx->seed );
2437+ auto * result = llama_sampler_init_power_law (ctx->target , ctx->decay , ctx->seed );
25252438 auto * result_ctx = (llama_sampler_power_law *) result->ctx ;
25262439
2527- result_ctx->rng = ctx->rng ;
2528- result_ctx->window = ctx->window ;
2440+ result_ctx->rng = ctx->rng ;
2441+ result_ctx->weighted_sum = ctx->weighted_sum ;
2442+ result_ctx->total_weight = ctx->total_weight ;
25292443
25302444 return result;
25312445}
@@ -2545,18 +2459,19 @@ static struct llama_sampler_i llama_sampler_power_law_i = {
25452459
25462460struct llama_sampler * llama_sampler_init_power_law (
25472461 float target,
2548- int32_t window_size ,
2462+ float decay ,
25492463 uint32_t seed
25502464) {
25512465 auto seed_cur = get_rng_seed (seed);
25522466 return llama_sampler_init (
25532467 /* .iface = */ &llama_sampler_power_law_i,
25542468 /* .ctx = */ new llama_sampler_power_law {
25552469 /* .target = */ target,
2556- /* .window_size = */ window_size ,
2470+ /* .decay = */ decay ,
25572471 /* .seed = */ seed_cur,
25582472 /* .rng = */ std::mt19937 (seed_cur),
2559- /* .window = */ ring_buffer<float >(window_size),
2473+ /* .weighted_sum = */ 0 .0f ,
2474+ /* .total_weight = */ 0 .0f ,
25602475 }
25612476 );
25622477}
0 commit comments