|
| 1 | +# cuda-kernel-verifier |
| 2 | + |
| 3 | +**Runtime correctness checker for custom CUDA / Triton kernels - ~200 lines of logic.** |
| 4 | + |
| 5 | +Attach a single decorator to any forward function and the library will periodically re-run the same inputs through a known-correct implementation in a background thread, comparing results with `torch.allclose`. Zero impact on the training graph. Works with raw kernels, Triton ops, `torch.autograd.Function`, or any `nn.Module`, including models and layers compiled with `torch.compile`. The enqueue call is decorated with `@torch.compiler.disable` so it is always a clean graph break with no interference with compiled regions. |
| 6 | + |
| 7 | +--- |
| 8 | + |
| 9 | +## How it works |
| 10 | + |
| 11 | +``` |
| 12 | +forward(x) ──► kernel result ──► returned to caller immediately |
| 13 | + │ |
| 14 | + ▼ (background thread, non-blocking) |
| 15 | + outlier check |
| 16 | + │ |
| 17 | + ┌──────┴──────┐ |
| 18 | + │ outlier? │ not outlier? |
| 19 | + │ │ |
| 20 | + ▼ ▼ |
| 21 | + enqueue random sample gate |
| 22 | + (execution_sample_probability) |
| 23 | + │ |
| 24 | + ▼ |
| 25 | + ground_truth(x) |
| 26 | + │ |
| 27 | + ▼ |
| 28 | + torch.allclose? |
| 29 | + │ │ |
| 30 | + yes no |
| 31 | + │ │ |
| 32 | + discard failure_callback(...) |
| 33 | +``` |
| 34 | + |
| 35 | +### Sampling |
| 36 | + |
| 37 | +The checker does **not** run the ground truth on every call. That would negate the point of writing a fast kernel. Instead, each call passes through two gates before work is enqueued: |
| 38 | + |
| 39 | +1. **Outlier gate** - if the current input is detected as an outlier (see below), it is enqueued unconditionally, so unusual inputs are never skipped. |
| 40 | +2. **Random gate** - otherwise, the call is enqueued with probability `execution_sample_probability` (default `0.5`). Tune this down for large models where verification overhead matters. |
| 41 | + |
| 42 | +The comparison itself runs in a single daemon background thread so the main training loop is never blocked. You can adjust the sampling rate at any point during a run with `EquivalenceChecker.set_execution_sample_probability(p)`, or stop verification entirely with `EquivalenceChecker.stop()`. |
| 43 | + |
| 44 | +### Outlier detection |
| 45 | + |
| 46 | +`ExponentialRunningCentroidExecutionOutlierDetector` tracks the distribution of activations seen so far and flags batches that look statistically different from the norm. |
| 47 | + |
| 48 | +**Algorithm:** |
| 49 | + |
| 50 | +1. Maintain a **running centroid** via exponential moving average: |
| 51 | + `centroid ← α · mean(batch) + (1 − α) · centroid` |
| 52 | + Default `α = 0.01` (slow drift, stable reference). |
| 53 | + |
| 54 | +2. Compute the **L2 distance** of each sample in the batch from the centroid. |
| 55 | + |
| 56 | +3. Append distances to a rolling window of up to `max_distances` values (default 10 000). |
| 57 | + |
| 58 | +4. A batch is an **outlier** when: |
| 59 | + `mean(distances) / quantile(all_distances, p) ≥ outlier_threshold` |
| 60 | + Default `p = 0.95`, `outlier_threshold = 0.8`. |
| 61 | + |
| 62 | +5. The **first batch is always treated as an outlier** so the centroid can be seeded before any comparison. |
| 63 | + |
| 64 | +This means the verifier is biased toward checking inputs that are unusual (the cases most likely to expose a kernel bug) while randomly sampling the rest. |
| 65 | + |
| 66 | +--- |
| 67 | + |
| 68 | +## Installation |
| 69 | + |
| 70 | +**Requires CUDA** Install PyTorch for CUDA first, then the package: |
| 71 | + |
| 72 | +```bash |
| 73 | +pip install torch --index-url https://download.pytorch.org/whl/cu126 |
| 74 | +pip install cuda-kernel-verifier |
| 75 | +``` |
| 76 | + |
| 77 | +--- |
| 78 | + |
| 79 | +## Quick start |
| 80 | + |
| 81 | +```python |
| 82 | +import torch |
| 83 | +from cuda_kernel_verifier import equivalent, EquivalenceChecker |
| 84 | + |
| 85 | +def ground_truth(x: torch.Tensor) -> torch.Tensor: |
| 86 | + return x.sum(dim=1) |
| 87 | + |
| 88 | +def on_mismatch(args: FailureCallbackArgs) -> None: |
| 89 | + diff = (args.original_result - args.ground_truth_result).abs().max().item() |
| 90 | + raise AssertionError(f"Kernel diverged! max abs diff = {diff:.6f}") |
| 91 | + |
| 92 | +@equivalent(ground_truth, on_mismatch, rtol=1e-1, atol=1e-6) |
| 93 | +def my_fast_row_sum(x: torch.Tensor) -> torch.Tensor: |
| 94 | + return my_cuda_row_sum_kernel(x) |
| 95 | + |
| 96 | +EquivalenceChecker.start(execution_sample_probability=0.5) |
| 97 | + |
| 98 | +result = my_fast_row_sum(torch.randn(128, 512, device="cuda")) |
| 99 | + |
| 100 | +EquivalenceChecker.stop() |
| 101 | +``` |
| 102 | + |
| 103 | +### Attaching to `torch.autograd.Function` |
| 104 | + |
| 105 | +```python |
| 106 | +from torch.autograd import Function |
| 107 | +from cuda_kernel_verifier import equivalent, FailureCallbackArgs |
| 108 | + |
| 109 | +def sum_ground_truth(ctx, x): |
| 110 | + return x.sum(dim=1) |
| 111 | + |
| 112 | +def on_mismatch(args: FailureCallbackArgs) -> None: |
| 113 | + raise AssertionError("kernel diverged!") |
| 114 | + |
| 115 | +class FastRowSum(Function): |
| 116 | + @staticmethod |
| 117 | + @equivalent(sum_ground_truth, on_mismatch, rtol=1e-1, atol=1e-6) |
| 118 | + def forward(ctx, x): |
| 119 | + ctx.save_for_backward(x) |
| 120 | + return my_cuda_kernel(x) |
| 121 | +``` |
| 122 | + |
| 123 | +The decorator wraps the static method, so `ctx` is passed through transparently. Just mirror the full signature in the ground truth and ignore `ctx` with `_` if needed. |
| 124 | + |
| 125 | +### Custom outlier detector |
| 126 | + |
| 127 | +```python |
| 128 | +from cuda_kernel_verifier import ( |
| 129 | + equivalent, |
| 130 | + ExponentialRunningCentroidExecutionOutlierDetector, |
| 131 | +) |
| 132 | + |
| 133 | +detector = ExponentialRunningCentroidExecutionOutlierDetector( |
| 134 | + percentile=0.99, |
| 135 | + outlier_threshold=0.9, |
| 136 | + exponential_alpha=5e-3, |
| 137 | +) |
| 138 | + |
| 139 | +@equivalent(ground_truth, outlier_detector=detector) |
| 140 | +def my_kernel(x): |
| 141 | + ... |
| 142 | +``` |
| 143 | + |
| 144 | +--- |
| 145 | + |
| 146 | +## API reference |
| 147 | + |
| 148 | +### `equivalent(ground_truth_function, failure_callback=None, *, rtol=1e-2, atol=1e-8, outlier_detector=None)` |
| 149 | + |
| 150 | +Decorator factory. Returns a decorator that wraps the target function. |
| 151 | + |
| 152 | +| Parameter | Description | |
| 153 | +| ----------------------- | ----------------------------------------------------------------------------------- | |
| 154 | +| `ground_truth_function` | Known-correct implementation with the same signature. | |
| 155 | +| `failure_callback` | Called with `FailureCallbackArgs` on mismatch. Required. | |
| 156 | +| `rtol` | Relative tolerance for `torch.allclose` (default `1e-2`). | |
| 157 | +| `atol` | Absolute tolerance for `torch.allclose` (default `1e-8`). | |
| 158 | +| `outlier_detector` | Outlier strategy. Defaults to `ExponentialRunningCentroidExecutionOutlierDetector`. | |
| 159 | + |
| 160 | +--- |
| 161 | + |
| 162 | +### `EquivalenceChecker` |
| 163 | + |
| 164 | +Class-level singleton that manages the background thread and queue. |
| 165 | + |
| 166 | +| Method | Description | |
| 167 | +| --------------------------------------------------------------------- | ---------------------------------------------------------- | |
| 168 | +| `start(max_execution_queue_size=0, execution_sample_probability=0.5)` | Start the background thread. Resets all outlier detectors. | |
| 169 | +| `stop()` | Stop the thread and drain the queue. | |
| 170 | +| `is_running()` | Returns `True` if the checker is active. | |
| 171 | +| `set_execution_sample_probability(p)` | Adjust sampling rate at runtime. | |
| 172 | + |
| 173 | +--- |
| 174 | + |
| 175 | +### `ExponentialRunningCentroidExecutionOutlierDetector` |
| 176 | + |
| 177 | +| Parameter | Default | Description | |
| 178 | +| ------------------- | -------- | ---------------------------------------------------------------------- | |
| 179 | +| `percentile` | `0.95` | Quantile used as the distance scale reference. | |
| 180 | +| `max_distances` | `10_000` | Rolling window size for historical distances. | |
| 181 | +| `exponential_alpha` | `1e-2` | EMA factor for the running centroid. | |
| 182 | +| `outlier_threshold` | `0.8` | Fraction of the percentile scale that triggers outlier classification. | |
| 183 | + |
| 184 | +--- |
| 185 | + |
| 186 | +### `FailureCallbackArgs` |
| 187 | + |
| 188 | +Dataclass passed to the failure callback. |
| 189 | + |
| 190 | +| Field | Type | Description | |
| 191 | +| --------------------- | -------------- | ------------------------------------------- | |
| 192 | +| `original_result` | `torch.Tensor` | Output of the kernel under test (detached). | |
| 193 | +| `ground_truth_result` | `torch.Tensor` | Output of the reference function. | |
| 194 | + |
| 195 | +--- |
| 196 | + |
| 197 | +## Full example |
| 198 | + |
| 199 | +See [`examples/mnist_triton.py`](examples/mnist_triton.py) for a complete MNIST training loop using a Triton row-sum kernel validated in real time. |
| 200 | + |
| 201 | +--- |
0 commit comments