Skip to content

Commit 02b034f

Browse files
committed
refactor: simplify and correct the main code.
1 parent 8be2797 commit 02b034f

File tree

5 files changed

+79
-199
lines changed

5 files changed

+79
-199
lines changed

xllm/core/framework/sampling/rejection_sampler.cpp

Lines changed: 27 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -41,164 +41,51 @@ torch::Tensor index_select_2d(const torch::Tensor& input,
4141

4242
RejectionSamplerRateController::RejectionSamplerRateController(
4343
double fixed_acceptance_rate)
44-
: window_size_(1000),
45-
history_buffer_(1000, 0),
46-
history_idx_(0),
47-
window_sum_(0),
48-
error_buffer_(20, 0.0),
49-
error_idx_(0),
50-
pid_adj_(0.0),
51-
cumulative_err_(0.0),
52-
last_target_(-1.0),
53-
total_batches_(0),
54-
accepted_batches_(0),
55-
gen_(std::random_device{}()),
56-
dist_(0.0, 1.0),
57-
fixed_acceptance_rate_(fixed_acceptance_rate) {}
44+
: fixed_acceptance_rate_(fixed_acceptance_rate),
45+
last_target_(fixed_acceptance_rate) {}
5846

5947
torch::Tensor RejectionSamplerRateController::filter_with_acceptance_rate(
6048
const torch::Tensor& token_ids) {
61-
// Check parameters
62-
if (fixed_acceptance_rate_ < 0.0 || fixed_acceptance_rate_ > 1.0)
49+
// Basic parameter validation
50+
if (fixed_acceptance_rate_ < 0.0 || fixed_acceptance_rate_ > 1.0 ||
51+
token_ids.size(0) == 0) {
6352
return token_ids.clone();
64-
if (token_ids.size(0) == 0) return token_ids.clone();
53+
}
6554

66-
// Reset state if target acceptance rate changed
55+
// Reset counters if the target rate has changed significantly
6756
if (std::abs(last_target_ - fixed_acceptance_rate_) > 1e-6) {
68-
reset_state(fixed_acceptance_rate_);
57+
total_batches_ = 0;
58+
accepted_batches_ = 0;
59+
last_target_ = fixed_acceptance_rate_;
6960
}
7061

71-
// Update window statistics
72-
total_batches_++;
73-
double global_rate =
74-
(total_batches_ > 0)
75-
? static_cast<double>(accepted_batches_) / total_batches_
76-
: 0.0;
77-
window_sum_ -= history_buffer_[history_idx_];
78-
history_buffer_[history_idx_] = 0;
79-
double window_rate = static_cast<double>(window_sum_) / window_size_;
80-
81-
// Calculate combined error between observed and target rates
82-
double batch_weight =
83-
1.0 - std::exp(-static_cast<double>(total_batches_) / 30.0);
84-
double win_weight =
85-
std::min(static_cast<double>(window_size_) / 100.0, 0.9) * batch_weight;
86-
double combined_err =
87-
(1.0 - win_weight) * (global_rate - fixed_acceptance_rate_) +
88-
win_weight * (window_rate - fixed_acceptance_rate_);
89-
90-
// PID controller for automatic correction
91-
if (total_batches_ > 50) {
92-
double curr_err = global_rate - fixed_acceptance_rate_;
93-
error_buffer_[error_idx_] = curr_err;
94-
double prev_err = error_buffer_[(error_idx_ + 19) % 20];
95-
error_idx_ = (error_idx_ + 1) % 20;
96-
97-
double i_term =
98-
std::accumulate(error_buffer_.begin(), error_buffer_.end(), 0.0);
99-
double d_term = curr_err - prev_err;
100-
101-
// PID: kp=0.05, ki=0.001, kd=0.01
102-
pid_adj_ = 0.05 * curr_err + 0.001 * i_term + 0.01 * d_term;
103-
104-
// Clamp adjustment
105-
double limit =
106-
0.02 +
107-
0.03 * (1.0 - std::exp(-static_cast<double>(total_batches_) / 500.0));
108-
pid_adj_ = std::clamp(pid_adj_, -limit, limit);
109-
}
62+
// Calculate Drift: Difference between expected hits and actual hits
63+
double expected_hits = total_batches_ * fixed_acceptance_rate_;
64+
double drift = expected_hits - accepted_batches_;
11065

111-
// Get corrected acceptance rate
112-
double adj_rate =
113-
calculate_adjusted_rate(fixed_acceptance_rate_, combined_err);
66+
// Calculate adjusted probability
67+
// If drift > 0 (we accepted too few), increase probability.
68+
// The factor 0.1 acts as a gentle gain to correct long-term error.
69+
double adj_rate = fixed_acceptance_rate_ + (drift * 0.1);
70+
adj_rate = std::clamp(adj_rate, 0.0, 1.0);
11471

115-
// Decide accept or reject this batch
72+
// Perform rejection sampling
11673
bool accept = dist_(gen_) < adj_rate;
11774

118-
torch::Tensor out_tensor = token_ids.clone();
75+
// Update statistics
76+
total_batches_++;
11977
if (accept) {
120-
// Accept: keep as-is
12178
accepted_batches_++;
122-
history_buffer_[history_idx_] = 1;
123-
window_sum_ += 1;
124-
} else {
125-
// Reject: mask out tokens after the first (dimension 1)
126-
out_tensor.slice(1, 1).fill_(PLACEHOLDER_TOKEN_ID);
12779
}
12880

129-
// Move window index forward
130-
history_idx_ = (history_idx_ + 1) % window_size_;
131-
132-
// Update cumulative EMA error
133-
double actual_rate = static_cast<double>(accepted_batches_) / total_batches_;
134-
double alpha =
135-
0.05 * std::exp(-static_cast<double>(total_batches_) / 200.0) + 0.01;
136-
cumulative_err_ = alpha * (actual_rate - fixed_acceptance_rate_) +
137-
(1.0 - alpha) * cumulative_err_;
138-
139-
return out_tensor;
140-
}
141-
142-
void RejectionSamplerRateController::reset_state(double new_rate) {
143-
pid_adj_ = 0.0;
144-
std::fill(error_buffer_.begin(), error_buffer_.end(), 0.0);
145-
last_target_ = new_rate;
146-
}
147-
148-
double RejectionSamplerRateController::calculate_adjusted_rate(double target,
149-
double error) {
150-
// 1. Nonlinear correction based on error magnitude
151-
double err_abs = std::abs(error);
152-
double factor = 1.0;
153-
if (err_abs > 0.0005) {
154-
double strength = 2.0 + (1.0 - std::exp(-err_abs * 50.0)) * 1.5;
155-
double sign = (error > 0) ? 1.0 : -1.0;
156-
factor = 1.0 + (strength * err_abs * sign);
157-
}
158-
159-
// 2. Apply PID adjustment
160-
double rate_base = target * (factor == 0.0 ? 1.0 : 1.0 / factor);
161-
double rate = std::clamp(rate_base - pid_adj_, 0.0, 1.0);
162-
163-
// 3. Edge cases and periodic force accept for low rate
164-
if (target == 0.0) return 0.0;
165-
if (target == 1.0) return 1.0;
166-
if (target > 0 && target < 0.05) {
167-
int period = static_cast<int>(1.0 / target);
168-
if (total_batches_ % period == 0) return 1.0;
169-
}
170-
171-
// 4. Gap correction to enforce long-term target
172-
long target_accepted = std::llround(total_batches_ * target);
173-
long gap = target_accepted - accepted_batches_;
174-
long gap_threshold = std::max(1L, static_cast<long>(total_batches_ * 0.005));
175-
176-
if (std::abs(gap) >= gap_threshold) {
177-
double gap_relative =
178-
static_cast<double>(std::abs(gap)) / std::max(1L, total_batches_);
179-
double importance = 1.0 - std::exp(-gap_relative * 50.0);
180-
double strength =
181-
0.2 + 0.8 * std::exp(-static_cast<double>(total_batches_) / 1000.0);
182-
double boost = importance * strength;
183-
184-
if (gap > 0) {
185-
// Need to accept more
186-
rate = std::min(1.0, rate + (1.0 - rate) * boost);
187-
} else {
188-
// Need to reject more
189-
rate = std::max(0.0, rate * (1.0 - boost));
190-
}
191-
}
192-
193-
// 5. Random noise to prevent local optima
194-
if (rate > 0.01 && rate < 0.99 && total_batches_ > 100) {
195-
double noise_amp =
196-
0.01 * std::exp(-static_cast<double>(total_batches_) / 500.0);
197-
double noise = (dist_(gen_) * 2.0 - 1.0) * noise_amp;
198-
rate = std::clamp(rate + noise, 0.0, 1.0);
81+
// Generate output
82+
torch::Tensor out_tensor = token_ids.clone();
83+
if (!accept) {
84+
// Reject: Mask out tokens after the first one (dimension 1)
85+
out_tensor.slice(1, 1).fill_(kPlaceholderTokenId);
19986
}
20087

201-
return rate;
88+
return out_tensor;
20289
}
20390

20491
RejectionSampler::RejectionSampler(

xllm/core/framework/sampling/rejection_sampler.h

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,50 +24,36 @@ limitations under the License.
2424

2525
namespace xllm {
2626

27-
// Provide a default placeholder token ID
28-
#ifndef PLACEHOLDER_TOKEN_ID
29-
#define PLACEHOLDER_TOKEN_ID -1
30-
#endif
31-
3227
class RejectionSamplerRateController {
3328
public:
3429
explicit RejectionSamplerRateController(double fixed_acceptance_rate);
3530

36-
// Core filtering function, decides whether to accept a batch based on target
37-
// acceptance rate
31+
// Core filtering function.
32+
// Maintains the target acceptance rate using a cumulative drift correction.
3833
torch::Tensor filter_with_acceptance_rate(const torch::Tensor& token_ids);
3934

4035
private:
41-
// Reset internal state (call when the target acceptance rate changes
42-
// significantly)
43-
void reset_state(double new_rate);
44-
45-
// Compute the final acceptance rate after PID and error correction
46-
double calculate_adjusted_rate(double target, double error);
47-
size_t window_size_;
48-
49-
// History state (using circular buffer logic)
50-
std::vector<int> history_buffer_;
51-
size_t history_idx_;
52-
long window_sum_;
53-
54-
// PID controller state
55-
std::vector<double> error_buffer_;
56-
size_t error_idx_;
57-
double pid_adj_;
58-
59-
// Global statistics and error state
60-
double cumulative_err_;
61-
double last_target_;
62-
long total_batches_;
63-
long accepted_batches_;
36+
// Placeholder value for rejected tokens
37+
static constexpr int32_t kPlaceholderTokenId = -1;
38+
39+
// Tolerance for detecting when the user changes the target rate
40+
static constexpr double kTargetRateChangeTolerance = 1e-6;
6441

65-
// Random number generator
66-
std::mt19937 gen_;
67-
std::uniform_real_distribution<double> dist_;
42+
// Gain factor for drift correction (higher = faster correction, lower =
43+
// smoother)
44+
static constexpr double kDriftCorrectionGain = 0.1;
6845

69-
// acceptance rate
46+
// Target configuration
7047
double fixed_acceptance_rate_;
48+
double last_target_;
49+
50+
// Global statistics for long-term rate tracking
51+
int64_t total_batches_ = 0;
52+
int64_t accepted_batches_ = 0;
53+
54+
// Random number generator components
55+
std::mt19937 gen_{std::random_device{}()};
56+
std::uniform_real_distribution<double> dist_{0.0, 1.0};
7157
};
7258

7359
class RejectionSampler final {

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,11 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
171171
// performance debug for fixing the speculative acceptance rate
172172
// NOTE: This is for performance debugging only, it will
173173
// influence the model accuracy and should not be used in production.
174-
double fixed_acceptance_rate =
175-
xllm::util::get_fix_speculative_acceptance_rate();
176-
if (fixed_acceptance_rate >= 0.0) {
177-
rate_controller_ =
178-
std::make_shared<RejectionSamplerRateController>(fixed_acceptance_rate);
174+
std::optional<double> fixed_acceptance_rate =
175+
util::get_fix_speculative_acceptance_rate();
176+
if (fixed_acceptance_rate.has_value()) {
177+
rate_controller_ = std::make_shared<RejectionSamplerRateController>(
178+
*fixed_acceptance_rate);
179179
}
180180
}
181181

xllm/core/util/env_var.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::string get_string_env(const std::string& name) {
6262
return std::string(val);
6363
}
6464

65-
double get_double_env(const std::string& key, double defaultValue) {
65+
double get_double_env(const std::string& key, double defaultValue = -1) {
6666
const char* val = std::getenv(key.c_str());
6767
if (val == nullptr) {
6868
return defaultValue;
@@ -89,21 +89,27 @@ int64_t get_process_group_test_timeout_seconds() {
8989
return get_int_env(kTimeoutEnvVar, kDefaultTimeoutSeconds);
9090
}
9191

92-
double get_fix_speculative_acceptance_rate() {
93-
// Default is -1.0 to disable fixing the speculative acceptance rate.
94-
// Set XLLM_FIX_SPECULATIVE_ACCEPTANCE_RATE to a valid float value to fix the
95-
// speculative acceptance rate.
96-
constexpr double kDefaultValue = -1.0;
92+
std::optional<double> get_fix_speculative_acceptance_rate() {
93+
// XLLM_FIX_SPECULATIVE_ACCEPTANCE_RATE:
94+
// Defines a fixed acceptance rate for speculative decoding to simulate
95+
// specific performance scenarios.
96+
// Valid values are in the range [0.0, 1.0].
97+
// If not set, or set to an invalid value, the fixed rate logic is disabled
98+
// (returns std::nullopt).
9799
constexpr const char* kAcceptanceRateEnvVar =
98100
"XLLM_FIX_SPECULATIVE_ACCEPTANCE_RATE";
99-
double value = get_double_env(kAcceptanceRateEnvVar, kDefaultValue);
100-
// Ensure that if set, it is between 0 and 1
101-
if (value != kDefaultValue && (value < 0.0 || value > 1.0)) {
102-
LOG(WARNING) << "Invalid value for " << kAcceptanceRateEnvVar << ": "
103-
<< value << ". Must be in [0, 1]. Using default (-1)";
104-
return kDefaultValue;
101+
double value = get_double_env(kAcceptanceRateEnvVar, -1.0);
102+
if (value == -1.0) {
103+
return std::nullopt;
105104
}
106-
return value;
105+
// Validate the range. It must be a probability between 0 and 1.
106+
if (value < 0.0 || value > 1.0) {
107+
LOG(WARNING) << "Warning: Invalid value for " << kAcceptanceRateEnvVar
108+
<< ": " << value << ". Must be in [0, 1]. Ignoring setting."
109+
<< std::endl;
110+
return std::nullopt;
111+
}
112+
return std::make_optional(value);
107113
}
108114

109115
} // namespace util

xllm/core/util/env_var.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#pragma once
1717

18+
#include <optional>
1819
#include <string>
1920

2021
namespace xllm {
@@ -39,13 +40,13 @@ std::string get_string_env(const std::string& name);
3940
// insufficient.
4041
int64_t get_process_group_test_timeout_seconds();
4142

42-
// Check if the speculative acceptance rate should be fixed.
43-
// NOTE: This variable is for performance debugging only, it will
44-
// influence the model accuracy and should not be used in production.
45-
// Returns the fixed acceptance rate if the XLLM_FIX_SPECULATIVE_ACCEPTANCE_RATE
46-
// environment variable is set to a valid float value, -1.0 otherwise. This is
47-
// used to control whether to fix the speculative acceptance rate.
48-
double get_fix_speculative_acceptance_rate();
43+
// Returns an optional fixed acceptance rate for speculative decoding (for
44+
// performance debugging only). If the XLLM_FIX_SPECULATIVE_ACCEPTANCE_RATE
45+
// environment variable is set to a value in [0.0, 1.0], returns
46+
// std::optional<double> with that value; otherwise returns std::nullopt.
47+
// WARNING: Using this will influence model accuracy and should not be used in
48+
// production.
49+
std::optional<double> get_fix_speculative_acceptance_rate();
4950

5051
} // namespace util
5152
} // namespace xllm

0 commit comments

Comments
 (0)