@@ -41,164 +41,51 @@ torch::Tensor index_select_2d(const torch::Tensor& input,
4141
4242RejectionSamplerRateController::RejectionSamplerRateController (
4343 double fixed_acceptance_rate)
44- : window_size_(1000 ),
45- history_buffer_ (1000 , 0 ),
46- history_idx_(0 ),
47- window_sum_(0 ),
48- error_buffer_(20 , 0.0 ),
49- error_idx_(0 ),
50- pid_adj_(0.0 ),
51- cumulative_err_(0.0 ),
52- last_target_(-1.0 ),
53- total_batches_(0 ),
54- accepted_batches_(0 ),
55- gen_(std::random_device{}()),
56- dist_(0.0 , 1.0 ),
57- fixed_acceptance_rate_(fixed_acceptance_rate) {}
44+ : fixed_acceptance_rate_(fixed_acceptance_rate),
45+ last_target_ (fixed_acceptance_rate) {}
5846
5947torch::Tensor RejectionSamplerRateController::filter_with_acceptance_rate (
6048 const torch::Tensor& token_ids) {
61- // Check parameters
62- if (fixed_acceptance_rate_ < 0.0 || fixed_acceptance_rate_ > 1.0 )
49+ // Basic parameter validation
50+ if (fixed_acceptance_rate_ < 0.0 || fixed_acceptance_rate_ > 1.0 ||
51+ token_ids.size (0 ) == 0 ) {
6352 return token_ids.clone ();
64- if (token_ids. size ( 0 ) == 0 ) return token_ids. clone ();
53+ }
6554
66- // Reset state if target acceptance rate changed
55+ // Reset counters if the target rate has changed significantly
6756 if (std::abs (last_target_ - fixed_acceptance_rate_) > 1e-6 ) {
68- reset_state (fixed_acceptance_rate_);
57+ total_batches_ = 0 ;
58+ accepted_batches_ = 0 ;
59+ last_target_ = fixed_acceptance_rate_;
6960 }
7061
71- // Update window statistics
72- total_batches_++;
73- double global_rate =
74- (total_batches_ > 0 )
75- ? static_cast <double >(accepted_batches_) / total_batches_
76- : 0.0 ;
77- window_sum_ -= history_buffer_[history_idx_];
78- history_buffer_[history_idx_] = 0 ;
79- double window_rate = static_cast <double >(window_sum_) / window_size_;
80-
81- // Calculate combined error between observed and target rates
82- double batch_weight =
83- 1.0 - std::exp (-static_cast <double >(total_batches_) / 30.0 );
84- double win_weight =
85- std::min (static_cast <double >(window_size_) / 100.0 , 0.9 ) * batch_weight;
86- double combined_err =
87- (1.0 - win_weight) * (global_rate - fixed_acceptance_rate_) +
88- win_weight * (window_rate - fixed_acceptance_rate_);
89-
90- // PID controller for automatic correction
91- if (total_batches_ > 50 ) {
92- double curr_err = global_rate - fixed_acceptance_rate_;
93- error_buffer_[error_idx_] = curr_err;
94- double prev_err = error_buffer_[(error_idx_ + 19 ) % 20 ];
95- error_idx_ = (error_idx_ + 1 ) % 20 ;
96-
97- double i_term =
98- std::accumulate (error_buffer_.begin (), error_buffer_.end (), 0.0 );
99- double d_term = curr_err - prev_err;
100-
101- // PID: kp=0.05, ki=0.001, kd=0.01
102- pid_adj_ = 0.05 * curr_err + 0.001 * i_term + 0.01 * d_term;
103-
104- // Clamp adjustment
105- double limit =
106- 0.02 +
107- 0.03 * (1.0 - std::exp (-static_cast <double >(total_batches_) / 500.0 ));
108- pid_adj_ = std::clamp (pid_adj_, -limit, limit);
109- }
62+ // Calculate Drift: Difference between expected hits and actual hits
63+ double expected_hits = total_batches_ * fixed_acceptance_rate_;
64+ double drift = expected_hits - accepted_batches_;
11065
111- // Get corrected acceptance rate
112- double adj_rate =
113- calculate_adjusted_rate (fixed_acceptance_rate_, combined_err);
66+ // Calculate adjusted probability
67+ // If drift > 0 (we accepted too few), increase probability.
68+ // The factor 0.1 acts as a gentle gain to correct long-term error.
69+ double adj_rate = fixed_acceptance_rate_ + (drift * 0.1 );
70+ adj_rate = std::clamp (adj_rate, 0.0 , 1.0 );
11471
115- // Decide accept or reject this batch
72+ // Perform rejection sampling
11673 bool accept = dist_ (gen_) < adj_rate;
11774
118- torch::Tensor out_tensor = token_ids.clone ();
75+ // Update statistics
76+ total_batches_++;
11977 if (accept) {
120- // Accept: keep as-is
12178 accepted_batches_++;
122- history_buffer_[history_idx_] = 1 ;
123- window_sum_ += 1 ;
124- } else {
125- // Reject: mask out tokens after the first (dimension 1)
126- out_tensor.slice (1 , 1 ).fill_ (PLACEHOLDER_TOKEN_ID);
12779 }
12880
129- // Move window index forward
130- history_idx_ = (history_idx_ + 1 ) % window_size_;
131-
132- // Update cumulative EMA error
133- double actual_rate = static_cast <double >(accepted_batches_) / total_batches_;
134- double alpha =
135- 0.05 * std::exp (-static_cast <double >(total_batches_) / 200.0 ) + 0.01 ;
136- cumulative_err_ = alpha * (actual_rate - fixed_acceptance_rate_) +
137- (1.0 - alpha) * cumulative_err_;
138-
139- return out_tensor;
140- }
141-
142- void RejectionSamplerRateController::reset_state (double new_rate) {
143- pid_adj_ = 0.0 ;
144- std::fill (error_buffer_.begin (), error_buffer_.end (), 0.0 );
145- last_target_ = new_rate;
146- }
147-
148- double RejectionSamplerRateController::calculate_adjusted_rate (double target,
149- double error) {
150- // 1. Nonlinear correction based on error magnitude
151- double err_abs = std::abs (error);
152- double factor = 1.0 ;
153- if (err_abs > 0.0005 ) {
154- double strength = 2.0 + (1.0 - std::exp (-err_abs * 50.0 )) * 1.5 ;
155- double sign = (error > 0 ) ? 1.0 : -1.0 ;
156- factor = 1.0 + (strength * err_abs * sign);
157- }
158-
159- // 2. Apply PID adjustment
160- double rate_base = target * (factor == 0.0 ? 1.0 : 1.0 / factor);
161- double rate = std::clamp (rate_base - pid_adj_, 0.0 , 1.0 );
162-
163- // 3. Edge cases and periodic force accept for low rate
164- if (target == 0.0 ) return 0.0 ;
165- if (target == 1.0 ) return 1.0 ;
166- if (target > 0 && target < 0.05 ) {
167- int period = static_cast <int >(1.0 / target);
168- if (total_batches_ % period == 0 ) return 1.0 ;
169- }
170-
171- // 4. Gap correction to enforce long-term target
172- long target_accepted = std::llround (total_batches_ * target);
173- long gap = target_accepted - accepted_batches_;
174- long gap_threshold = std::max (1L , static_cast <long >(total_batches_ * 0.005 ));
175-
176- if (std::abs (gap) >= gap_threshold) {
177- double gap_relative =
178- static_cast <double >(std::abs (gap)) / std::max (1L , total_batches_);
179- double importance = 1.0 - std::exp (-gap_relative * 50.0 );
180- double strength =
181- 0.2 + 0.8 * std::exp (-static_cast <double >(total_batches_) / 1000.0 );
182- double boost = importance * strength;
183-
184- if (gap > 0 ) {
185- // Need to accept more
186- rate = std::min (1.0 , rate + (1.0 - rate) * boost);
187- } else {
188- // Need to reject more
189- rate = std::max (0.0 , rate * (1.0 - boost));
190- }
191- }
192-
193- // 5. Random noise to prevent local optima
194- if (rate > 0.01 && rate < 0.99 && total_batches_ > 100 ) {
195- double noise_amp =
196- 0.01 * std::exp (-static_cast <double >(total_batches_) / 500.0 );
197- double noise = (dist_ (gen_) * 2.0 - 1.0 ) * noise_amp;
198- rate = std::clamp (rate + noise, 0.0 , 1.0 );
81+ // Generate output
82+ torch::Tensor out_tensor = token_ids.clone ();
83+ if (!accept) {
84+ // Reject: Mask out tokens after the first one (dimension 1)
85+ out_tensor.slice (1 , 1 ).fill_ (kPlaceholderTokenId );
19986 }
20087
201- return rate ;
88+ return out_tensor ;
20289}
20390
20491RejectionSampler::RejectionSampler (
0 commit comments