Skip to content

Commit 5d195d6

Browse files
committed
update
1 parent 7e818ee commit 5d195d6

22 files changed

+7996
-4432
lines changed

oink/benchmarks/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsn
9696
--json /tmp/fused_add_rmsnorm_sm100_bf16.json
9797
```
9898

99+
Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`).
100+
Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark
101+
times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`)
102+
to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only
103+
the Quack fused kernel with preallocated outputs.
104+
99105
### RMSNorm backward
100106

101107
```bash

oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@
3131
DSv3 suite (Oink vs Quack, multi-shape):
3232
CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\
3333
--json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json
34+
35+
Quack baseline note:
36+
- Oink exposes an **in-place** fused op (writes `x` and `residual` in-place).
37+
- Quack provides an equivalent fused kernel, but typically returns `out` and
38+
`residual_out` (out-of-place) and does not expose a public "update my input
39+
buffers in-place" API.
40+
- For integration realism (vLLM-style semantics) we default to timing:
41+
Quack fused kernel + 2 explicit copies to apply the in-place updates
42+
so the benchmark covers the full semantic cost.
3443
"""
3544

3645
from __future__ import annotations
@@ -177,6 +186,7 @@ def bench_one(
177186
warmup_ms: int,
178187
iters_ms: int,
179188
verify: bool,
189+
quack_baseline: str,
180190
) -> Dict[str, Any]:
181191
device = torch.device("cuda")
182192
x = torch.randn((M, N), device=device, dtype=dtype)
@@ -212,23 +222,40 @@ def fn():
212222
row.update(stats)
213223

214224
if quack_rmsnorm_fwd_mut is not None:
215-
out_q = torch.empty_like(x)
216-
res_out_q = torch.empty_like(residual)
225+
x_q = x.clone()
226+
residual_q = residual.clone()
227+
out_q = torch.empty_like(x_q)
228+
res_out_q = torch.empty_like(residual_q)
217229

218-
def fn_q():
230+
def fn_q_kernel():
219231
quack_rmsnorm_fwd_mut(
220-
x,
232+
x_q,
221233
w,
222234
out_q,
223235
None, # bias
224236
None, # rstd
225237
None, # mean
226-
residual,
238+
residual_q,
227239
res_out_q,
228240
1e-6,
229241
False, # is_layernorm
230242
)
231243

244+
if quack_baseline == "kernel":
245+
fn_q = fn_q_kernel
246+
elif quack_baseline == "kernel_inplace":
247+
248+
def fn_q():
249+
fn_q_kernel()
250+
# Apply the same in-place semantics as vLLM expects:
251+
# - x is overwritten with y
252+
# - residual is overwritten with z = x + residual
253+
x_q.copy_(out_q)
254+
residual_q.copy_(res_out_q)
255+
256+
else:
257+
raise ValueError(f"Unknown quack_baseline: {quack_baseline}")
258+
232259
ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms)
233260
gbps_q = bytes_io / (ms_q * 1e-3) / 1e9
234261
row.update(
@@ -287,6 +314,18 @@ def main() -> None:
287314
p.add_argument(
288315
"--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)"
289316
)
317+
p.add_argument(
318+
"--quack-baseline",
319+
type=str,
320+
default="kernel_inplace",
321+
choices=["kernel", "kernel_inplace"],
322+
help=(
323+
"How to time Quack for the in-place fused op.\n"
324+
"- kernel: Quack fused kernel only (preallocated out/residual_out).\n"
325+
"- kernel_inplace: Quack fused kernel + 2 explicit copies to apply "
326+
"in-place semantics (integration-realistic)."
327+
),
328+
)
290329
p.add_argument("--skip-verify", action="store_true")
291330
p.add_argument("--json", type=str, default=None)
292331
args = p.parse_args()
@@ -309,6 +348,7 @@ def main() -> None:
309348
warmup_ms=int(args.warmup_ms),
310349
iters_ms=int(args.iters),
311350
verify=not bool(args.skip_verify),
351+
quack_baseline=str(args.quack_baseline),
312352
)
313353
)
314354

@@ -324,7 +364,10 @@ def main() -> None:
324364
warmup_ms=int(args.warmup_ms),
325365
rep_ms=int(args.iters),
326366
method="triton.testing.do_bench(mean)",
327-
note="Oink fused_add_rmsnorm_inplace_ vs Quack quack::_rmsnorm_fwd(residual=..., residual_out=...) when available",
367+
note=(
368+
"Oink fused_add_rmsnorm_inplace_ vs Quack baseline "
369+
f"({args.quack_baseline}) when available"
370+
),
328371
),
329372
)
330373

oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import argparse
1818
import csv
1919
import os
20-
import sys
2120
from dataclasses import dataclass
2221
from typing import List, Optional, Tuple
2322

@@ -30,19 +29,17 @@
3029
# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
3130
os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
3231

33-
# Make the in-repo KernelAgent Oink package importable without an editable install.
34-
_HERE = os.path.dirname(os.path.abspath(__file__))
35-
_OINK_SRC = os.path.abspath(os.path.join(_HERE, "..", "src"))
36-
if _OINK_SRC not in sys.path:
37-
sys.path.insert(0, _OINK_SRC)
38-
3932
from bench_utils import ( # noqa: E402
4033
ErrorStatsAccumulator,
4134
collect_device_meta,
35+
ensure_oink_src_on_path,
4236
error_stats_to_row,
4337
iter_row_blocks,
4438
write_json,
4539
)
40+
41+
ensure_oink_src_on_path()
42+
4643
from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
4744

4845
try:

0 commit comments

Comments
 (0)