Skip to content

Commit 26aabcb

Browse files
jhou-jpgfacebook-github-bot
authored andcommitted
Entropy warmup enhancements (#651)
Summary: This diff fixes an implementation error that was impacting slope and r^2 computations. It also includes optimization to the entropy warmup logic that should improve its performance with smaller/faster kernels. Further numerical stability optimization were brought in from the cross-pr on nvbench NVIDIA/nvbench#286 Since entropy tracks the frequency of appearence for each unique time element, the rounding precision applied on individual latency element can impact the characteristics of entropy convergence. This diff introduces logic to dynamically increase rounding precision to maintain a balance between entropy sensitivity and trend detection Reviewed By: xuzhao9 Differential Revision: D87379814
1 parent aa5ecdc commit 26aabcb

File tree

2 files changed

+52
-27
lines changed

2 files changed

+52
-27
lines changed

tritonbench/components/do_bench/entropy/entropy_criterion.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,17 @@ def _update_entropy_sum(self, old_count: int, new_count: int) -> None:
9999
new_count: New count (0 if removing unique value)
100100
"""
101101
# Remove old contribution: S -= old_count * log2(old_count)
102-
if old_count > 0:
103-
self._sum_count_log_count -= old_count * math.log2(old_count)
104-
105-
# Add new contribution: S += new_count * log2(new_count)
106-
if new_count > 0:
107-
self._sum_count_log_count += new_count * math.log2(new_count)
102+
# Optimzation: nlog(n) - olog(o) = nlog(1+(n-o)/o) + (n - o)log(o)
103+
if old_count > 0 and new_count > 0:
104+
delta = new_count - old_count
105+
self._sum_count_log_count += (
106+
new_count * math.log2(1 + delta / old_count) + delta * math.log2(old_count)
107+
)
108+
else:
109+
if old_count > 0:
110+
self._sum_count_log_count -= old_count * math.log2(old_count)
111+
if new_count > 0:
112+
self._sum_count_log_count += new_count * math.log2(new_count)
108113

109114
def _compute_entropy(self) -> float:
110115
"""
@@ -121,7 +126,7 @@ def _compute_entropy(self) -> float:
121126

122127
# Entropy formula: H = log2(n) - S/n
123128
entropy = math.log2(n) - (self._sum_count_log_count / n)
124-
return entropy
129+
return max(0.0, entropy)
125130

126131
def add_measurement(self, measurement: float) -> None:
127132
"""
@@ -157,24 +162,20 @@ def add_measurement(self, measurement: float) -> None:
157162

158163
# Update running statistics for linear regression
159164
# If entropy_tracker is full, remove oldest component from running stats
165+
# removal index in the sliding window = 0
160166
if len(self.entropy_tracker) == self.window_size:
161167
old_entropy = self.entropy_tracker[0]
162-
old_x = 0 # Oldest position in the sliding window
163-
old_sum_x = self._sum_x
164168

165169
# Remove old values from running sums
166-
self._sum_x -= old_x
167170
self._sum_y -= old_entropy
168-
self._sum_xy -= old_x * old_entropy
169-
self._sum_x2 -= old_x * old_x
170171
self._sum_y2 -= old_entropy * old_entropy
171-
self._n -= 1
172172

173173
# Remove element's effect from sum of squares
174-
n = self._n
175-
self._sum_x2 -= 2 * old_sum_x + n # Use saved old_sum_x
174+
n = self._n - 1
176175
self._sum_x -= n
176+
self._sum_x2 -= 2 * self._sum_x + n
177177
self._sum_xy -= self._sum_y
178+
self._n -= 1
178179

179180
# Add new entropy value to running stats
180181
x = self._n
@@ -217,7 +218,7 @@ def is_finished(self) -> bool:
217218
numerator = self._sum_xy - n * mean_x * mean_y
218219
denominator = self._sum_x2 - n * mean_x * mean_x
219220

220-
if denominator == 0:
221+
if (denominator < 1e-9):
221222
return False
222223

223224
slope = numerator / denominator
@@ -240,8 +241,8 @@ def is_finished(self) -> bool:
240241
+ n * intercept * intercept
241242
)
242243

243-
# If ss_tot == 0, entropy values are identical => perfect stability
244-
if ss_tot == 0:
244+
# If ss_tot < epsilon, entropy values are identical => perfect stability
245+
if (ss_tot < 1e-9):
245246
r2 = 1.0
246247
else:
247248
r2 = max(0.0, min(1.0, 1 - (ss_res / ss_tot)))

tritonbench/components/do_bench/run.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -502,22 +502,25 @@ def _do_bench_entropy(
502502
assert return_mode in ["min", "max", "mean", "median", "all"]
503503

504504
# ENTROPY-BASED WARMUP
505-
criterion = EntropyCriterion(
505+
entropy_criterion = EntropyCriterion(
506506
max_angle=max_angle,
507507
min_r2=min_r2,
508508
window_size=window_size,
509509
min_warmup_samples=min_warmup_samples,
510510
)
511-
criterion.reset()
512-
BATCH_SIZE = 20
511+
entropy_criterion.reset()
512+
513+
rounding_factor = 3
514+
BATCH_SIZE = 50
513515
last_batch = [-1.00] * BATCH_SIZE
514516
counter = 0
515517
converged = False
518+
precision_increase = False
516519

517520
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()
518521

519522
# Adaptive warmup loop with batched synchronization
520-
while not criterion.is_finished():
523+
while True:
521524
remaining = max_samples - counter
522525
batch_size = min(BATCH_SIZE, remaining) if remaining > 0 else BATCH_SIZE
523526

@@ -540,20 +543,41 @@ def _do_bench_entropy(
540543
torch.cuda.synchronize()
541544

542545
for i in range(batch_size):
543-
v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), 3)
544-
criterion.add_measurement(v)
546+
v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), rounding_factor)
545547
last_batch[i] = v
548+
549+
entropy_criterion.add_measurement(v)
550+
551+
if entropy_criterion.is_finished():
552+
converged = True
553+
break
554+
546555
counter += batch_size
547556

557+
if converged:
558+
break
559+
548560
if counter >= max_samples:
549561
break
550-
else:
551-
converged = True
562+
563+
if counter >= 200 and not precision_increase:
564+
stats = entropy_criterion.get_stats()
565+
unique_count = stats.get('unique_measurements', 0)
566+
567+
# If we have < 20 unique values, this indicates quantization, increase rounding precision
568+
if unique_count < 20:
569+
rounding_factor = 4
570+
entropy_criterion.reset()
571+
entropy_criterion.entropy_window_size = 1000
572+
573+
logger.info(f"Quantization detected: only {unique_count} unique measurements. ")
574+
precision_increase = True
575+
552576

553577
# Log if warmup didn't converge
554578
if not converged:
555579
logger.warning(
556-
f"Entropy warmup did not converge after {counter} samples "
580+
f"Warmup did not converge after {counter} samples "
557581
f"(max_samples={max_samples})"
558582
)
559583

0 commit comments

Comments
 (0)