3131DSv3 suite (Oink vs Quack, multi-shape):
3232 CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\
3333 --json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json
34+
35+ Quack baseline note:
36+ - Oink exposes an **in-place** fused op (writes `x` and `residual` in-place).
37+ - Quack provides an equivalent fused kernel, but typically returns `out` and
38+ `residual_out` (out-of-place) and does not expose a public "update my input
39+ buffers in-place" API.
40+ - For integration realism (vLLM-style semantics) we default to timing:
41+ Quack fused kernel + 2 explicit copies to apply the in-place updates
42+ so the benchmark covers the full semantic cost.
3443"""
3544
3645from __future__ import annotations
@@ -177,6 +186,7 @@ def bench_one(
177186 warmup_ms : int ,
178187 iters_ms : int ,
179188 verify : bool ,
189+ quack_baseline : str ,
180190) -> Dict [str , Any ]:
181191 device = torch .device ("cuda" )
182192 x = torch .randn ((M , N ), device = device , dtype = dtype )
@@ -212,23 +222,40 @@ def fn():
212222 row .update (stats )
213223
214224 if quack_rmsnorm_fwd_mut is not None :
215- out_q = torch .empty_like (x )
216- res_out_q = torch .empty_like (residual )
225+ x_q = x .clone ()
226+ residual_q = residual .clone ()
227+ out_q = torch .empty_like (x_q )
228+ res_out_q = torch .empty_like (residual_q )
217229
218- def fn_q ():
230+ def fn_q_kernel ():
219231 quack_rmsnorm_fwd_mut (
220- x ,
232+ x_q ,
221233 w ,
222234 out_q ,
223235 None , # bias
224236 None , # rstd
225237 None , # mean
226- residual ,
238+ residual_q ,
227239 res_out_q ,
228240 1e-6 ,
229241 False , # is_layernorm
230242 )
231243
244+ if quack_baseline == "kernel" :
245+ fn_q = fn_q_kernel
246+ elif quack_baseline == "kernel_inplace" :
247+
248+ def fn_q ():
249+ fn_q_kernel ()
250+ # Apply the same in-place semantics as vLLM expects:
251+ # - x is overwritten with y
252+ # - residual is overwritten with z = x + residual
253+ x_q .copy_ (out_q )
254+ residual_q .copy_ (res_out_q )
255+
256+ else :
257+ raise ValueError (f"Unknown quack_baseline: { quack_baseline } " )
258+
232259 ms_q = do_bench_triton (fn_q , warmup_ms = warmup_ms , rep_ms = iters_ms )
233260 gbps_q = bytes_io / (ms_q * 1e-3 ) / 1e9
234261 row .update (
@@ -287,6 +314,18 @@ def main() -> None:
287314 p .add_argument (
288315 "--iters" , type = int , default = 200 , help = "rep_ms for do_bench (default: 200)"
289316 )
317+ p .add_argument (
318+ "--quack-baseline" ,
319+ type = str ,
320+ default = "kernel_inplace" ,
321+ choices = ["kernel" , "kernel_inplace" ],
322+ help = (
323+ "How to time Quack for the in-place fused op.\n "
324+ "- kernel: Quack fused kernel only (preallocated out/residual_out).\n "
325+ "- kernel_inplace: Quack fused kernel + 2 explicit copies to apply "
326+ "in-place semantics (integration-realistic)."
327+ ),
328+ )
290329 p .add_argument ("--skip-verify" , action = "store_true" )
291330 p .add_argument ("--json" , type = str , default = None )
292331 args = p .parse_args ()
@@ -309,6 +348,7 @@ def main() -> None:
309348 warmup_ms = int (args .warmup_ms ),
310349 iters_ms = int (args .iters ),
311350 verify = not bool (args .skip_verify ),
351+ quack_baseline = str (args .quack_baseline ),
312352 )
313353 )
314354
@@ -324,7 +364,10 @@ def main() -> None:
324364 warmup_ms = int (args .warmup_ms ),
325365 rep_ms = int (args .iters ),
326366 method = "triton.testing.do_bench(mean)" ,
327- note = "Oink fused_add_rmsnorm_inplace_ vs Quack quack::_rmsnorm_fwd(residual=..., residual_out=...) when available" ,
367+ note = (
368+ "Oink fused_add_rmsnorm_inplace_ vs Quack baseline "
369+ f"({ args .quack_baseline } ) when available"
370+ ),
328371 ),
329372 )
330373
0 commit comments