Skip to content

CUDA library hijacking introduces ~82% performance overhead on training workloads #164

@nishitnshah

Description

@nishitnshah

Problem

The HAMi-core CUDA library hijacking layer introduces significant performance overhead on kernel-launch-intensive workloads. Testing with a benchmark that issues ~147,000 CUDA kernel launches per step (balanced_assign with K=8192, 3 stages) shows ~80% degradation in throughput compared to running without HAMi.

The overhead is not in the resource limiting logic itself, but in per-call bookkeeping (logging, status checks, shared memory reads, mutex acquisition) that executes on every intercepted CUDA API call regardless of whether limiting is active.


Reproducing the Issue

Save the script below as test_hami_slowdown.py and run:

# Without HAMi (baseline):
python test_hami_slowdown.py

# With HAMi:
LD_PRELOAD=/path/to/libvgpu.so python test_hami_slowdown.py

# Multi-GPU:
torchrun --nproc_per_node=N test_hami_slowdown.py

Compare the ms/step values between runs. The balanced_assign pattern launches ~147K kernels per step and is highly sensitive to per-kernel overhead.

test_hami_slowdown.py (click to expand)
"""
HAMi-core Slowdown Test — Single and Multi-GPU

Single GPU:  python test_hami_slowdown.py
Multi GPU:   torchrun --nproc_per_node=N test_hami_slowdown.py

Patterns:
  1. balanced_assign vs argmin       — per-GPU kernel launch overhead
  2. dist.all_reduce sustained       — DDP all_reduce overhead (multi-GPU only)
  3. balanced_assign + all_reduce    — combined simulated training step

Usage:
    # Without HAMi:
    torchrun --nproc_per_node=2 test_hami_slowdown.py

    # With HAMi:
    LD_PRELOAD=/path/to/libvgpu.so torchrun --nproc_per_node=2 test_hami_slowdown.py
"""

import os
import time
import threading
import subprocess
import statistics
import torch
import torch.nn.functional as F
import torch.distributed as dist

# ── Distributed setup ─────────────────────────────────────────────────────────
IS_DIST = "LOCAL_RANK" in os.environ

if IS_DIST:
    dist.init_process_group(backend="nccl")
    LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    WORLD_SIZE  = dist.get_world_size()
    RANK        = dist.get_rank()
else:
    LOCAL_RANK  = 0
    WORLD_SIZE  = 1
    RANK        = 0

torch.cuda.set_device(LOCAL_RANK)
DEVICE = f"cuda:{LOCAL_RANK}"

# ── Config ────────────────────────────────────────────────────────────────────
BATCH_SIZE   = 8192
K            = 8192
D            = 50
NUM_STAGES   = 3
MEASURE_SECS = 30
# ─────────────────────────────────────────────────────────────────────────────


class PCIePoller:
    def __init__(self):
        self.rx_kb = []
        self.tx_kb = []
        self._proc   = None
        self._thread = None
        self._active = False
        self._rx_col = 1
        self._tx_col = 2
        try:
            r = subprocess.run(
                ["nvidia-smi", "dmon", "-s", "t", "-d", "1", "-c", "1"],
                capture_output=True, text=True, timeout=5)
            self.available = (r.returncode == 0)
        except Exception:
            self.available = False

    def start(self):
        if not self.available:
            return
        self.rx_kb = []
        self.tx_kb = []
        self._active = True
        self._thread = threading.Thread(target=self._read, daemon=True)
        self._thread.start()

    def stop(self):
        if not self.available:
            return
        self._active = False
        if self._proc:
            try:
                self._proc.terminate()
                self._proc.wait(timeout=2)
            except Exception:
                pass
        if self._thread:
            self._thread.join(timeout=3)

    def _read(self):
        try:
            cmd = ["nvidia-smi", "dmon", "-s", "t", "-d", "1",
                   "-i", str(LOCAL_RANK)]
            self._proc = subprocess.Popen(
                cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
                text=True, bufsize=1)
            for line in self._proc.stdout:
                if not self._active:
                    break
                line = line.strip()
                if not line:
                    continue
                if line.startswith("#"):
                    cols = line.lstrip("# ").split()
                    if "rxpci" in cols:
                        self._rx_col = cols.index("rxpci")
                    if "txpci" in cols:
                        self._tx_col = cols.index("txpci")
                    continue
                parts = line.split()
                try:
                    self.rx_kb.append(float(parts[self._rx_col]))
                    self.tx_kb.append(float(parts[self._tx_col]))
                except (ValueError, IndexError):
                    pass
        except Exception:
            self.available = False

    def h2d_avg_mb(self): return statistics.mean(self.rx_kb) / 1024 if self.rx_kb else 0.0
    def h2d_peak_mb(self): return max(self.rx_kb) / 1024 if self.rx_kb else 0.0
    def d2h_avg_mb(self):  return statistics.mean(self.tx_kb) / 1024 if self.tx_kb else 0.0
    def n_samples(self):   return len(self.rx_kb)


def gather_and_print(label, stats: dict):
    local_t = torch.tensor([
        stats["ms_per_step"],
        float(stats["steps"]),
        stats["h2d_avg"],
        stats["h2d_peak"],
        stats["d2h_avg"],
        float(stats["n_samples"]),
        float(stats["pcie_available"]),
    ], dtype=torch.float64, device=DEVICE)

    if IS_DIST:
        all_tensors = [torch.zeros_like(local_t) for _ in range(WORLD_SIZE)]
        dist.all_gather(all_tensors, local_t)
    else:
        all_tensors = [local_t]

    if RANK == 0:
        print(f"\n  {label}")
        for r, t in enumerate(all_tensors):
            ms       = t[0].item()
            steps    = int(t[1].item())
            h2d_avg  = t[2].item()
            h2d_peak = t[3].item()
            n_samp   = int(t[5].item())
            pcie_ok  = bool(t[6].item())
            pcie_str = (f"  H2D avg={h2d_avg:.1f} MB/s  peak={h2d_peak:.1f} MB/s  ({n_samp} samples)"
                        if pcie_ok else "  PCIe: dmon unavailable")
            print(f"    rank {r}:  {ms:8.1f} ms/step"
                  f"  ({steps} steps in {MEASURE_SECS}s)"
                  f"  {pcie_str}")


def make_stats(ms_per_step, steps, poller):
    return {
        "ms_per_step":    ms_per_step,
        "steps":          steps,
        "h2d_avg":        poller.h2d_avg_mb(),
        "h2d_peak":       poller.h2d_peak_mb(),
        "d2h_avg":        poller.d2h_avg_mb(),
        "n_samples":      poller.n_samples(),
        "pcie_available": poller.available,
    }


def barrier():
    if IS_DIST:
        dist.barrier()


def compute_distances(inputs, codebook):
    return (
          torch.sum(inputs ** 2, dim=1, keepdim=True)
        - 2 * torch.matmul(inputs, codebook.T)
        + torch.sum(codebook ** 2, dim=1)
    )


def balanced_assign(inputs, codebook):
    N = inputs.shape[0]
    w = max(1, N // codebook.shape[0])
    d = compute_distances(inputs, codebook)
    unassigned  = torch.ones(N, dtype=torch.bool,  device=inputs.device)
    assignments = torch.full((N,), -1, dtype=torch.long, device=inputs.device)
    for k in range(codebook.shape[0]):
        col = d[:, k].clone()
        col[~unassigned] = float("inf")
        idx  = torch.argsort(col)
        cand = idx[:w]
        cand = cand[col[cand] != float("inf")]
        assignments[cand] = k
        unassigned[cand]  = False
    return assignments


def argmin_assign(inputs, codebook):
    return torch.argmin(compute_distances(inputs, codebook), dim=1)


def measure_pattern1(inputs, codebooks, assign_fn):
    for _ in range(2):
        r = inputs
        for cb in codebooks:
            r = r - cb[assign_fn(r, cb)]
        torch.cuda.synchronize()

    barrier()
    poller = PCIePoller()
    poller.start()

    steps = 0
    t0 = time.perf_counter()
    while time.perf_counter() - t0 < MEASURE_SECS:
        r = inputs
        for cb in codebooks:
            r = r - cb[assign_fn(r, cb)]
        torch.cuda.synchronize()
        steps += 1

    poller.stop()
    elapsed = time.perf_counter() - t0
    return make_stats(elapsed / steps * 1000, steps, poller)


def measure_allreduce():
    batch_counts = torch.ones(K,    dtype=torch.float32, device=DEVICE)
    batch_sums   = torch.ones(K, D, dtype=torch.float32, device=DEVICE)
    stop_flag   = torch.zeros(1, dtype=torch.float32, device=DEVICE)
    CHECK_EVERY = 100

    for _ in range(10):
        for _ in range(NUM_STAGES):
            dist.all_reduce(batch_counts)
            dist.all_reduce(batch_sums)
    torch.cuda.synchronize()

    barrier()
    poller = PCIePoller()
    poller.start()

    steps = 0
    t0 = time.perf_counter()
    while True:
        for _ in range(CHECK_EVERY):
            for _ in range(NUM_STAGES):
                dist.all_reduce(batch_counts)
                dist.all_reduce(batch_sums)
            torch.cuda.synchronize()
            steps += 1

        stop_flag.fill_(1.0 if time.perf_counter() - t0 >= MEASURE_SECS else 0.0)
        dist.all_reduce(stop_flag, op=dist.ReduceOp.MAX)
        if stop_flag.item() > 0.5:
            break

    poller.stop()
    elapsed = time.perf_counter() - t0
    return make_stats(elapsed / steps * 1000, steps, poller)


def measure_combined(inputs, codebooks):
    batch_counts = torch.ones(K,    dtype=torch.float32, device=DEVICE)
    batch_sums   = torch.ones(K, D, dtype=torch.float32, device=DEVICE)
    stop_flag    = torch.zeros(1,   dtype=torch.float32, device=DEVICE)

    for _ in range(2):
        r = inputs
        for cb in codebooks:
            r = r - cb[balanced_assign(r, cb)]
        if IS_DIST:
            for _ in range(NUM_STAGES):
                dist.all_reduce(batch_counts)
                dist.all_reduce(batch_sums)
        torch.cuda.synchronize()

    barrier()
    poller = PCIePoller()
    poller.start()

    steps = 0
    t0 = time.perf_counter()
    while True:
        r = inputs
        for cb in codebooks:
            r = r - cb[balanced_assign(r, cb)]
        if IS_DIST:
            for _ in range(NUM_STAGES):
                dist.all_reduce(batch_counts)
                dist.all_reduce(batch_sums)
        torch.cuda.synchronize()
        steps += 1

        if IS_DIST:
            stop_flag.fill_(1.0 if time.perf_counter() - t0 >= MEASURE_SECS else 0.0)
            dist.all_reduce(stop_flag, op=dist.ReduceOp.MAX)
            if stop_flag.item() > 0.5:
                break
        elif time.perf_counter() - t0 >= MEASURE_SECS:
            break

    poller.stop()
    elapsed = time.perf_counter() - t0
    return make_stats(elapsed / steps * 1000, steps, poller)


def main():
    if not torch.cuda.is_available():
        print("ERROR: CUDA not available.")
        return

    torch.manual_seed(42 + RANK)
    probe = PCIePoller()

    if RANK == 0:
        print(f"\n{'='*65}")
        print(f"  HAMi-core Slowdown Test")
        print(f"{'='*65}")
        print(f"  GPUs:          {WORLD_SIZE}")
        print(f"  Device rank 0: {torch.cuda.get_device_name(0)}")
        print(f"  Distributed:   {'yes' if IS_DIST else 'no'}")
        print(f"  K={K}  D={D}  batch={BATCH_SIZE}  stages={NUM_STAGES}")
        print(f"  Each pattern runs {MEASURE_SECS}s per rank")

    inputs    = torch.randn(BATCH_SIZE, D, device=DEVICE)
    codebooks = [torch.randn(K, D, device=DEVICE) for _ in range(NUM_STAGES)]

    balanced_kernels = (7 + 6 * K) * NUM_STAGES
    argmin_kernels   = (7 + 1)     * NUM_STAGES
    allreduce_calls  = 2 * NUM_STAGES

    if RANK == 0:
        print(f"\n  Pattern 1: balanced_assign vs argmin")

    stats_argmin = measure_pattern1(inputs, codebooks, argmin_assign)
    gather_and_print(f"argmin x {NUM_STAGES} stages  ({argmin_kernels} kernels/step)", stats_argmin)

    stats_balanced = measure_pattern1(inputs, codebooks, balanced_assign)
    gather_and_print(f"balanced x {NUM_STAGES} stages  ({balanced_kernels:,} kernels/step)", stats_balanced)

    if IS_DIST:
        if RANK == 0:
            print(f"\n  Pattern 2: dist.all_reduce")
        stats_ar = measure_allreduce()
        gather_and_print(f"all_reduce x {allreduce_calls} calls/step", stats_ar)
    else:
        if RANK == 0:
            print(f"\n  Pattern 2: skipped (single GPU)")

    if RANK == 0:
        print(f"\n  Pattern 3: balanced + all_reduce combined")

    stats_combined = measure_combined(inputs, codebooks)
    gather_and_print("balanced + all_reduce (combined step)", stats_combined)

    if RANK == 0:
        print(f"\n{'='*65}")
        print(f"  Summary (rank 0)")
        print(f"{'='*65}")
        print(f"  argmin:    {stats_argmin['ms_per_step']:8.1f} ms/step  ({argmin_kernels} kernels)")
        print(f"  balanced:  {stats_balanced['ms_per_step']:8.1f} ms/step  ({balanced_kernels:,} kernels)")
        if IS_DIST:
            print(f"  all_reduce:{stats_ar['ms_per_step']:8.3f} ms/step  ({allreduce_calls} calls)")
            print(f"  combined:  {stats_combined['ms_per_step']:8.1f} ms/step")
        print(f"{'─'*65}")
        print(f"  Compare with and without LD_PRELOAD to isolate HAMi overhead.")
        print(f"{'='*65}\n")

    if IS_DIST:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

Overhead Breakdown

What HAMi does on every cuLaunchKernel

Every intercepted kernel launch executes (in src/cuda/memory.c):

ENSURE_RUNNING();      // → wait_status_self(): linear scan through shared memory process slots
pre_launch_kernel();   // → time(NULL) syscall + pthread_mutex_lock/unlock
rate_limiter(...)      // → 4 shared memory reads (sm_limit ×2, util_switch, recent_kernel)
                       //   + 2 ensure_initialized() calls
cuLaunchKernel(...)    // actual kernel dispatch

Additionally, every LOG_* macro in the interception path calls getenv("LIBCUDA_LOG_LEVEL") + atoi().

Cause 1: Per-kernel overhead on balanced_assign (+33%)

The balanced_assign pattern in the test script iterates over K=8192 centroids in a Python loop, each iteration issuing ~6 CUDA kernels. With 3 stages, this produces ~147,000 kernel launches per step. Each launch pays the full HAMi interception cost (~5 µs).

Pattern Kernels/step Without HAMi With HAMi Overhead
argmin 24 3.1 ms 3.2 ms ~0%
balanced_assign 147,000 4,357 ms 5,786 ms +33%

Cause 2: CPU cache pressure compounding with all_reduce (+82% combined)

When balanced_assign and dist.all_reduce run in the same step (Pattern 3 in the test), the overhead is super-additive — far worse than the sum of individual overheads:

Pattern Without HAMi With HAMi Overhead
balanced_assign alone 4,357 ms 5,786 ms +33%
all_reduce alone (6 calls) 0.306 ms 0.331 ms +8%
Combined step 4,363 ms 7,942 ms +82%
Expected (sum of parts) ~5,786 ms
Unexplained gap 2,156 ms

The 147K rounds of shared memory reads from HAMi's interception layer evict NCCL's ring-buffer and coordination data from CPU cache. When all_reduce runs immediately after, it suffers compounding cache-miss penalties. Additionally, the pthread_mutex_lock/unlock in pre_launch_kernel() on every kernel launch prevents the CUDA driver from pipelining kernel dispatches, causing NCCL to wait longer for the GPU stream to drain before initiating the collective.


Root Cause Analysis

P0 — Critical

  1. getenv()/atoi() called on every LOG macro (log_utils.h)

    • Every LOG_DEBUG, LOG_INFO, LOG_WARN, LOG_MSG calls getenv("LIBCUDA_LOG_LEVEL") + atoi() on every invocation. getenv() does a linear scan of the environment block. These macros appear in every intercepted CUDA function.
  2. wait_status_self() linear scan (multiprocess_memory_limit.c)

    • Called via ENSURE_RUNNING() on every kernel launch and memory operation. Scans up to 1024 process slots comparing PIDs, despite the slot pointer already being cached in region_info.my_slot.

P1 — High Impact

  1. pre_launch_kernel() uses time() syscall + unconditional mutex (multiprocess_memory_limit.c)

    • Every cuLaunchKernel call invokes time(NULL) (a syscall) and acquires pthread_mutex_lock, even though the recording interval is 1 second and kernels fire thousands/sec. The mutex is taken even when no update is needed.
  2. rate_limiter() redundant shared memory operations (multiprocess_utilization_watcher.c)

    • On every kernel launch: 3 shared memory reads + 2 ensure_initialized() calls, even when SM limiting is inactive (sm_limit >= 100 or == 0).
  3. oom_check() dead cuDeviceGetCount call (allocator.c)

    • Calls cuDeviceGetCount() on every memory allocation but never uses the result.

Proposed Fix

Each fix is independent and can be tested/applied separately:

Priority Fix What it eliminates
P0 Cache LIBCUDA_LOG_LEVEL in a global int getenv() + atoi() per LOG macro
P0 Use cached my_slot pointer in wait_status_self() O(n) linear scan → O(1)
P1 Replace time() with clock_gettime(CLOCK_REALTIME_COARSE) + double-checked locking syscall + unconditional mutex per kernel
P1 Cache sm_limit/utilization_switch in local statics 3 shared memory reads + 2 ensure_initialized() per kernel
P1 Remove dead cuDeviceGetCount call 1 driver API call per allocation

Results With Fix

Pattern No HAMi HAMi (original) HAMi (fixed)
argmin (24 kernels) 3.1 ms 3.2 ms (+3%) 3.2 ms (+3%)
balanced_assign (147K kernels) 4,357 ms 5,786 ms (+33%) 4,335 ms (~0%)
all_reduce only (6 calls) 0.306 ms 0.331 ms (+8%) 0.326 ms (+7%)
Combined step 4,363 ms 7,942 ms (+82%) 4,527 ms (+3.8%)

The fix brings the combined overhead from +82% → +3.8%.


Related

A working implementation of all fixes is available on branch perf/reduce-hijack-overhead.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions