Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions iris/bench/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,25 +427,46 @@ def _run_benchmarks_worker(
return_mode="all",
)

mean_ms = statistics.mean(times)
# Per-rank: median across iterations (robust to outliers)
local_median_ms = statistics.median(times)

# Cross-rank: gather every rank's median to compute
# max (true collective latency), min, and skew.
gather_device = "cuda" if backend == "nccl" else "cpu"
local_t = torch.tensor([local_median_ms], device=gather_device)
gathered = [torch.zeros(1, device=gather_device) for _ in range(world_size)]
Comment on lines +436 to +437

This comment was marked as resolved.

dist.all_gather(gathered, local_t)
rank_medians = [t.item() for t in gathered]
Comment on lines +433 to +439

This comment was marked as resolved.


max_ms = max(rank_medians)
min_ms = min(rank_medians)
Comment on lines +433 to +442

This comment was marked as resolved.

skew_pct = ((max_ms - min_ms) / max_ms * 100) if max_ms > 0 else 0.0

# Headline time = max across ranks (slowest rank
# determines when the collective is actually done).
gpu_time_ms = max_ms

bw = None
if state._bytes is not None and mean_ms > 0:
bw = (state._bytes / 1e9) / (mean_ms * 1e-3)
if state._bytes is not None and gpu_time_ms > 0:
bw = (state._bytes / 1e9) / (gpu_time_ms * 1e-3)

tflops = None
if state._flops is not None and mean_ms > 0:
tflops = (state._flops / 1e12) / (mean_ms * 1e-3)
if state._flops is not None and gpu_time_ms > 0:
tflops = (state._flops / 1e12) / (gpu_time_ms * 1e-3)

counters = dict(state._counters)
counters["min_ms"] = min_ms
counters["skew_pct"] = skew_pct

all_results.append(
Result(
benchmark_name=bdef.name,
params=params,
gpu_time_ms=mean_ms,
gpu_time_ms=gpu_time_ms,
all_times_ms=times,
bandwidth_gbps=bw,
tflops=tflops,
counters=dict(state._counters),
counters=counters,
world_size=world_size,
)
)
Expand Down
Loading