-
Notifications
You must be signed in to change notification settings - Fork 146
CUDA library hijacking introduces ~82% performance overhead on training workloads #164
Description
Problem
The HAMi-core CUDA library hijacking layer introduces significant performance overhead on kernel-launch-intensive workloads. Testing with a benchmark that issues ~147,000 CUDA kernel launches per step (balanced_assign with K=8192, 3 stages) shows ~80% degradation in throughput compared to running without HAMi.
The overhead is not in the resource limiting logic itself, but in per-call bookkeeping (logging, status checks, shared memory reads, mutex acquisition) that executes on every intercepted CUDA API call regardless of whether limiting is active.
Reproducing the Issue
Save the script below as test_hami_slowdown.py and run:
# Without HAMi (baseline):
python test_hami_slowdown.py
# With HAMi:
LD_PRELOAD=/path/to/libvgpu.so python test_hami_slowdown.py
# Multi-GPU:
torchrun --nproc_per_node=N test_hami_slowdown.pyCompare the ms/step values between runs. The balanced_assign pattern launches ~147K kernels per step and is highly sensitive to per-kernel overhead.
test_hami_slowdown.py (click to expand)
"""
HAMi-core Slowdown Test — Single and Multi-GPU
Single GPU: python test_hami_slowdown.py
Multi GPU: torchrun --nproc_per_node=N test_hami_slowdown.py
Patterns:
1. balanced_assign vs argmin — per-GPU kernel launch overhead
2. dist.all_reduce sustained — DDP all_reduce overhead (multi-GPU only)
3. balanced_assign + all_reduce — combined simulated training step
Usage:
# Without HAMi:
torchrun --nproc_per_node=2 test_hami_slowdown.py
# With HAMi:
LD_PRELOAD=/path/to/libvgpu.so torchrun --nproc_per_node=2 test_hami_slowdown.py
"""
import os
import time
import threading
import subprocess
import statistics
import torch
import torch.nn.functional as F
import torch.distributed as dist
# ── Distributed setup ─────────────────────────────────────────────────────────
IS_DIST = "LOCAL_RANK" in os.environ
if IS_DIST:
dist.init_process_group(backend="nccl")
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = dist.get_world_size()
RANK = dist.get_rank()
else:
LOCAL_RANK = 0
WORLD_SIZE = 1
RANK = 0
torch.cuda.set_device(LOCAL_RANK)
DEVICE = f"cuda:{LOCAL_RANK}"
# ── Config ────────────────────────────────────────────────────────────────────
BATCH_SIZE = 8192
K = 8192
D = 50
NUM_STAGES = 3
MEASURE_SECS = 30
# ─────────────────────────────────────────────────────────────────────────────
class PCIePoller:
def __init__(self):
self.rx_kb = []
self.tx_kb = []
self._proc = None
self._thread = None
self._active = False
self._rx_col = 1
self._tx_col = 2
try:
r = subprocess.run(
["nvidia-smi", "dmon", "-s", "t", "-d", "1", "-c", "1"],
capture_output=True, text=True, timeout=5)
self.available = (r.returncode == 0)
except Exception:
self.available = False
def start(self):
if not self.available:
return
self.rx_kb = []
self.tx_kb = []
self._active = True
self._thread = threading.Thread(target=self._read, daemon=True)
self._thread.start()
def stop(self):
if not self.available:
return
self._active = False
if self._proc:
try:
self._proc.terminate()
self._proc.wait(timeout=2)
except Exception:
pass
if self._thread:
self._thread.join(timeout=3)
def _read(self):
try:
cmd = ["nvidia-smi", "dmon", "-s", "t", "-d", "1",
"-i", str(LOCAL_RANK)]
self._proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
text=True, bufsize=1)
for line in self._proc.stdout:
if not self._active:
break
line = line.strip()
if not line:
continue
if line.startswith("#"):
cols = line.lstrip("# ").split()
if "rxpci" in cols:
self._rx_col = cols.index("rxpci")
if "txpci" in cols:
self._tx_col = cols.index("txpci")
continue
parts = line.split()
try:
self.rx_kb.append(float(parts[self._rx_col]))
self.tx_kb.append(float(parts[self._tx_col]))
except (ValueError, IndexError):
pass
except Exception:
self.available = False
def h2d_avg_mb(self): return statistics.mean(self.rx_kb) / 1024 if self.rx_kb else 0.0
def h2d_peak_mb(self): return max(self.rx_kb) / 1024 if self.rx_kb else 0.0
def d2h_avg_mb(self): return statistics.mean(self.tx_kb) / 1024 if self.tx_kb else 0.0
def n_samples(self): return len(self.rx_kb)
def gather_and_print(label, stats: dict):
local_t = torch.tensor([
stats["ms_per_step"],
float(stats["steps"]),
stats["h2d_avg"],
stats["h2d_peak"],
stats["d2h_avg"],
float(stats["n_samples"]),
float(stats["pcie_available"]),
], dtype=torch.float64, device=DEVICE)
if IS_DIST:
all_tensors = [torch.zeros_like(local_t) for _ in range(WORLD_SIZE)]
dist.all_gather(all_tensors, local_t)
else:
all_tensors = [local_t]
if RANK == 0:
print(f"\n {label}")
for r, t in enumerate(all_tensors):
ms = t[0].item()
steps = int(t[1].item())
h2d_avg = t[2].item()
h2d_peak = t[3].item()
n_samp = int(t[5].item())
pcie_ok = bool(t[6].item())
pcie_str = (f" H2D avg={h2d_avg:.1f} MB/s peak={h2d_peak:.1f} MB/s ({n_samp} samples)"
if pcie_ok else " PCIe: dmon unavailable")
print(f" rank {r}: {ms:8.1f} ms/step"
f" ({steps} steps in {MEASURE_SECS}s)"
f" {pcie_str}")
def make_stats(ms_per_step, steps, poller):
return {
"ms_per_step": ms_per_step,
"steps": steps,
"h2d_avg": poller.h2d_avg_mb(),
"h2d_peak": poller.h2d_peak_mb(),
"d2h_avg": poller.d2h_avg_mb(),
"n_samples": poller.n_samples(),
"pcie_available": poller.available,
}
def barrier():
if IS_DIST:
dist.barrier()
def compute_distances(inputs, codebook):
return (
torch.sum(inputs ** 2, dim=1, keepdim=True)
- 2 * torch.matmul(inputs, codebook.T)
+ torch.sum(codebook ** 2, dim=1)
)
def balanced_assign(inputs, codebook):
N = inputs.shape[0]
w = max(1, N // codebook.shape[0])
d = compute_distances(inputs, codebook)
unassigned = torch.ones(N, dtype=torch.bool, device=inputs.device)
assignments = torch.full((N,), -1, dtype=torch.long, device=inputs.device)
for k in range(codebook.shape[0]):
col = d[:, k].clone()
col[~unassigned] = float("inf")
idx = torch.argsort(col)
cand = idx[:w]
cand = cand[col[cand] != float("inf")]
assignments[cand] = k
unassigned[cand] = False
return assignments
def argmin_assign(inputs, codebook):
return torch.argmin(compute_distances(inputs, codebook), dim=1)
def measure_pattern1(inputs, codebooks, assign_fn):
for _ in range(2):
r = inputs
for cb in codebooks:
r = r - cb[assign_fn(r, cb)]
torch.cuda.synchronize()
barrier()
poller = PCIePoller()
poller.start()
steps = 0
t0 = time.perf_counter()
while time.perf_counter() - t0 < MEASURE_SECS:
r = inputs
for cb in codebooks:
r = r - cb[assign_fn(r, cb)]
torch.cuda.synchronize()
steps += 1
poller.stop()
elapsed = time.perf_counter() - t0
return make_stats(elapsed / steps * 1000, steps, poller)
def measure_allreduce():
batch_counts = torch.ones(K, dtype=torch.float32, device=DEVICE)
batch_sums = torch.ones(K, D, dtype=torch.float32, device=DEVICE)
stop_flag = torch.zeros(1, dtype=torch.float32, device=DEVICE)
CHECK_EVERY = 100
for _ in range(10):
for _ in range(NUM_STAGES):
dist.all_reduce(batch_counts)
dist.all_reduce(batch_sums)
torch.cuda.synchronize()
barrier()
poller = PCIePoller()
poller.start()
steps = 0
t0 = time.perf_counter()
while True:
for _ in range(CHECK_EVERY):
for _ in range(NUM_STAGES):
dist.all_reduce(batch_counts)
dist.all_reduce(batch_sums)
torch.cuda.synchronize()
steps += 1
stop_flag.fill_(1.0 if time.perf_counter() - t0 >= MEASURE_SECS else 0.0)
dist.all_reduce(stop_flag, op=dist.ReduceOp.MAX)
if stop_flag.item() > 0.5:
break
poller.stop()
elapsed = time.perf_counter() - t0
return make_stats(elapsed / steps * 1000, steps, poller)
def measure_combined(inputs, codebooks):
batch_counts = torch.ones(K, dtype=torch.float32, device=DEVICE)
batch_sums = torch.ones(K, D, dtype=torch.float32, device=DEVICE)
stop_flag = torch.zeros(1, dtype=torch.float32, device=DEVICE)
for _ in range(2):
r = inputs
for cb in codebooks:
r = r - cb[balanced_assign(r, cb)]
if IS_DIST:
for _ in range(NUM_STAGES):
dist.all_reduce(batch_counts)
dist.all_reduce(batch_sums)
torch.cuda.synchronize()
barrier()
poller = PCIePoller()
poller.start()
steps = 0
t0 = time.perf_counter()
while True:
r = inputs
for cb in codebooks:
r = r - cb[balanced_assign(r, cb)]
if IS_DIST:
for _ in range(NUM_STAGES):
dist.all_reduce(batch_counts)
dist.all_reduce(batch_sums)
torch.cuda.synchronize()
steps += 1
if IS_DIST:
stop_flag.fill_(1.0 if time.perf_counter() - t0 >= MEASURE_SECS else 0.0)
dist.all_reduce(stop_flag, op=dist.ReduceOp.MAX)
if stop_flag.item() > 0.5:
break
elif time.perf_counter() - t0 >= MEASURE_SECS:
break
poller.stop()
elapsed = time.perf_counter() - t0
return make_stats(elapsed / steps * 1000, steps, poller)
def main():
if not torch.cuda.is_available():
print("ERROR: CUDA not available.")
return
torch.manual_seed(42 + RANK)
probe = PCIePoller()
if RANK == 0:
print(f"\n{'='*65}")
print(f" HAMi-core Slowdown Test")
print(f"{'='*65}")
print(f" GPUs: {WORLD_SIZE}")
print(f" Device rank 0: {torch.cuda.get_device_name(0)}")
print(f" Distributed: {'yes' if IS_DIST else 'no'}")
print(f" K={K} D={D} batch={BATCH_SIZE} stages={NUM_STAGES}")
print(f" Each pattern runs {MEASURE_SECS}s per rank")
inputs = torch.randn(BATCH_SIZE, D, device=DEVICE)
codebooks = [torch.randn(K, D, device=DEVICE) for _ in range(NUM_STAGES)]
balanced_kernels = (7 + 6 * K) * NUM_STAGES
argmin_kernels = (7 + 1) * NUM_STAGES
allreduce_calls = 2 * NUM_STAGES
if RANK == 0:
print(f"\n Pattern 1: balanced_assign vs argmin")
stats_argmin = measure_pattern1(inputs, codebooks, argmin_assign)
gather_and_print(f"argmin x {NUM_STAGES} stages ({argmin_kernels} kernels/step)", stats_argmin)
stats_balanced = measure_pattern1(inputs, codebooks, balanced_assign)
gather_and_print(f"balanced x {NUM_STAGES} stages ({balanced_kernels:,} kernels/step)", stats_balanced)
if IS_DIST:
if RANK == 0:
print(f"\n Pattern 2: dist.all_reduce")
stats_ar = measure_allreduce()
gather_and_print(f"all_reduce x {allreduce_calls} calls/step", stats_ar)
else:
if RANK == 0:
print(f"\n Pattern 2: skipped (single GPU)")
if RANK == 0:
print(f"\n Pattern 3: balanced + all_reduce combined")
stats_combined = measure_combined(inputs, codebooks)
gather_and_print("balanced + all_reduce (combined step)", stats_combined)
if RANK == 0:
print(f"\n{'='*65}")
print(f" Summary (rank 0)")
print(f"{'='*65}")
print(f" argmin: {stats_argmin['ms_per_step']:8.1f} ms/step ({argmin_kernels} kernels)")
print(f" balanced: {stats_balanced['ms_per_step']:8.1f} ms/step ({balanced_kernels:,} kernels)")
if IS_DIST:
print(f" all_reduce:{stats_ar['ms_per_step']:8.3f} ms/step ({allreduce_calls} calls)")
print(f" combined: {stats_combined['ms_per_step']:8.1f} ms/step")
print(f"{'─'*65}")
print(f" Compare with and without LD_PRELOAD to isolate HAMi overhead.")
print(f"{'='*65}\n")
if IS_DIST:
dist.destroy_process_group()
if __name__ == "__main__":
main()Overhead Breakdown
What HAMi does on every cuLaunchKernel
Every intercepted kernel launch executes (in src/cuda/memory.c):
ENSURE_RUNNING(); // → wait_status_self(): linear scan through shared memory process slots
pre_launch_kernel(); // → time(NULL) syscall + pthread_mutex_lock/unlock
rate_limiter(...) // → 4 shared memory reads (sm_limit ×2, util_switch, recent_kernel)
// + 2 ensure_initialized() calls
cuLaunchKernel(...) // actual kernel dispatchAdditionally, every LOG_* macro in the interception path calls getenv("LIBCUDA_LOG_LEVEL") + atoi().
Cause 1: Per-kernel overhead on balanced_assign (+33%)
The balanced_assign pattern in the test script iterates over K=8192 centroids in a Python loop, each iteration issuing ~6 CUDA kernels. With 3 stages, this produces ~147,000 kernel launches per step. Each launch pays the full HAMi interception cost (~5 µs).
| Pattern | Kernels/step | Without HAMi | With HAMi | Overhead |
|---|---|---|---|---|
argmin |
24 | 3.1 ms | 3.2 ms | ~0% |
balanced_assign |
147,000 | 4,357 ms | 5,786 ms | +33% |
Cause 2: CPU cache pressure compounding with all_reduce (+82% combined)
When balanced_assign and dist.all_reduce run in the same step (Pattern 3 in the test), the overhead is super-additive — far worse than the sum of individual overheads:
| Pattern | Without HAMi | With HAMi | Overhead |
|---|---|---|---|
balanced_assign alone |
4,357 ms | 5,786 ms | +33% |
all_reduce alone (6 calls) |
0.306 ms | 0.331 ms | +8% |
| Combined step | 4,363 ms | 7,942 ms | +82% |
| Expected (sum of parts) | — | ~5,786 ms | — |
| Unexplained gap | — | 2,156 ms | — |
The 147K rounds of shared memory reads from HAMi's interception layer evict NCCL's ring-buffer and coordination data from CPU cache. When all_reduce runs immediately after, it suffers compounding cache-miss penalties. Additionally, the pthread_mutex_lock/unlock in pre_launch_kernel() on every kernel launch prevents the CUDA driver from pipelining kernel dispatches, causing NCCL to wait longer for the GPU stream to drain before initiating the collective.
Root Cause Analysis
P0 — Critical
-
getenv()/atoi()called on every LOG macro (log_utils.h)- Every
LOG_DEBUG,LOG_INFO,LOG_WARN,LOG_MSGcallsgetenv("LIBCUDA_LOG_LEVEL")+atoi()on every invocation.getenv()does a linear scan of the environment block. These macros appear in every intercepted CUDA function.
- Every
-
wait_status_self()linear scan (multiprocess_memory_limit.c)- Called via
ENSURE_RUNNING()on every kernel launch and memory operation. Scans up to 1024 process slots comparing PIDs, despite the slot pointer already being cached inregion_info.my_slot.
- Called via
P1 — High Impact
-
pre_launch_kernel()usestime()syscall + unconditional mutex (multiprocess_memory_limit.c)- Every
cuLaunchKernelcall invokestime(NULL)(a syscall) and acquirespthread_mutex_lock, even though the recording interval is 1 second and kernels fire thousands/sec. The mutex is taken even when no update is needed.
- Every
-
rate_limiter()redundant shared memory operations (multiprocess_utilization_watcher.c)- On every kernel launch: 3 shared memory reads + 2
ensure_initialized()calls, even when SM limiting is inactive (sm_limit >= 100 or == 0).
- On every kernel launch: 3 shared memory reads + 2
-
oom_check()deadcuDeviceGetCountcall (allocator.c)- Calls
cuDeviceGetCount()on every memory allocation but never uses the result.
- Calls
Proposed Fix
Each fix is independent and can be tested/applied separately:
| Priority | Fix | What it eliminates |
|---|---|---|
| P0 | Cache LIBCUDA_LOG_LEVEL in a global int |
getenv() + atoi() per LOG macro |
| P0 | Use cached my_slot pointer in wait_status_self() |
O(n) linear scan → O(1) |
| P1 | Replace time() with clock_gettime(CLOCK_REALTIME_COARSE) + double-checked locking |
syscall + unconditional mutex per kernel |
| P1 | Cache sm_limit/utilization_switch in local statics |
3 shared memory reads + 2 ensure_initialized() per kernel |
| P1 | Remove dead cuDeviceGetCount call |
1 driver API call per allocation |
Results With Fix
| Pattern | No HAMi | HAMi (original) | HAMi (fixed) |
|---|---|---|---|
argmin (24 kernels) |
3.1 ms | 3.2 ms (+3%) | 3.2 ms (+3%) |
balanced_assign (147K kernels) |
4,357 ms | 5,786 ms (+33%) | 4,335 ms (~0%) |
all_reduce only (6 calls) |
0.306 ms | 0.331 ms (+8%) | 0.326 ms (+7%) |
| Combined step | 4,363 ms | 7,942 ms (+82%) | 4,527 ms (+3.8%) |
The fix brings the combined overhead from +82% → +3.8%.
Related
A working implementation of all fixes is available on branch perf/reduce-hijack-overhead.