diff --git a/tritonbench/components/do_bench/entropy/entropy_criterion.py b/tritonbench/components/do_bench/entropy/entropy_criterion.py index d2872b1aa..753cf8c41 100644 --- a/tritonbench/components/do_bench/entropy/entropy_criterion.py +++ b/tritonbench/components/do_bench/entropy/entropy_criterion.py @@ -99,12 +99,17 @@ def _update_entropy_sum(self, old_count: int, new_count: int) -> None: new_count: New count (0 if removing unique value) """ # Remove old contribution: S -= old_count * log2(old_count) - if old_count > 0: - self._sum_count_log_count -= old_count * math.log2(old_count) - - # Add new contribution: S += new_count * log2(new_count) - if new_count > 0: - self._sum_count_log_count += new_count * math.log2(new_count) + # Optimzation: nlog(n) - olog(o) = nlog(1+(n-o)/o) + (n - o)log(o) + if old_count > 0 and new_count > 0: + delta = new_count - old_count + self._sum_count_log_count += ( + new_count * math.log2(1 + delta / old_count) + delta * math.log2(old_count) + ) + else: + if old_count > 0: + self._sum_count_log_count -= old_count * math.log2(old_count) + if new_count > 0: + self._sum_count_log_count += new_count * math.log2(new_count) def _compute_entropy(self) -> float: """ @@ -121,7 +126,7 @@ def _compute_entropy(self) -> float: # Entropy formula: H = log2(n) - S/n entropy = math.log2(n) - (self._sum_count_log_count / n) - return entropy + return max(0.0, entropy) def add_measurement(self, measurement: float) -> None: """ @@ -157,24 +162,20 @@ def add_measurement(self, measurement: float) -> None: # Update running statistics for linear regression # If entropy_tracker is full, remove oldest component from running stats + # removal index in the sliding window = 0 if len(self.entropy_tracker) == self.window_size: old_entropy = self.entropy_tracker[0] - old_x = 0 # Oldest position in the sliding window - old_sum_x = self._sum_x # Remove old values from running sums - self._sum_x -= old_x self._sum_y -= old_entropy - self._sum_xy -= old_x * old_entropy - self._sum_x2 -= old_x * old_x self._sum_y2 -= old_entropy * old_entropy - self._n -= 1 # Remove element's effect from sum of squares - n = self._n - self._sum_x2 -= 2 * old_sum_x + n # Use saved old_sum_x + n = self._n - 1 self._sum_x -= n + self._sum_x2 -= 2 * self._sum_x + n self._sum_xy -= self._sum_y + self._n -= 1 # Add new entropy value to running stats x = self._n @@ -217,7 +218,7 @@ def is_finished(self) -> bool: numerator = self._sum_xy - n * mean_x * mean_y denominator = self._sum_x2 - n * mean_x * mean_x - if denominator == 0: + if (denominator < 1e-9): return False slope = numerator / denominator @@ -240,8 +241,8 @@ def is_finished(self) -> bool: + n * intercept * intercept ) - # If ss_tot == 0, entropy values are identical => perfect stability - if ss_tot == 0: + # If ss_tot < epsilon, entropy values are identical => perfect stability + if (ss_tot < 1e-9): r2 = 1.0 else: r2 = max(0.0, min(1.0, 1 - (ss_res / ss_tot))) diff --git a/tritonbench/components/do_bench/run.py b/tritonbench/components/do_bench/run.py index 217b532a5..3f45ddecc 100644 --- a/tritonbench/components/do_bench/run.py +++ b/tritonbench/components/do_bench/run.py @@ -502,22 +502,25 @@ def _do_bench_entropy( assert return_mode in ["min", "max", "mean", "median", "all"] # ENTROPY-BASED WARMUP - criterion = EntropyCriterion( + entropy_criterion = EntropyCriterion( max_angle=max_angle, min_r2=min_r2, window_size=window_size, min_warmup_samples=min_warmup_samples, ) - criterion.reset() - BATCH_SIZE = 20 + entropy_criterion.reset() + + rounding_factor = 3 + BATCH_SIZE = 50 last_batch = [-1.00] * BATCH_SIZE counter = 0 converged = False + precision_increase = False cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # Adaptive warmup loop with batched synchronization - while not criterion.is_finished(): + while True: remaining = max_samples - counter batch_size = min(BATCH_SIZE, remaining) if remaining > 0 else BATCH_SIZE @@ -540,20 +543,41 @@ def _do_bench_entropy( torch.cuda.synchronize() for i in range(batch_size): - v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), 3) - criterion.add_measurement(v) + v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), rounding_factor) last_batch[i] = v + + entropy_criterion.add_measurement(v) + + if entropy_criterion.is_finished(): + converged = True + break + counter += batch_size + if converged: + break + if counter >= max_samples: break - else: - converged = True + + if counter >= 200 and not precision_increase: + stats = entropy_criterion.get_stats() + unique_count = stats.get('unique_measurements', 0) + + # If we have < 20 unique values, this indicates quantization, increase rounding precision + if unique_count < 20: + rounding_factor = 4 + entropy_criterion.reset() + entropy_criterion.entropy_window_size = 1000 + + logger.info(f"Quantization detected: only {unique_count} unique measurements. ") + precision_increase = True + # Log if warmup didn't converge if not converged: logger.warning( - f"Entropy warmup did not converge after {counter} samples " + f"Warmup did not converge after {counter} samples " f"(max_samples={max_samples})" )