Skip to content

Commit cfb340d

Browse files
Initial commit
0 parents  commit cfb340d

File tree

7 files changed

+764
-0
lines changed

7 files changed

+764
-0
lines changed

.github/workflows/release.yml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
name: Release
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
version:
7+
description: "Version to release (e.g. 0.2.0)"
8+
required: true
9+
type: string
10+
11+
jobs:
12+
release:
13+
name: Build and publish to PyPI
14+
runs-on: ubuntu-latest
15+
permissions:
16+
contents: write
17+
18+
steps:
19+
- name: Checkout
20+
uses: actions/checkout@v4.2.2
21+
22+
- name: Set up Python
23+
uses: actions/setup-python@v5.4.0
24+
with:
25+
python-version: "3.11"
26+
27+
- name: Install build tools
28+
run: pip install hatch
29+
30+
- name: Bump version in pyproject.toml
31+
run: |
32+
sed -i "s/^version = .*/version = \"${{ github.event.inputs.version }}\"/" pyproject.toml
33+
echo "Releasing version ${{ github.event.inputs.version }}"
34+
35+
- name: Build distributions
36+
run: hatch build
37+
38+
- name: Create GitHub release and tag
39+
env:
40+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
41+
run: |
42+
git config user.name "github-actions[bot]"
43+
git config user.email "github-actions[bot]@users.noreply.github.com"
44+
git add pyproject.toml
45+
git diff --cached --quiet || git commit -m "chore: release v${{ github.event.inputs.version }}"
46+
git tag "v${{ github.event.inputs.version }}"
47+
git push origin HEAD --tags
48+
gh release create "v${{ github.event.inputs.version }}" dist/* \
49+
--title "v${{ github.event.inputs.version }}" \
50+
--generate-notes
51+
52+
- name: Publish to PyPI
53+
uses: pypa/gh-action-pypi-publish@release/v1
54+
with:
55+
password: ${{ secrets.PYPI_API_TOKEN }}
56+
verbose: true

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2026 Marian-Sergiu Nistor
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# cuda-kernel-verifier
2+
3+
**Runtime correctness checker for custom CUDA / Triton kernels - ~200 lines of logic.**
4+
5+
Attach a single decorator to any forward function and the library will periodically re-run the same inputs through a known-correct implementation in a background thread, comparing results with `torch.allclose`. Zero impact on the training graph. Works with raw kernels, Triton ops, `torch.autograd.Function`, or any `nn.Module`, including models and layers compiled with `torch.compile`. The enqueue call is decorated with `@torch.compiler.disable` so it is always a clean graph break with no interference with compiled regions.
6+
7+
---
8+
9+
## How it works
10+
11+
```
12+
forward(x) ──► kernel result ──► returned to caller immediately
13+
14+
▼ (background thread, non-blocking)
15+
outlier check
16+
17+
┌──────┴──────┐
18+
│ outlier? │ not outlier?
19+
│ │
20+
▼ ▼
21+
enqueue random sample gate
22+
(execution_sample_probability)
23+
24+
25+
ground_truth(x)
26+
27+
28+
torch.allclose?
29+
│ │
30+
yes no
31+
│ │
32+
discard failure_callback(...)
33+
```
34+
35+
### Sampling
36+
37+
The checker does **not** run the ground truth on every call. That would negate the point of writing a fast kernel. Instead, each call passes through two gates before work is enqueued:
38+
39+
1. **Outlier gate** - if the current input is detected as an outlier (see below), it is enqueued unconditionally, so unusual inputs are never skipped.
40+
2. **Random gate** - otherwise, the call is enqueued with probability `execution_sample_probability` (default `0.5`). Tune this down for large models where verification overhead matters.
41+
42+
The comparison itself runs in a single daemon background thread so the main training loop is never blocked. You can adjust the sampling rate at any point during a run with `EquivalenceChecker.set_execution_sample_probability(p)`, or stop verification entirely with `EquivalenceChecker.stop()`.
43+
44+
### Outlier detection
45+
46+
`ExponentialRunningCentroidExecutionOutlierDetector` tracks the distribution of activations seen so far and flags batches that look statistically different from the norm.
47+
48+
**Algorithm:**
49+
50+
1. Maintain a **running centroid** via exponential moving average:
51+
`centroid ← α · mean(batch) + (1 − α) · centroid`
52+
Default `α = 0.01` (slow drift, stable reference).
53+
54+
2. Compute the **L2 distance** of each sample in the batch from the centroid.
55+
56+
3. Append distances to a rolling window of up to `max_distances` values (default 10 000).
57+
58+
4. A batch is an **outlier** when:
59+
`mean(distances) / quantile(all_distances, p) ≥ outlier_threshold`
60+
Default `p = 0.95`, `outlier_threshold = 0.8`.
61+
62+
5. The **first batch is always treated as an outlier** so the centroid can be seeded before any comparison.
63+
64+
This means the verifier is biased toward checking inputs that are unusual (the cases most likely to expose a kernel bug) while randomly sampling the rest.
65+
66+
---
67+
68+
## Installation
69+
70+
**Requires CUDA** Install PyTorch for CUDA first, then the package:
71+
72+
```bash
73+
pip install torch --index-url https://download.pytorch.org/whl/cu126
74+
pip install cuda-kernel-verifier
75+
```
76+
77+
---
78+
79+
## Quick start
80+
81+
```python
82+
import torch
83+
from cuda_kernel_verifier import equivalent, EquivalenceChecker
84+
85+
def ground_truth(x: torch.Tensor) -> torch.Tensor:
86+
return x.sum(dim=1)
87+
88+
def on_mismatch(args: FailureCallbackArgs) -> None:
89+
diff = (args.original_result - args.ground_truth_result).abs().max().item()
90+
raise AssertionError(f"Kernel diverged! max abs diff = {diff:.6f}")
91+
92+
@equivalent(ground_truth, on_mismatch, rtol=1e-1, atol=1e-6)
93+
def my_fast_row_sum(x: torch.Tensor) -> torch.Tensor:
94+
return my_cuda_row_sum_kernel(x)
95+
96+
EquivalenceChecker.start(execution_sample_probability=0.5)
97+
98+
result = my_fast_row_sum(torch.randn(128, 512, device="cuda"))
99+
100+
EquivalenceChecker.stop()
101+
```
102+
103+
### Attaching to `torch.autograd.Function`
104+
105+
```python
106+
from torch.autograd import Function
107+
from cuda_kernel_verifier import equivalent, FailureCallbackArgs
108+
109+
def sum_ground_truth(ctx, x):
110+
return x.sum(dim=1)
111+
112+
def on_mismatch(args: FailureCallbackArgs) -> None:
113+
raise AssertionError("kernel diverged!")
114+
115+
class FastRowSum(Function):
116+
@staticmethod
117+
@equivalent(sum_ground_truth, on_mismatch, rtol=1e-1, atol=1e-6)
118+
def forward(ctx, x):
119+
ctx.save_for_backward(x)
120+
return my_cuda_kernel(x)
121+
```
122+
123+
The decorator wraps the static method, so `ctx` is passed through transparently. Just mirror the full signature in the ground truth and ignore `ctx` with `_` if needed.
124+
125+
### Custom outlier detector
126+
127+
```python
128+
from cuda_kernel_verifier import (
129+
equivalent,
130+
ExponentialRunningCentroidExecutionOutlierDetector,
131+
)
132+
133+
detector = ExponentialRunningCentroidExecutionOutlierDetector(
134+
percentile=0.99,
135+
outlier_threshold=0.9,
136+
exponential_alpha=5e-3,
137+
)
138+
139+
@equivalent(ground_truth, outlier_detector=detector)
140+
def my_kernel(x):
141+
...
142+
```
143+
144+
---
145+
146+
## API reference
147+
148+
### `equivalent(ground_truth_function, failure_callback=None, *, rtol=1e-2, atol=1e-8, outlier_detector=None)`
149+
150+
Decorator factory. Returns a decorator that wraps the target function.
151+
152+
| Parameter | Description |
153+
| ----------------------- | ----------------------------------------------------------------------------------- |
154+
| `ground_truth_function` | Known-correct implementation with the same signature. |
155+
| `failure_callback` | Called with `FailureCallbackArgs` on mismatch. Required. |
156+
| `rtol` | Relative tolerance for `torch.allclose` (default `1e-2`). |
157+
| `atol` | Absolute tolerance for `torch.allclose` (default `1e-8`). |
158+
| `outlier_detector` | Outlier strategy. Defaults to `ExponentialRunningCentroidExecutionOutlierDetector`. |
159+
160+
---
161+
162+
### `EquivalenceChecker`
163+
164+
Class-level singleton that manages the background thread and queue.
165+
166+
| Method | Description |
167+
| --------------------------------------------------------------------- | ---------------------------------------------------------- |
168+
| `start(max_execution_queue_size=0, execution_sample_probability=0.5)` | Start the background thread. Resets all outlier detectors. |
169+
| `stop()` | Stop the thread and drain the queue. |
170+
| `is_running()` | Returns `True` if the checker is active. |
171+
| `set_execution_sample_probability(p)` | Adjust sampling rate at runtime. |
172+
173+
---
174+
175+
### `ExponentialRunningCentroidExecutionOutlierDetector`
176+
177+
| Parameter | Default | Description |
178+
| ------------------- | -------- | ---------------------------------------------------------------------- |
179+
| `percentile` | `0.95` | Quantile used as the distance scale reference. |
180+
| `max_distances` | `10_000` | Rolling window size for historical distances. |
181+
| `exponential_alpha` | `1e-2` | EMA factor for the running centroid. |
182+
| `outlier_threshold` | `0.8` | Fraction of the percentile scale that triggers outlier classification. |
183+
184+
---
185+
186+
### `FailureCallbackArgs`
187+
188+
Dataclass passed to the failure callback.
189+
190+
| Field | Type | Description |
191+
| --------------------- | -------------- | ------------------------------------------- |
192+
| `original_result` | `torch.Tensor` | Output of the kernel under test (detached). |
193+
| `ground_truth_result` | `torch.Tensor` | Output of the reference function. |
194+
195+
---
196+
197+
## Full example
198+
199+
See [`examples/mnist_triton.py`](examples/mnist_triton.py) for a complete MNIST training loop using a Triton row-sum kernel validated in real time.
200+
201+
---

0 commit comments

Comments
 (0)