Skip to content

[PERFORMANCE] Top‑k windowed mask generation dominates end‑to‑end latency for large sequences #196

@LoserCheems

Description

@LoserCheems

Performance issue description

When enabling top‑k windowing (key_len > window_size), the mask generation step (topk + scatter) often dominates end‑to‑end latency, outweighing the attention kernel speedup. This is most pronounced for larger Lq/Lk and for bias layouts with more “head-like” planes (e.g., (B,H) vs (B,1)). With small windows, attention gets slightly faster, but the Python-side mask materialization increases latency significantly, making E2E slower.

Current performance metrics

  • Device: CUDA
  • Dtype: float16
  • B=2, H=8, KVH=4, D=64, Lq=4096, Lk=4096
  • Numbers are per-iteration latency (ms). total = mask + attn.

bias_kind=B,H, mask_kind=None

case mask attn total
window=None 0.006 0.826 0.832
window=Lk(4096) 0.002 0.860 0.862
window=256 3.625 0.720 4.345
window=512 4.031 0.749 4.780
window=1024 4.899 0.849 5.748
window=2048 6.861 1.002 7.863

bias_kind=B,H, mask_kind=2D

case mask attn total
window=None 0.003 0.331 0.334
window=Lk(4096) 0.004 0.329 0.333
window=256 6.136 0.657 6.793
window=512 6.576 0.675 7.251
window=1024 7.477 0.660 8.137
window=2048 9.475 0.656 10.131

bias_kind=B,H, mask_kind=4D-like-bias

case mask attn total
window=None 0.004 0.652 0.656
window=Lk(4096) 0.001 0.670 0.671
window=256 7.202 0.670 7.872
window=512 7.656 0.654 8.310
window=1024 8.533 0.656 9.189
window=2048 10.518 0.679 11.197

bias_kind=B,KVH, mask_kind=None

case mask attn total
window=None 0.004 0.508 0.512
window=Lk(4096) 0.002 0.512 0.514
window=256 1.743 0.426 2.169
window=512 1.965 0.455 2.420
window=1024 2.407 0.517 2.924
window=2048 3.402 0.625 4.027

bias_kind=B,KVH, mask_kind=4D-like-bias

case mask attn total
window=None 0.002 0.425 0.427
window=Lk(4096) 0.002 0.419 0.421
window=256 3.570 0.436 4.006
window=512 3.781 0.423 4.204
window=1024 4.236 0.414 4.650
window=2048 5.217 0.418 5.635

bias_kind=1,H, mask_kind=None

case mask attn total
window=None 0.022 0.516 0.538
window=Lk(4096) 0.001 0.505 0.506
window=256 1.725 0.441 2.166
window=512 1.963 0.462 2.425
window=1024 2.388 0.496 2.884
window=2048 3.406 0.594 4.000

bias_kind=B,1, mask_kind=None

case mask attn total
window=None 0.030 0.504 0.534
window=Lk(4096) 0.003 0.510 0.513
window=256 0.410 0.307 0.717
window=512 0.463 0.326 0.789
window=1024 0.605 0.384 0.989
window=2048 0.848 0.488 1.336

Observations

  • The attention kernel time drops modestly for small windows, but mask creation quickly dominates as window_size grows (and as head-like planes increase).
  • B,1 is the cheapest case for mask creation; B,H with a 4D mask is the most expensive.
  • Baseline (no windowing) is consistently faster end‑to‑end in these runs.

Expected performance

  • Expect enabling windowing to reduce end‑to‑end latency when ws << Lk (e.g., ws ≤ Lk/16), or at least not degrade it.
  • Ideally, mask generation should be on par with or cheaper than the kernel savings. A target could be E2E <= baseline for ws ≤ Lk/16 across bias/mask variants.

Environment information

PyTorch: 2.9.0a0+50eac811a6.nv25.09
CUDA: 13.0
GPU: NVIDIA GeForce RTX 4090

Benchmark code

Profiling information

No response

System information

No response

Additional context

No response

Metadata

Metadata

Labels

performanceperformance problems or optimization

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions