Conversation
Co-authored-by: Ted Zadouri <tz6037@princeton.edu>
Easier to benchmark without having to install FA2
* Tweaks * better errors * Switch to new API
* Implement split KV * Remove modal bench harness * Fixes
- Create block_sparse_utils.py with SM80/90 block-sparse logic - Refactor flash_fwd.py to use extracted utilities - Clean up whitespace in block_sparsity.py This extracts the block-sparse consumer loop and related utilities from flash_fwd.py into a reusable module for SM80/90 architectures.
* add gqa for sm100 bwd * remove mha guard for test * change to cluster size 1
* begin block sparsity computation kernel * block sparsity computation kernel and benchmark working * loop range_constexpr * add fast kernel * merge fast and regular kernel * use TensorSSA approach to mask mod * update with OOB check * tests and benchmarks for block sparsity working * remove extraneous files * Revert mask.py to previous state - removing unintended changes from block sparsity work * remove flex attn test stub * add sleeps to benchmark * correct block sparsity benchmark to use torch.compile * Restore missing mask definitions and fix benchmark window_size handling * move benchmarks into new directory * compute_block_sparsity docstring * streamline compute block sparsity benchmark script
Credit: Ben Spector
Two fixes: 1. test_mask_mod.py: Update _flash_attn_fwd calls to use tile_mn=(m, n) instead of the removed m_block_size/n_block_size parameters. The API was changed in 99d0148 but the tests were never updated. 2. flash_bwd_sm90.py: Fix kwarg name mismatch in dQaccum_store_block_sparse_bwd_sm90 call — num_wg_mma → num_mma_warp_groups to match the function signature in block_sparse_utils.py.
…tures and remove from call sites (#2369) Follow-up the clean-up at `10bbfd0e246b99feabfe620a355f9213eeb6c9b5`, which switched to use `cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)` instead of relying on getting a concrete cuda stream at kernel compilation time. However, the tvm_ffi EnvStream feature excludes the `stream` parameter from the TVM FFI function signature. That means calling a compiled tvm_ffi function expecting EnvStream with an actual cuda stream would cause function signature mismatch. On one hand, TVM FFI would obtain the stream parameter implicitly via TVMFFIEnvGetStream at runtime. On the other hand, passing a concrete stream as a positional argument would cause parameter count mismatch (i.e., compiled function expects N args, but N+1 are passed). This PR: - Move stream to the last parameter with default None in all kernel __call__ methods (fwd, bwd, preprocess, postprocess, combine) - Remove current_stream from runtime call sites in interface.py - Keep current_stream in compile sites (moved to end) so the DSL compiler can generate the EnvStream stub Did a full-sweep test with `pip install nvidia-cutlass-dsl==4.4.1 apache-tvm-ffi==0.1.9`: ``` FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 256 -x tests/cute/test_flash_attn.py # 58371 passed, 34272 skipped, 306 xfailed, 394005 warnings in 719.14s (0:11:59) FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 8 -x tests/cute/test_flash_attn.py -k 'not test_flash_attn_kvcache' # 55875 passed, 33696 skipped, 306 xfailed, 6 warnings in 181.55s (0:03:01) FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 8 -x tests/cute/ --ignore=tests/cute/test_flash_attn.py # 24588 passed, 18971 skipped, 78 xfailed, 162951 warnings in 1296.16s (0:21:36) ``` Co-authored-by: Gefei Zuo <gzuo@fb.com>
With the latest cutedsl 4.4.2 release, aot + tvm_ffi feature got some updates: 1. No long need to export to `.o` then manually link to `.so`. tvm_ffi works with the exported `.o` objects. 2. No additional cutedsl patches are needed to workaround the dynamic linking symbol resolution issue. So this diff: 1. Bump to "nvidia-cutlass-dsl>=4.4.2" 2. Remove the dependency and usage of setuptools + distutils (i.e., prior linking workarounds). 3. export/load ".o" instead of ".so" Did a full-sweep test with `pip install nvidia-cutlass-dsl==4.4.2 apache-tvm-ffi==0.1.9` (& with recent #2369 patched): ``` FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 256 -x tests/cute/test_flash_attn.py # 58371 passed, 34272 skipped, 306 xfailed, 403941 warnings in 731.89s (0:12:11) FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 8 -x tests/cute/test_flash_attn.py -k 'not test_flash_attn_kvcache' # 55875 passed, 33696 skipped, 306 xfailed, 6 warnings in 180.83s (0:03:00) FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 8 -x tests/cute/ --ignore=tests/cute/test_flash_attn.py # 24588 passed, 18971 skipped, 78 xfailed, 163546 warnings in 1292.44s (0:21:32) ```
…00 (#2368) * fix * fix * cleaner fix * kernel fix * rm interface changes * fp8 support * add test * update test style
* [Fwd,Sm90] Add paged KV attention support Implement paged KV cache support for SM90 (Hopper) forward attention, matching SM100's feature set with arbitrary page sizes via two paths: - TMA path (page_size == n_block_size): fast, uses TMA with page index indirection from the page table - cp.async path (page_size != n_block_size): general, uses PagedKVManager with row-by-row cp.async copies and PipelineAsync barriers Key changes: - flash_fwd_sm90.py: Add use_tma_KV flag, dual pipeline types (PipelineTmaAsync vs PipelineAsync), separate Q mbarrier for non-TMA path, n_block clamping for causal+pack_gqa edge case - paged_kv.py: Add arch parameter (default 100), SM90 V layout handling (not transposed), smem layout flattening via group_modes - interface.py: Remove SM90 paged KV assert, pass paged_kv_non_tma, fix hardcoded 128 to tile_n in compile key and SM100 constructor Test Plan: CUDA_VISIBLE_DEVICES=0 pytest tests/cute/test_flash_attn.py -k "test_flash_attn_kvcache" -x # 2496 passed, 576 skipped, 0 failed * Restore accidentally removed comments in flash_fwd_sm90.py The paged KV commit (459b94d) accidentally removed several commented-out code blocks and TODO comments during refactoring. Restore them to preserve debug aids and development notes. * [Paged KV] Remove redundant v_gmem_transposed check in load_KV Inside the else branch (arch != 90), v_gmem_transposed is always True since it's defined as (arch != 90). Simplify the condition to just check K_or_V == "V". * [Fwd,Sm90] Refactor KV loading into load_KV method Extract a load_KV method on FlashAttentionForwardSm90 (following flash_fwd_sm100.py's pattern) that encapsulates the TMA vs cp_async dispatch for K/V loads. This eliminates duplicated page_idx branching (8 occurrences) in the TMA path and the repeated 5-line cp_async commit sequence (4 occurrences). The producer loop for non-block-sparsity now uses partial(self.load_KV, ...) for both TMA and cp_async paths, and the non-overlap iteration body is shared between them (gated only by const_expr for load_page_table). Block sparsity still passes raw TMA closures to produce_block_sparse_loads. * [Fwd,Sm90] Always load Q on its own mbarrier, never piggyback on pipeline_k Decouple Q loading from K's pipeline barrier. Previously, when both Q and K/V used TMA, Q was piggybacked onto pipeline_k via extra_tx_count so the consumer waited on a single barrier for both. Now Q always uses its own mbar_ptr_Q regardless of the K/V loading path (TMA or cp.async). This simplifies the code by making Q loading uniform across all paths and removes Q-related parameters (load_Q, use_tma_q, tma_q_bytes, load_q_with_first) from block_sparse_utils. No measurable perf impact. * address comments * ruff lint
* limit vec_size to 2 for score mod when not on Sm100 * properly handle arch for vec sizes to check equality
* Support 2CTA for sliding window hdim 192 * Remove local 2CTA restriction in SM100 backward * Enable SM100 backward local tests for hdim 192
There was a problem hiding this comment.
Pull request overview
Adds an initial CuTeDSL-based FlashAttention implementation and supporting utilities/configuration within flash_sparse_attn/ops/cute.
Changes:
- Introduces CuTeDSL kernels/utilities for attention (softmax, paged KV, PackGQA, barriers/pipelines, seqlen helpers).
- Adds developer tooling (benchmarking, config search, logging, ptxas hook, compilation/JIT caching).
- Adds packaging metadata and docs for distributing the CuTeDSL implementation.
Reviewed changes
Copilot reviewed 43 out of 48 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_sparse_attn/ops/cute/testing.py | Adds PyTorch reference helpers (masking/padding/attention) and FakeTensorMode utilities. |
| flash_sparse_attn/ops/cute/softmax.py | Implements CuTeDSL softmax (including SM100 specialization) and score-mod hooks. |
| flash_sparse_attn/ops/cute/sm90_config_search.py | Adds a CLI tool to search feasible SM90 attention configs. |
| flash_sparse_attn/ops/cute/seqlen_info.py | Adds CuTeDSL sequence-length metadata structs and offset helpers. |
| flash_sparse_attn/ops/cute/pyproject.toml | Defines the standalone package metadata/deps for the CuTeDSL implementation. |
| flash_sparse_attn/ops/cute/pipeline.py | Adds pipeline wrappers/mixins with index/phase helpers and elect-one variants. |
| flash_sparse_attn/ops/cute/paged_kv.py | Adds paged KV manager for loading K/V pages into shared memory. |
| flash_sparse_attn/ops/cute/pack_gqa.py | Adds utilities and a kernel helper class for Pack-GQA layout folding/unfolding. |
| flash_sparse_attn/ops/cute/named_barrier.py | Defines named barrier enums for fwd/bwd across architectures. |
| flash_sparse_attn/ops/cute/mma_sm100_desc.py | Ports SM100 UMMA instruction descriptor encoding helpers to Python. |
| flash_sparse_attn/ops/cute/flash_fwd_sm120.py | Adds SM120 forward kernel specialization with SMEM capacity checks. |
| flash_sparse_attn/ops/cute/flash_bwd_sm120.py | Adds SM120 backward kernel specialization with SMEM capacity checks. |
| flash_sparse_attn/ops/cute/flash_bwd_preprocess.py | Adds bwd preprocess kernel (PdPsum/LSE log2/dQaccum clear). |
| flash_sparse_attn/ops/cute/flash_bwd_postprocess.py | Adds bwd postprocess kernel (reduce/scale dQaccum to dQ). |
| flash_sparse_attn/ops/cute/fast_math.py | Adds a small CuTeDSL fast-math helper (clz). |
| flash_sparse_attn/ops/cute/fa_logging.py | Adds unified host/device logging controlled via FA_LOG_LEVEL. |
| flash_sparse_attn/ops/cute/cute_dsl_utils.py | Adds compile hooks, alignment assumptions, and torch→CuTe tensor helpers. |
| flash_sparse_attn/ops/cute/cute_dsl_ptxas.py | Adds an optional system-ptxas hook for CUTLASS DSL PTX compilation. |
| flash_sparse_attn/ops/cute/copy_utils.py | Adds CuTeDSL copy helpers and low-level inline-asm primitives. |
| flash_sparse_attn/ops/cute/compute_block_sparsity.py | Adds a kernel to compute block sparsity indices/counts from a mask_mod. |
| flash_sparse_attn/ops/cute/cache_utils.py | Adds in-memory + optional persistent JIT cache for compiled kernels. |
| flash_sparse_attn/ops/cute/block_sparsity.py | Adds block-sparsity tensor normalization, broadcasting, and conversion helpers. |
| flash_sparse_attn/ops/cute/block_info.py | Adds block-range computations (causal/local/split-kv/new-K) for tiling. |
| flash_sparse_attn/ops/cute/benchmark.py | Adds benchmark wrappers for forward/backward/combined profiling. |
| flash_sparse_attn/ops/cute/bench_utils.py | Adds shared benchmark utilities (ref attention, FLOPS, cuDNN graph helpers). |
| flash_sparse_attn/ops/cute/barrier.py | Adds inline-asm helpers for acquire/release reductions and spin-wait barriers. |
| flash_sparse_attn/ops/cute/ampere_helpers.py | Adds SM80 helper layouts and GEMM loops for tiled MMA. |
| flash_sparse_attn/ops/cute/init.py | Exposes public APIs and monkey-patches cute.compile for optional SASS dumping. |
| flash_sparse_attn/ops/cute/README.md | Documents installation/usage/dev flow for the CuTeDSL package. |
| flash_sparse_attn/ops/cute/MANIFEST.in | Adds packaging excludes/prunes for build artifacts and egg-info. |
| flash_sparse_attn/ops/cute/LICENSE | Adds BSD-3 license text for this subtree/package. |
| flash_sparse_attn/ops/cute/AUTHORS | Adds a contributor list for this subtree/package. |
| flash_sparse_attn/ops/cute/.flake8 | Adds flake8 configuration for this subtree/package. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| m_block_size: cutlass.Constexpr[int] | ||
| head_dim_padded: cutlass.Constexpr[int] | ||
| check_hdim_oob: cutlass.Constexpr[bool] | ||
| qhead_per_kvhead: cutlass.Constexpr[bool] |
There was a problem hiding this comment.
qhead_per_kvhead is used as an integer divisor/multiplier (e.g., idx // self.qhead_per_kvhead), but it's annotated as cutlass.Constexpr[bool]. If this is treated as a boolean constant, True behaves like 1 and will break Pack-GQA pointer math (e.g., h_idx becomes 0). Change this field to an integer constexpr type (e.g., cutlass.Constexpr[int]) and ensure callers pass an integer replication factor.
| qhead_per_kvhead: cutlass.Constexpr[bool] | |
| qhead_per_kvhead: cutlass.Constexpr[int] |
| num_copy_elems = 128 // self.dtype.width | ||
| threads_per_row = gmem_k_block_size // num_copy_elems | ||
| self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( | ||
| self.dtype, threads_per_row, self.num_threads, num_copy_elems |
There was a problem hiding this comment.
copy_utils.tiled_copy_2d (as added in this PR) takes (dtype, major_mode_size, num_threads, is_async=False). This call passes threads_per_row as major_mode_size and passes num_copy_elems positionally into is_async, which will silently be treated as truthy and enable async copies unintentionally. Update the call to match the helper’s signature (and pass the actual major-mode size, not threads-per-row), or update tiled_copy_2d to accept an explicit num_copy_elems parameter if that’s the intended API.
| self.dtype, threads_per_row, self.num_threads, num_copy_elems | |
| self.dtype, gmem_k_block_size, self.num_threads |
| def tiled_copy_2d( | ||
| dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False | ||
| ) -> cute.TiledCopy: | ||
| num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width | ||
| copy_elems = num_copy_bits // dtype.width | ||
| copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() | ||
| copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) | ||
| gmem_threads_per_row = major_mode_size // copy_elems | ||
| assert num_threads % gmem_threads_per_row == 0 | ||
| thr_layout = cute.make_ordered_layout( | ||
| (num_threads // gmem_threads_per_row, gmem_threads_per_row), | ||
| order=(1, 0), | ||
| ) | ||
| val_layout = cute.make_layout((1, copy_elems)) | ||
| return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) |
There was a problem hiding this comment.
This helper’s signature doesn’t match how it’s used elsewhere in the PR (call sites pass a 4th positional argument that appears to be num_copy_elems). As written, the helper computes copy_elems internally from major_mode_size, so passing num_copy_elems at call sites cannot work. Either (a) change the helper signature to accept num_copy_elems explicitly and derive num_copy_bits from it, or (b) update all call sites to pass the actual major_mode_size and use the default is_async kwarg intentionally.
| if scores.shape[-2] not in _attention_ref_mask_cache: | ||
| mask = torch.tril( | ||
| torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0 | ||
| ) | ||
| _attention_ref_mask_cache[scores.shape[-2]] = mask | ||
| else: | ||
| mask = _attention_ref_mask_cache[scores.shape[-2]] | ||
| scores = scores.masked_fill(mask, float("-inf")) |
There was a problem hiding this comment.
The causal mask logic is inverted: torch.tril(..., dtype=bool) marks the allowed lower-triangular region as True, but masked_fill(mask, -inf) masks out exactly those allowed positions. This makes the reference attention incorrect for causal=True. Also, the cache key only uses scores.shape[-2], but the mask shape depends on both T and S (scores.shape[-2:]), so it can return a wrong-shaped mask when seqlen_q != seqlen_k.
| if scores.shape[-2] not in _attention_ref_mask_cache: | |
| mask = torch.tril( | |
| torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0 | |
| ) | |
| _attention_ref_mask_cache[scores.shape[-2]] = mask | |
| else: | |
| mask = _attention_ref_mask_cache[scores.shape[-2]] | |
| scores = scores.masked_fill(mask, float("-inf")) | |
| mask_key = scores.shape[-2:] | |
| if mask_key not in _attention_ref_mask_cache: | |
| mask = torch.tril( | |
| torch.ones(mask_key, device=scores.device, dtype=torch.bool), diagonal=0 | |
| ) | |
| _attention_ref_mask_cache[mask_key] = mask | |
| else: | |
| mask = _attention_ref_mask_cache[mask_key] | |
| scores = scores.masked_fill(~mask, float("-inf")) |
| from importlib.metadata import PackageNotFoundError, version | ||
|
|
||
| try: | ||
| __version__ = version("fa4") |
There was a problem hiding this comment.
The distribution metadata lookup uses version(\"fa4\"), but the added pyproject.toml declares name = \"flash-attn-4\". With this mismatch, __version__ will fall back to 0.0.0 even when installed. Update the queried distribution name to match the project name (or align the project name if fa4 is intended).
| __version__ = version("fa4") | |
| __version__ = version("flash-attn-4") |
| if cubin_path is not None: | ||
| cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( | ||
| load_cubin_module_data_patched, filepath=cubin_path | ||
| ) | ||
| output = cute_compile_og(*args, **kwargs) | ||
| if cubin_path is not None: | ||
| cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og | ||
| if extract is not None: | ||
| sass = extract(cubin_path, None) | ||
| pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) |
There was a problem hiding this comment.
This monkey-patches a global CUTLASS runtime function but doesn’t restore it if cute_compile_og(*args, **kwargs) raises. That can leave the process in a partially patched state and affect subsequent compilations/loads. Wrap the patched region in try/finally so load_cubin_module_data is reliably restored (and consider restoring even if SASS extraction fails).
| if cubin_path is not None: | |
| cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( | |
| load_cubin_module_data_patched, filepath=cubin_path | |
| ) | |
| output = cute_compile_og(*args, **kwargs) | |
| if cubin_path is not None: | |
| cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og | |
| if extract is not None: | |
| sass = extract(cubin_path, None) | |
| pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) | |
| # If no cubin path is set, do not monkey-patch; just call the original. | |
| if cubin_path is None: | |
| return cute_compile_og(*args, **kwargs) | |
| # Monkey-patch the CUTLASS loader so we can dump the cubin to a file, | |
| # and be sure to restore it even if compilation fails. | |
| cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( | |
| load_cubin_module_data_patched, filepath=cubin_path | |
| ) | |
| try: | |
| output = cute_compile_og(*args, **kwargs) | |
| finally: | |
| # Always restore the original loader, even on exceptions. | |
| cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og | |
| if extract is not None: | |
| sass = extract(cubin_path, None) | |
| pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) |
| if intermediate_dtype is not None: | ||
| attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) | ||
| output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) |
There was a problem hiding this comment.
The intermediate_dtype cast is currently a no-op: it converts attention_drop to intermediate_dtype and immediately converts it back to the original dtype. If the intent is to perform the einsum in higher precision (or a specific dtype), keep attention_drop in intermediate_dtype until after the matmul/accumulation step (and then cast the output back if needed).
| if intermediate_dtype is not None: | |
| attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) | |
| output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) | |
| orig_attention_dtype = attention_drop.dtype | |
| if intermediate_dtype is not None: | |
| attention_drop = attention_drop.to(intermediate_dtype) | |
| v_scaled = (v * dropout_scaling).to(intermediate_dtype) | |
| output = torch.einsum("bhts,bshd->bthd", attention_drop, v_scaled) | |
| output = output.to(orig_attention_dtype) | |
| else: | |
| output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) |
|
|
||
|
|
||
| def patch(): | ||
| """Install system ptxas hook. Call before importing cutlass.""" |
There was a problem hiding this comment.
The docstring says to call patch() before importing cutlass, but this module imports cutlass at import time (top of file), so callers cannot follow that instruction. Update the docstring to reflect the actual requirement (e.g., 'call before compiling kernels' / before CudaDialectJitCompiledFunction is used), or restructure so cutlass is imported lazily inside patch().
| """Install system ptxas hook. Call before importing cutlass.""" | |
| """Install system ptxas hook. | |
| Call this before any CUTLASS DSL kernels are compiled, i.e. before | |
| `CudaDialectJitCompiledFunction` is used. | |
| """ |
No description provided.