Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
412 commits
Select commit Hold shift + click to select a range
12fecd5
[Cute,Fwd,Sm100] Fix interface w score mod to get it to run
tridao Oct 24, 2025
09a3791
[Cute,Sm100] In gemm ptx, add to base smem_address instead
tridao Oct 24, 2025
c8e8766
[Cute,Bwd,Sm100] Make postprocessing work, add interface
tridao Oct 25, 2025
5ac18f7
[Cute,Bwd,Sm100] Simplify layouts in compute_loop
tridao Oct 25, 2025
583af0d
[Cute,Bwd,Sm100] Causal mask
tridao Oct 25, 2025
43f9b54
[Cute] Add store_shared_remote_fp32x4 util function
tridao Oct 26, 2025
8e0ebb4
[Cute,Bwd,Sm100] Tune registers
tridao Oct 26, 2025
543d873
[Cute,Sm100] acc_tmem_addr is Int32 instead of constexpr
tridao Oct 26, 2025
64033ad
[Cute,Bwd,Sm100] Reduce sync
tridao Oct 26, 2025
617c4c0
[Cute] Change utils.view_transpose back
tridao Oct 26, 2025
2aa711d
[Cute,Bwd,Sm100] Remove delay_tma_store option
tridao Oct 26, 2025
4f7c9bb
[Cute,Bwd,Sm100] Implement cluster
tridao Oct 26, 2025
8d26537
[Cute] Copy benchmark util functions to cute directory
tridao Oct 27, 2025
889e0de
[Cute,Bwd,Sm100] Use pipeline class for LSE and dPsum
tridao Oct 28, 2025
d6e3ab8
[Cute,Bwd,Sm100] Remove stage from sK, sV, tP, sdS
tridao Oct 28, 2025
2bedd5c
[Cute,Bwd,Sm100] Fix wrong LSE and dPsum indexing in load
tridao Oct 28, 2025
59ac9c3
[Cute] Blocks tweaks (#1964)
drisspg Oct 28, 2025
9198080
[Cute,Bwd,Sm100] Use TS MMA for dK
tridao Oct 28, 2025
56b9610
[Cute,Blocksparse] Group block sparse input torch tensors
tridao Oct 28, 2025
bea05b8
[Cute,Bwd,Sm100] Separate mma_S and mma_dP
tridao Oct 29, 2025
5117976
[Cute,Bwd,Sm100] Try LPTBwdScheduler
tridao Oct 29, 2025
e4e439d
[Cute,Bwd,Sm100] Try separating warps loading Q and dO
tridao Oct 29, 2025
003f236
BlockSparse Tweaks (#1970)
drisspg Oct 31, 2025
283913b
[Cute] Fix main (#1982)
drisspg Nov 3, 2025
f526b19
[Cute,Fwd,Sm100] Implement SplitKV (#1940)
timmy-feng Nov 5, 2025
46b6491
[Cute] Extract block-sparse utilities from SM80/90 (#1984)
drisspg Nov 5, 2025
e6fea4b
Enable python-3.10+ (#1998)
drisspg Nov 9, 2025
be8a89b
[Cute, Bwd, Sm100] Add GQA support (#2004)
jayhshah Nov 12, 2025
b5c1a60
[Cute,Fwd,Sm100] fix major regression with split kv (#2006)
jayhshah Nov 12, 2025
e27bd40
[CuTe DSL] Block sparsity computation kernel (#1983)
reubenconducts Nov 12, 2025
6e007f4
[Cute,Fwd,Sm100] Support paged attention (#1999)
timmy-feng Nov 14, 2025
16bd62c
[Cute] Add block-sparsity support to SM100 (#1985)
drisspg Nov 18, 2025
3ef9d30
[Cute,Sm100,Fwd] use correction warps for epi when not using TMA (#2014)
jayhshah Nov 19, 2025
34c2c7b
add fastdivmod for oob reads in mask_mods (#2020)
drisspg Nov 21, 2025
0a3931b
don't pass mask_fn to softmax_step generically (#2026)
jayhshah Nov 22, 2025
a41ac00
swap order of decorators (#2029)
anakinxc Nov 24, 2025
5f02da5
[Cute,Bwd,Sm100] enable deterministic mode for sm100 bwd and fix race…
jayhshah Nov 25, 2025
c00296a
Add LICENSE and AUTHORS to flash_attn/cute (#2032)
jduprat Nov 25, 2025
aa67af1
[Cute] Add authors
tridao Nov 25, 2025
1d2a4d0
[Cute,Fwd] enable mask mod without blocksparsity (#2031)
reubenconducts Nov 25, 2025
e872bf4
Bump pin (#2025)
drisspg Nov 25, 2025
5542be5
ruff all the smaller files (#2040)
drisspg Dec 2, 2025
b8d9d25
[Flash] Fix head dim 64 bwd (#2035)
drisspg Dec 2, 2025
8cba251
[Cute,Bwd,Sm100] Add local for sm100 bwd (#2046)
jayhshah Dec 6, 2025
290df71
Add hash attr to shortcut expensive check (#2048)
drisspg Dec 7, 2025
26b3a68
fixing cute bwd func def (#2056)
liangel-02 Dec 9, 2025
6aa559a
[CUTE] Allow grads to be preallocated (#2065)
drisspg Dec 15, 2025
6353ec0
[Cute,Fwd] Extend score_mod to variable sequence length (#2043)
reubenconducts Dec 15, 2025
0cf70d3
[CUTE] Seeing if tvvm reduces cpu overhead (#2042)
drisspg Dec 15, 2025
591d191
[FIRST] Fix softcap scoremod kwargs typo. (#2072)
LeoZDong Dec 16, 2025
6de5219
basics working (#2070)
drisspg Dec 16, 2025
1ad3be1
Blocksparse impl (#2085)
drisspg Dec 18, 2025
139a2a5
Fix IMA in fwd on m boundary (#2091)
drisspg Dec 20, 2025
5b45940
Update to dsl 3.4.3 (#2092)
drisspg Dec 22, 2025
2928a8e
fix shuffle sync for pack gqa epilogue (#2097)
jayhshah Dec 24, 2025
3cf926e
improve paged cpasync
v0i0 Dec 24, 2025
768888d
Enable Thor (#2108)
johnnynunez Dec 29, 2025
8ad861d
[Cute] Add quack as dependency
tridao Dec 31, 2025
37470a6
[Cute,Fwd,Sm90] Change PipelineTMAAsync sublass to signal per warp
tridao Jan 1, 2026
05c93b5
Add pack-gqa support for blcoksparse impl w/ braodcasted H dim (#2098)
drisspg Jan 4, 2026
83905fe
[Cute,Fwd] improved block sparsity (#2100)
reubenconducts Jan 5, 2026
0d707e4
[Cute] Fix minor lint issue in shuffle_sync
tridao Jan 5, 2026
c3899cb
[Cute,Fwd,Sm100] Support `q_stage=1` for inference (#1993)
timmy-feng Jan 8, 2026
bcd0c77
[Cute] Fix two tests that were failing (#2149)
henrylhtsang Jan 8, 2026
95c07cc
[Cute, Bwd, Sm100] Add varlen for sm100 bwd (#2150)
jayhshah Jan 9, 2026
0a0e27e
block-sparse backward SM90 (#2136)
drisspg Jan 10, 2026
4456c1e
score-mod backward SM90 (#2137)
drisspg Jan 10, 2026
647d2aa
[Cute] Clarify and fix subtle cachekey bug (#2143)
drisspg Jan 10, 2026
12bd2ec
[CUTE][SM90]Enable pack-gqa with broadcasted maskmods (#2145)
drisspg Jan 10, 2026
7ba04f4
[CUTE][SM90] GQA backward non deterministic (#2158)
drisspg Jan 10, 2026
0e13567
[Cute,Bwd,Sm100] fix seqused in varlen bwd (#2167)
jayhshah Jan 10, 2026
c87546c
[CUTE] Bump cutedsl to 4.3.5 (#2170)
drisspg Jan 12, 2026
03e9810
Merge pull request #2156 from v0i0/v0i0/improve-paged-ldgsts
v0i0 Jan 12, 2026
672ad62
[Cute,Flex] Add option to create and cache __cute_hash__ (#2171)
reubenconducts Jan 12, 2026
df5cf43
[Cute][Flex] Remove no longer needed contig (#2172)
drisspg Jan 12, 2026
84f9c6c
[Cute] update row_max before safe overwrite for online_softmax (#2174)
jayhshah Jan 13, 2026
331676d
[Cute][Flex] add back in contig (#2177)
drisspg Jan 15, 2026
35b1b01
[Cute][Flex]Add pack-gqa divmod (#2180)
drisspg Jan 15, 2026
e77e390
[Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)
timmy-feng Jan 15, 2026
8bf0592
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
195bbd3
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
1e65d93
switch from xor to mask_right & ~ mask_left
henrylhtsang Jan 16, 2026
8e418f8
flip in_bound to out_bound
henrylhtsang Jan 16, 2026
c20b65f
remove zero logic for right_s and left_s
henrylhtsang Jan 16, 2026
cb73292
remove 24 clamp
henrylhtsang Jan 16, 2026
7613e22
doc
henrylhtsang Jan 16, 2026
13e0cbd
lint
henrylhtsang Jan 16, 2026
91d3e44
added back clamp to avoid "OverflowError: Python int too large to con…
henrylhtsang Jan 16, 2026
5e266c0
add comment
henrylhtsang Jan 16, 2026
789abcd
Merge pull request #2185 from henrylhtsang/test_local_r2p
v0i0 Jan 17, 2026
bfd054b
[Cute][Flex] Fix expanded tensor bug (#2189)
drisspg Jan 17, 2026
e58d8c6
[Cute, SM90] fix fwd varlen Cute implementation bug for H100 (#2194)
KareemMusleh Jan 20, 2026
0386656
[Cute][Flex] Allow q_offset 1 and add block-sizes to disambiguate edg…
drisspg Jan 22, 2026
f3452a1
[Flex][SM100] Replay expand fix on sm100 (#2209)
drisspg Jan 26, 2026
f6a70e2
[DSL] Optionally patch cute-dsl to use system's ptxas
tridao Jan 27, 2026
007aa7a
Fix shared-memory race (#2229)
drisspg Feb 4, 2026
edf9205
short readme for flex flash (#2231)
v0i0 Feb 5, 2026
a874650
hdim 192 smem fix (#2235)
jayhshah Feb 5, 2026
e65910a
[CUTE]Bump to Cutedsl (#2216)
drisspg Feb 8, 2026
35fd3bb
[DSL] Replace old fence with cute.arch.fence_view_async_shared()
tridao Feb 8, 2026
860b552
[DSL]Replace utils.{fma,mul,add}_packed_f32x2 with cute.arch version
tridao Feb 8, 2026
e0d67de
[DSL] Remove coord_offset_i64, domain_offset_i64, elem_pointer_i64
tridao Feb 8, 2026
842de4d
[Sm90] Use functions from quack.sm90_utils
tridao Feb 8, 2026
54339ae
[DSL] Use cute.arch.warp_reduction_{max,sum}
tridao Feb 8, 2026
3cd2e74
[Layout] Use reshape_acc_to_mn and reshape_acc_to_frgA from quack
tridao Feb 8, 2026
bebc065
[Layout] Use quack.layout_utils.mma_partition_C_vec
tridao Feb 8, 2026
41ebfc7
[DSL] Use cute.math.{exp2,log2,log}
tridao Feb 8, 2026
4182921
[Layout] Use layout_utils.transpose_view and select from quack
tridao Feb 8, 2026
2237ac3
[Bwd,Sm90] Use quack.copy_utils
tridao Feb 8, 2026
7e42b38
[Bwd,Sm100] Shorten PipelineTmaUmma create
tridao Feb 8, 2026
044e510
[Bwd,Sm90] Have score_mod and score_mod_bwd as partial functions
tridao Feb 8, 2026
215188d
[DSL] warpgroup_reg_alloc -> setmaxregister_increase
tridao Feb 8, 2026
c79095e
Fix Hopper tests (#2242)
drisspg Feb 8, 2026
b346efc
[Bwd,Sm90] For dQ, move wait_group before TMA atomic add
tridao Feb 11, 2026
37ee026
[Cute,Flex,Fwd] Allow vectorized score_mod definitions (#2236)
reubenconducts Feb 11, 2026
244dd1a
[Bwd,Sm90] Simplify dK/dV R2S copy
tridao Feb 14, 2026
b8c4c36
[DSL] Use quack.cute_dsl_utils.ParamsBase
tridao Feb 14, 2026
06cefb6
[Cute][Flex] Fix kernel hang w/ multiple empty tiles (#2258)
drisspg Feb 16, 2026
2606352
Bump to 4.4.0 cute dsl pin (#2262)
drisspg Feb 18, 2026
631c83a
BWD sm100 2cta (#2202)
tzadouri Feb 20, 2026
6d0054b
[Bwd,Sm100] Fix num reg variables
tridao Feb 20, 2026
287af25
[Cute] Change compute_capability to arch
tridao Feb 20, 2026
bc21012
[Bwd,Postprocess] Update api to cute.arch.fence_view_async_shared
tridao Feb 21, 2026
52a6a61
[Fwd,Sm100] Disable ex2 emulation for Sm103
tridao Feb 21, 2026
4cc99de
[Dep] Update quack dependency to 0.2.10
tridao Feb 21, 2026
2de12f2
[Fwd,Sm100] Use arch from BaseDSL._get_dsl().get_arch_enum()
tridao Feb 21, 2026
3c8f1a9
[Fwd,Sm100] Clean up
tridao Feb 22, 2026
9d21523
[Bwd,Sm100] Put 2CTA asserts under if const_expr
tridao Feb 22, 2026
7830384
[Fwd,Sm100] Refactor _store_O_to_gemm into a separate method
tridao Feb 22, 2026
7c5a65b
[Fwd,Sm100] Simplify tensor layouts
tridao Feb 22, 2026
fab75ff
[Fwd,Sm100] Use pipeline_kv in load_KV instead of raw mbarrier
tridao Feb 23, 2026
1d3ff0f
[DSL] Don't need to parse swizzle from str anymore
tridao Feb 23, 2026
bf39167
[Fwd,Sm100] Use position_independent for sO, more clean up
tridao Feb 23, 2026
f3227f1
[Fwd,Sm100] Use pipeline abstraction for loading Q and KV
tridao Feb 23, 2026
8e9fa0f
[Cute] Handle window_size=(-1, -1) for non-local attention (#2251)
henrylhtsang Feb 23, 2026
ba363aa
[Cute,Sm100,Bwd] Add hdim 192 hdimv 128 backward for sm100 (#2270)
jayhshah Feb 25, 2026
cad06ac
[Fwd,Sm100] Only 1 thread per warp signals mbar_P_full_2
tridao Feb 23, 2026
90cb5fa
[Fwd,Sm100] Use pipeline abstraction for S_full & P_full_O_rescaled
tridao Feb 25, 2026
bd23f88
[Fwd,Sm100] Use pipeline abstraction for softmax-correction mbarrier
tridao Feb 25, 2026
e87d281
[Fwd,Sm100] Use pipeline abstraction for correction-epilogue
tridao Feb 25, 2026
53fddb7
[Fwd,Sm100] Tune registers
tridao Feb 25, 2026
5b7393c
guard use_2cta_instrs on sm90 (#2274)
reubenconducts Feb 25, 2026
6f30f07
[cute] Add return_lse (#2271)
erikwijmans Feb 26, 2026
1d3d290
[Fwd,Sm100] Use pipeline abstraction for O_full
tridao Feb 25, 2026
52ad947
[Fwd,Sm100] Use pipeline abstraction for mbar_P_full_2
tridao Feb 26, 2026
02fbbd3
[Fwd,Sm100] Use TmemAllocator
tridao Feb 26, 2026
b2edc9d
[Fwd,Sm100] Set split_P_arrive as a tunable parameter
tridao Feb 26, 2026
caac1f2
[Fwd,Sm100] Use pipeline abstraction for s0_s1_sequence
tridao Feb 26, 2026
048ae57
[Fwd,Sm100] Fix tScS partitioning for score_mod
tridao Feb 26, 2026
59f1380
fix mask mod bugs (#2276)
reubenconducts Feb 26, 2026
1920b13
[Cute,Sm100,Bwd] Fix and enable 2CTA path for hdim 128 backward (#2280)
jayhshah Feb 28, 2026
9dd8194
[Fwd,Sm100] Change layout of gQ and gO to have q_stage
tridao Feb 27, 2026
a9d62f8
[Fwd,Sm100] Pass cta_layout_vmnk to pipelines
tridao Feb 27, 2026
1c5259f
[Fwd,Sm100] Gate mma with is_leader_cta
tridao Feb 27, 2026
ce6ade2
[Fwd,Sm100] Take into account mma_tile_coord_v when reading/writing
tridao Feb 27, 2026
ceac28c
[Fwd,Sm100] Add pipeline.producer_tail
tridao Feb 27, 2026
112bc15
[Fwd,Sm100] Enable 2CTA for hdim128 noncausal
tridao Feb 28, 2026
436b941
Bump to 4.4.1 to avoid segfault (#2291)
drisspg Feb 28, 2026
22e4ba1
Fix sm100 fwd missing tSrQs init regression (#2293)
drisspg Mar 1, 2026
1b6750a
[Scheduler] Revert SingleTileScheduler to get block_idx
tridao Mar 1, 2026
dd507a6
[CuTe] Include broadcast dims in backward compile cache keys (#2298)
bonpyt Mar 3, 2026
23850ad
[Fwd,Sm100] Use NamedBarrier to signal softmax -> corr warps
tridao Mar 3, 2026
2acd159
[Fwd,Sm100] Add polynomials degree 1 - 5
tridao Mar 3, 2026
7cf1b92
[Fwd,Sm100] Switch back to poly degree 3
tridao Mar 3, 2026
35a2ad1
[Fwd,Sm100] Compute kv_stage based on hdim instead of hard-coding
tridao Mar 3, 2026
fac979d
[Cute][Testing] Add fake tensor mode support for compile-only test pa…
Alkaid-Benetnash Mar 3, 2026
82b466f
Enable hdim=96 bwd (#2302)
v0i0 Mar 3, 2026
266e29c
Fix GQA crash in cute FLASH backend: init load_Q before conditional (…
platers Mar 3, 2026
d4331ee
[Fwd,Sm100] Be more explicit when loading Q
tridao Mar 3, 2026
87aa3ec
[Fwd,Sm100] Tune ex2_emu_freq
tridao Mar 3, 2026
472d91f
[Fwd,Sm100] Tweak ptx for gemm
tridao Mar 3, 2026
90f46a9
[Bench] Enable benchmarking bwd with headdim != headdim_v
tridao Mar 3, 2026
adba69b
fix paged kv (#2303)
jayhshah Mar 3, 2026
2fc03d8
Add FA4 publishing strategy (#2282)
drisspg Mar 3, 2026
8daecf3
[Cute][Testing] Add persistent compile cache for cutedsl AOT compilat…
Alkaid-Benetnash Mar 4, 2026
3780a4b
[Bench] Add reference attn implementation
tridao Mar 5, 2026
bea3263
[Bwd,Sm100] Use TmemAllocator
tridao Mar 5, 2026
8f60d23
Change PyPI name to flash-attn4
tridao Mar 5, 2026
f3f6ee5
Try again
tridao Mar 5, 2026
cb323bb
Change PyPI package name to fa4
tridao Mar 5, 2026
4d5336e
[Bwd,Sm100] Add fence_view_async_shared before LSE release
tridao Mar 5, 2026
8cce7a7
Change PyPI name back to flash-attn-4
tridao Mar 5, 2026
8191505
[Bwd,Sm103] Fix postprocess for 2cta_instrs
tridao Mar 7, 2026
cb6022f
[Sm100] Fix tmem delloc: sync before dealloc
tridao Mar 8, 2026
99c96d1
[Test] Skip non-files in cache_utils.py
tridao Mar 8, 2026
e7467fe
Add more code authors
tridao Mar 8, 2026
0552751
Nicer headdim error message (#2227)
drisspg Mar 9, 2026
0862177
[Fwd,Sm100] Extract named barriers (#2309)
drisspg Mar 9, 2026
b002ba5
Change 2cta opt in to have min seqlen > 2*m_block_size (#2320)
drisspg Mar 9, 2026
a2ef2a4
[CuteDSL][SM90] varlen bwd works (#2275)
KareemMusleh Mar 10, 2026
b61657e
[Fwd,Sm90] Move FwdSm90 to a separate file
tridao Mar 8, 2026
f60628c
[GQA] Refactor pack_gqa_layout into a helper function
tridao Mar 8, 2026
db58018
[Fwd] Refactor compute_softmax_scale_log2 and comptue_fastdiv_mods
tridao Mar 8, 2026
845a6de
[GQA] Add unpack_gqa_layout
tridao Mar 8, 2026
9422ff1
Add Logging helper (#2327)
drisspg Mar 11, 2026
31e65e3
[Sm80] basic fix for new api (#2297)
zhuochenKIDD Mar 11, 2026
0c184b3
fix: duplicate softmax_scale param (#2328)
NanoCode012 Mar 11, 2026
47d8f85
[Bwd] Compile bwd_preprocess with cute fake tensors
tridao Mar 10, 2026
ec18b53
[Bwd] Clean up bwd_preprocess kernel
tridao Mar 10, 2026
0d4f85e
[Fwd] Port SeqlenInfoQKNewK from C++ to cute-dsl
tridao Mar 11, 2026
e7837fb
[Fwd] Clean up fwd_combine kernel, compile w cute fake tensors
tridao Mar 11, 2026
8adabb4
[Fwd,Sm80] Fix import of BlockSparseTensors
tridao Mar 11, 2026
3c2448f
[Fwd,Sm90] Tune tile size for hdim 64, 96, 128
tridao Mar 11, 2026
2f9c6ec
[Bwd,Sm90] Implement deterministic
tridao Mar 11, 2026
72e4907
[Cute,Sm100] Introduce a flexible lambda-based R2P masking (#2313)
Alkaid-Benetnash Mar 12, 2026
69f6b17
[Bwd] Compile bwd_postprocess with cute fake tensors
tridao Mar 11, 2026
4f08f35
SM120 forward pass (Blackwell GeForce / DGX Spark) (#2329)
blake-snc Mar 12, 2026
5d7504c
rename logging module (#2335)
Luosuu Mar 12, 2026
325022e
Add tile parameter to SeqlenInfo creation (#2337)
risan-raja Mar 12, 2026
d1429b1
Fix (#2338)
MatthewBonanni Mar 12, 2026
a6c0517
[Bwd,Sm90] Pass tile_m to bwd_preprocess, enable varlen tests
tridao Mar 12, 2026
3f1eecd
[Fwd,Sm90] Use mask_r2p_lamba
tridao Mar 12, 2026
a5cf610
[Bwd,Sm90] Fix varlen scheduler
tridao Mar 12, 2026
a55b9cb
[Bwd,Sm90] Enable varlen tests with seqused_k
tridao Mar 12, 2026
e638a22
[Bwd,Sm120] Add SM120 backward pass support (#2330)
blake-snc Mar 12, 2026
e0e8a43
[Bwd,Sm90] Enable local
tridao Mar 12, 2026
25466ed
fix tdKrdS typo (#2341)
henrylhtsang Mar 13, 2026
13327e4
[Bwd,Sm90] Implement ShuffleLSE
tridao Mar 13, 2026
c911921
[Bwd] Support gradient wrt LSE
tridao Mar 13, 2026
e09ed3b
Add SM120 varlen attention support (#2333)
blake-snc Mar 13, 2026
39e4dc4
[Bwd] Use ragged tensor for TMA dKV when varlen
tridao Mar 13, 2026
f7755eb
[Bwd,Sm90] Set bwd configs, make hdim64 bwd work
tridao Mar 13, 2026
f6129aa
fix the create_ragged_tensor_for_tma issue (#2345)
rainj-me Mar 14, 2026
9fe2595
[Fwd,Sm90] Implement rescale_O_before_gemm, enable hdim 192 & 256
tridao Mar 14, 2026
0cabec3
[Fwd,Sm90] Add hdim 192 and 256 to _validate_head_dims
tridao Mar 14, 2026
08c5d74
Support CPU-only compilation and overriding arch
tridao Mar 14, 2026
43661f7
[Bwd,Sm90] Implement PDL between bwd_preprocess and bwd
tridao Mar 15, 2026
ca70809
[Bench] Refactor benchmark script to take args from cmdline
tridao Mar 15, 2026
df94d15
[Bwd,Sm90] Implement dQ_single_wg
tridao Mar 15, 2026
c44aa4a
[Bwd,Sm90] Make hdim 96 work
tridao Mar 15, 2026
e4261ed
[Sm90] Add script to search fwd bwd configs
tridao Mar 15, 2026
9e93505
[Sm90] Clean up sm90_config_search.py
tridao Mar 15, 2026
28c5782
[Bwd,Sm90] Implement hdim 192-128 and hdim 192
tridao Mar 15, 2026
d2c80ef
[Sm90] Fix test_mask_mod and bwd block-sparse kwarg mismatch (#2365)
henrylhtsang Mar 17, 2026
2c1ff34
[Cute, Testing] Move stream parameter to end of kernel __call__ signa…
Alkaid-Benetnash Mar 18, 2026
e707f7b
[Cute] Bump cutedsl to 4.4.2 and remove prior aot workarounds (#2370)
Alkaid-Benetnash Mar 18, 2026
e918ba0
[Cute] fix: FA4 paged attention kv load for DeepSeek (192,128) on SM1…
Luosuu Mar 18, 2026
d221956
[Fwd,Sm90] Add paged KV attention support (tma and cp.async) (#2360)
henrylhtsang Mar 18, 2026
e083e1e
[CuTe,Flex] limit vec_size to 2 for score mod when not on Sm100 (#2371)
reubenconducts Mar 18, 2026
d1e12c3
[Fwd,Sm90] Use pipeline_q instead of raw mbarrier
tridao Mar 19, 2026
cf63ded
[Fwd,Sm90] Use TMA for O when PackGQA, keep no. TMA dim when PackGQA
tridao Mar 19, 2026
7c52f15
Support 2CTA for sliding window hdim 192 (#2347)
Inodayy Mar 19, 2026
4a4c1aa
[Fwd,Sm90] Use TMA for O when varlen
tridao Mar 19, 2026
522c0f6
support irregular q to kv head ratio (#2186)
timmy-feng Mar 20, 2026
0a136e6
[Pipeline] Refactor
tridao Mar 20, 2026
82ccf65
[Fwd,Sm90] Use producer instead of mma warps to load Q when !TMA_Q
tridao Mar 20, 2026
b3bd27d
[Fwd,Sm90] Implement PipelineAsync with elect_one for commit/release
tridao Mar 20, 2026
16fc364
[DSL] Remove ArgumentsBase
tridao Mar 22, 2026
532293a
Add 'flash_sparse_attn/ops/cute/' from commit '16fc364769fa6329b73a55…
LoserCheems Mar 24, 2026
52b40d9
Add sync_cute_subtree scripts for managing upstream repository integr…
LoserCheems Mar 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flash_sparse_attn/ops/cute/.flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
max-line-length = 100
# W503: line break before binary operator
ignore = E731, E741, F841, W503
8 changes: 8 additions & 0 deletions flash_sparse_attn/ops/cute/AUTHORS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Tri Dao
Jay Shah
Ted Zadouri
Markus Hoehnerbach
Vijay Thakkar
Timmy Liu
Driss Guessous
Reuben Stern
29 changes: 29 additions & 0 deletions flash_sparse_attn/ops/cute/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5 changes: 5 additions & 0 deletions flash_sparse_attn/ops/cute/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
global-exclude *.egg-info/*
prune flash_attn_4.egg-info
prune flash_attn.egg-info
prune build
prune dist
26 changes: 26 additions & 0 deletions flash_sparse_attn/ops/cute/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# FlashAttention-4 (CuTeDSL)

FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs.

## Installation

```sh
pip install flash-attn-4
```

## Usage

```python
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func

out = flash_attn_func(q, k, v, causal=True)
```

## Development

```sh
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install -e "flash_attn/cute[dev]"
pytest tests/cute/
```
26 changes: 26 additions & 0 deletions flash_sparse_attn/ops/cute/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Flash Attention CUTE (CUDA Template Engine) implementation."""

from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("fa4")
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distribution metadata lookup uses version(\"fa4\"), but the added pyproject.toml declares name = \"flash-attn-4\". With this mismatch, __version__ will fall back to 0.0.0 even when installed. Update the queried distribution name to match the project name (or align the project name if fa4 is intended).

Suggested change
__version__ = version("fa4")
__version__ = version("flash-attn-4")

Copilot uses AI. Check for mistakes.
except PackageNotFoundError:
__version__ = "0.0.0"

import cutlass.cute as cute

from .interface import (
flash_attn_func,
flash_attn_varlen_func,
)

from flash_attn.cute.cute_dsl_utils import cute_compile_patched

# Patch cute.compile to optionally dump SASS
cute.compile = cute_compile_patched


__all__ = [
"flash_attn_func",
"flash_attn_varlen_func",
]
103 changes: 103 additions & 0 deletions flash_sparse_attn/ops/cute/ampere_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2025, Tri Dao.
from typing import Type, Callable, Optional

import cutlass
import cutlass.cute as cute


def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
dtype_byte = cutlass.const_expr(dtype.width // 8)
bytes_per_row = cutlass.const_expr(k_dim * dtype_byte)
smem_k_block_size = (
cutlass.const_expr(
128
if bytes_per_row % 128 == 0
else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
)
// dtype_byte
)
swizzle_bits = (
4
if smem_k_block_size == 128
else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
)
swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
return cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
0,
cute.make_ordered_layout(
(8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)
),
)


@cute.jit
def gemm(
tiled_mma: cute.TiledMma,
acc: cute.Tensor,
tCrA: cute.Tensor,
tCrB: cute.Tensor,
tCsA: cute.Tensor,
tCsB: cute.Tensor,
smem_thr_copy_A: cute.TiledCopy,
smem_thr_copy_B: cute.TiledCopy,
hook_fn: Optional[Callable] = None,
A_in_regs: cutlass.Constexpr[bool] = False,
B_in_regs: cutlass.Constexpr[bool] = False,
swap_AB: cutlass.Constexpr[bool] = False,
) -> None:
if cutlass.const_expr(swap_AB):
gemm(
tiled_mma,
acc,
tCrB,
tCrA,
tCsB,
tCsA,
smem_thr_copy_B,
smem_thr_copy_A,
hook_fn,
A_in_regs=B_in_regs,
B_in_regs=A_in_regs,
swap_AB=False,
)
else:
tCrA_copy_view = smem_thr_copy_A.retile(tCrA)
tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
if cutlass.const_expr(not A_in_regs):
cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0])
if cutlass.const_expr(not B_in_regs):
cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])):
if k < cute.size(tCsA.shape[2]) - 1:
if cutlass.const_expr(not A_in_regs):
cute.copy(
smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]
)
if cutlass.const_expr(not B_in_regs):
cute.copy(
smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]
)
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
if cutlass.const_expr(k == 0 and hook_fn is not None):
hook_fn()


@cute.jit
def gemm_rs(
tiled_mma: cute.TiledMma,
acc: cute.Tensor,
tCrA: cute.Tensor,
tCrB: cute.Tensor,
tCsB: cute.Tensor,
smem_thr_copy_B: cute.TiledCopy,
hook_fn: Optional[Callable] = None,
) -> None:
tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1):
cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1])
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
if cutlass.const_expr(k == 0 and hook_fn is not None):
hook_fn()
71 changes: 71 additions & 0 deletions flash_sparse_attn/ops/cute/barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import cutlass
import cutlass.cute as cute
from cutlass import Int32
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import llvm


@dsl_user_op
def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
state = llvm.inline_asm(
T.i32(),
[lock_ptr_i64],
"ld.global.acquire.gpu.b32 $0, [$1];",
"=r,l",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
return cutlass.Int32(state)


@dsl_user_op
def red_relaxed(
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
) -> None:
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
llvm.inline_asm(
None,
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
"red.relaxed.gpu.global.add.s32 [$0], $1;",
"l,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)


@dsl_user_op
def red_release(
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
) -> None:
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
llvm.inline_asm(
None,
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
"red.release.gpu.global.add.s32 [$0], $1;",
"l,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)


@cute.jit
def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
flag_ptr = lock_ptr + flag_offset
if thread_idx == 0:
read_val = Int32(0)
while read_val != val:
read_val = ld_acquire(flag_ptr)


@cute.jit
def arrive_inc(
lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
) -> None:
flag_ptr = lock_ptr + flag_offset
if thread_idx == 0:
red_release(flag_ptr, val)
# red_relaxed(flag_ptr, val)
Loading
Loading