Skip to content

Commit bed96e4

Browse files
jhou-jpgfacebook-github-bot
authored andcommitted
Entropy warmup enhencements
Summary: This diff optimizes the entropy warmup logic and should improve its performance with smaller/faster kernels Since entropy tracks the frequency of appearence for each unique time element, the rounding precision applied on individual latency element can impact the characteristics of entropy convergence. This diff introduces logic to dynamically increase rounding precision to maintain a balance between entropy sensitivity and trend detection Differential Revision: D87379814
1 parent 41b3d6f commit bed96e4

File tree

1 file changed

+33
-9
lines changed
  • tritonbench/components/do_bench

1 file changed

+33
-9
lines changed

tritonbench/components/do_bench/run.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -502,22 +502,25 @@ def _do_bench_entropy(
502502
assert return_mode in ["min", "max", "mean", "median", "all"]
503503

504504
# ENTROPY-BASED WARMUP
505-
criterion = EntropyCriterion(
505+
entropy_criterion = EntropyCriterion(
506506
max_angle=max_angle,
507507
min_r2=min_r2,
508508
window_size=window_size,
509509
min_warmup_samples=min_warmup_samples,
510510
)
511-
criterion.reset()
512-
BATCH_SIZE = 20
511+
entropy_criterion.reset()
512+
513+
rounding_factor = 3
514+
BATCH_SIZE = 50
513515
last_batch = [-1.00] * BATCH_SIZE
514516
counter = 0
515517
converged = False
518+
precision_increase = False
516519

517520
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()
518521

519522
# Adaptive warmup loop with batched synchronization
520-
while not criterion.is_finished():
523+
while True:
521524
remaining = max_samples - counter
522525
batch_size = min(BATCH_SIZE, remaining) if remaining > 0 else BATCH_SIZE
523526

@@ -540,20 +543,41 @@ def _do_bench_entropy(
540543
torch.cuda.synchronize()
541544

542545
for i in range(batch_size):
543-
v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), 3)
544-
criterion.add_measurement(v)
546+
v = round(batch_start_events[i].elapsed_time(batch_end_events[i]), rounding_factor)
545547
last_batch[i] = v
548+
549+
entropy_criterion.add_measurement(v)
550+
551+
if entropy_criterion.is_finished():
552+
converged = True
553+
break
554+
546555
counter += batch_size
547556

557+
if converged:
558+
break
559+
548560
if counter >= max_samples:
549561
break
550-
else:
551-
converged = True
562+
563+
if counter >= 200 and not precision_increase:
564+
stats = entropy_criterion.get_stats()
565+
unique_count = stats.get('unique_measurements', 0)
566+
567+
# If we have < 20 unique values, this indicates quantization, increase rounding precision
568+
if unique_count < 20:
569+
rounding_factor = 4
570+
entropy_criterion.reset()
571+
entropy_criterion.entropy_window_size = 1000
572+
573+
logger.info(f"Quantization detected: only {unique_count} unique measurements. ")
574+
precision_increase = True
575+
552576

553577
# Log if warmup didn't converge
554578
if not converged:
555579
logger.warning(
556-
f"Entropy warmup did not converge after {counter} samples "
580+
f"Warmup did not converge after {counter} samples "
557581
f"(max_samples={max_samples})"
558582
)
559583

0 commit comments

Comments
 (0)