-
Notifications
You must be signed in to change notification settings - Fork 53
Init cute version #261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Init cute version #261
Changes from all commits
Commits
Show all changes
412 commits
Select commit
Hold shift + click to select a range
12fecd5
[Cute,Fwd,Sm100] Fix interface w score mod to get it to run
tridao 09a3791
[Cute,Sm100] In gemm ptx, add to base smem_address instead
tridao c8e8766
[Cute,Bwd,Sm100] Make postprocessing work, add interface
tridao 5ac18f7
[Cute,Bwd,Sm100] Simplify layouts in compute_loop
tridao 583af0d
[Cute,Bwd,Sm100] Causal mask
tridao 43f9b54
[Cute] Add store_shared_remote_fp32x4 util function
tridao 8e0ebb4
[Cute,Bwd,Sm100] Tune registers
tridao 543d873
[Cute,Sm100] acc_tmem_addr is Int32 instead of constexpr
tridao 64033ad
[Cute,Bwd,Sm100] Reduce sync
tridao 617c4c0
[Cute] Change utils.view_transpose back
tridao 2aa711d
[Cute,Bwd,Sm100] Remove delay_tma_store option
tridao 4f7c9bb
[Cute,Bwd,Sm100] Implement cluster
tridao 8d26537
[Cute] Copy benchmark util functions to cute directory
tridao 889e0de
[Cute,Bwd,Sm100] Use pipeline class for LSE and dPsum
tridao d6e3ab8
[Cute,Bwd,Sm100] Remove stage from sK, sV, tP, sdS
tridao 2bedd5c
[Cute,Bwd,Sm100] Fix wrong LSE and dPsum indexing in load
tridao 59ac9c3
[Cute] Blocks tweaks (#1964)
drisspg 9198080
[Cute,Bwd,Sm100] Use TS MMA for dK
tridao 56b9610
[Cute,Blocksparse] Group block sparse input torch tensors
tridao bea05b8
[Cute,Bwd,Sm100] Separate mma_S and mma_dP
tridao 5117976
[Cute,Bwd,Sm100] Try LPTBwdScheduler
tridao e4e439d
[Cute,Bwd,Sm100] Try separating warps loading Q and dO
tridao 003f236
BlockSparse Tweaks (#1970)
drisspg 283913b
[Cute] Fix main (#1982)
drisspg f526b19
[Cute,Fwd,Sm100] Implement SplitKV (#1940)
timmy-feng 46b6491
[Cute] Extract block-sparse utilities from SM80/90 (#1984)
drisspg e6fea4b
Enable python-3.10+ (#1998)
drisspg be8a89b
[Cute, Bwd, Sm100] Add GQA support (#2004)
jayhshah b5c1a60
[Cute,Fwd,Sm100] fix major regression with split kv (#2006)
jayhshah e27bd40
[CuTe DSL] Block sparsity computation kernel (#1983)
reubenconducts 6e007f4
[Cute,Fwd,Sm100] Support paged attention (#1999)
timmy-feng 16bd62c
[Cute] Add block-sparsity support to SM100 (#1985)
drisspg 3ef9d30
[Cute,Sm100,Fwd] use correction warps for epi when not using TMA (#2014)
jayhshah 34c2c7b
add fastdivmod for oob reads in mask_mods (#2020)
drisspg 0a3931b
don't pass mask_fn to softmax_step generically (#2026)
jayhshah a41ac00
swap order of decorators (#2029)
anakinxc 5f02da5
[Cute,Bwd,Sm100] enable deterministic mode for sm100 bwd and fix race…
jayhshah c00296a
Add LICENSE and AUTHORS to flash_attn/cute (#2032)
jduprat aa67af1
[Cute] Add authors
tridao 1d2a4d0
[Cute,Fwd] enable mask mod without blocksparsity (#2031)
reubenconducts e872bf4
Bump pin (#2025)
drisspg 5542be5
ruff all the smaller files (#2040)
drisspg b8d9d25
[Flash] Fix head dim 64 bwd (#2035)
drisspg 8cba251
[Cute,Bwd,Sm100] Add local for sm100 bwd (#2046)
jayhshah 290df71
Add hash attr to shortcut expensive check (#2048)
drisspg 26b3a68
fixing cute bwd func def (#2056)
liangel-02 6aa559a
[CUTE] Allow grads to be preallocated (#2065)
drisspg 6353ec0
[Cute,Fwd] Extend score_mod to variable sequence length (#2043)
reubenconducts 0cf70d3
[CUTE] Seeing if tvvm reduces cpu overhead (#2042)
drisspg 591d191
[FIRST] Fix softcap scoremod kwargs typo. (#2072)
LeoZDong 6de5219
basics working (#2070)
drisspg 1ad3be1
Blocksparse impl (#2085)
drisspg 139a2a5
Fix IMA in fwd on m boundary (#2091)
drisspg 5b45940
Update to dsl 3.4.3 (#2092)
drisspg 2928a8e
fix shuffle sync for pack gqa epilogue (#2097)
jayhshah 3cf926e
improve paged cpasync
v0i0 768888d
Enable Thor (#2108)
johnnynunez 8ad861d
[Cute] Add quack as dependency
tridao 37470a6
[Cute,Fwd,Sm90] Change PipelineTMAAsync sublass to signal per warp
tridao 05c93b5
Add pack-gqa support for blcoksparse impl w/ braodcasted H dim (#2098)
drisspg 83905fe
[Cute,Fwd] improved block sparsity (#2100)
reubenconducts 0d707e4
[Cute] Fix minor lint issue in shuffle_sync
tridao c3899cb
[Cute,Fwd,Sm100] Support `q_stage=1` for inference (#1993)
timmy-feng bcd0c77
[Cute] Fix two tests that were failing (#2149)
henrylhtsang 95c07cc
[Cute, Bwd, Sm100] Add varlen for sm100 bwd (#2150)
jayhshah 0a0e27e
block-sparse backward SM90 (#2136)
drisspg 4456c1e
score-mod backward SM90 (#2137)
drisspg 647d2aa
[Cute] Clarify and fix subtle cachekey bug (#2143)
drisspg 12bd2ec
[CUTE][SM90]Enable pack-gqa with broadcasted maskmods (#2145)
drisspg 7ba04f4
[CUTE][SM90] GQA backward non deterministic (#2158)
drisspg 0e13567
[Cute,Bwd,Sm100] fix seqused in varlen bwd (#2167)
jayhshah c87546c
[CUTE] Bump cutedsl to 4.3.5 (#2170)
drisspg 03e9810
Merge pull request #2156 from v0i0/v0i0/improve-paged-ldgsts
v0i0 672ad62
[Cute,Flex] Add option to create and cache __cute_hash__ (#2171)
reubenconducts df5cf43
[Cute][Flex] Remove no longer needed contig (#2172)
drisspg 84f9c6c
[Cute] update row_max before safe overwrite for online_softmax (#2174)
jayhshah 331676d
[Cute][Flex] add back in contig (#2177)
drisspg 35b1b01
[Cute][Flex]Add pack-gqa divmod (#2180)
drisspg e77e390
[Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)
timmy-feng 8bf0592
Add R2P dual bound masking for local attention
henrylhtsang 195bbd3
Add R2P dual bound masking for local attention
henrylhtsang 1e65d93
switch from xor to mask_right & ~ mask_left
henrylhtsang 8e418f8
flip in_bound to out_bound
henrylhtsang c20b65f
remove zero logic for right_s and left_s
henrylhtsang cb73292
remove 24 clamp
henrylhtsang 7613e22
doc
henrylhtsang 13e0cbd
lint
henrylhtsang 91d3e44
added back clamp to avoid "OverflowError: Python int too large to con…
henrylhtsang 5e266c0
add comment
henrylhtsang 789abcd
Merge pull request #2185 from henrylhtsang/test_local_r2p
v0i0 bfd054b
[Cute][Flex] Fix expanded tensor bug (#2189)
drisspg e58d8c6
[Cute, SM90] fix fwd varlen Cute implementation bug for H100 (#2194)
KareemMusleh 0386656
[Cute][Flex] Allow q_offset 1 and add block-sizes to disambiguate edg…
drisspg f3452a1
[Flex][SM100] Replay expand fix on sm100 (#2209)
drisspg f6a70e2
[DSL] Optionally patch cute-dsl to use system's ptxas
tridao 007aa7a
Fix shared-memory race (#2229)
drisspg edf9205
short readme for flex flash (#2231)
v0i0 a874650
hdim 192 smem fix (#2235)
jayhshah e65910a
[CUTE]Bump to Cutedsl (#2216)
drisspg 35fd3bb
[DSL] Replace old fence with cute.arch.fence_view_async_shared()
tridao 860b552
[DSL]Replace utils.{fma,mul,add}_packed_f32x2 with cute.arch version
tridao e0d67de
[DSL] Remove coord_offset_i64, domain_offset_i64, elem_pointer_i64
tridao 842de4d
[Sm90] Use functions from quack.sm90_utils
tridao 54339ae
[DSL] Use cute.arch.warp_reduction_{max,sum}
tridao 3cd2e74
[Layout] Use reshape_acc_to_mn and reshape_acc_to_frgA from quack
tridao bebc065
[Layout] Use quack.layout_utils.mma_partition_C_vec
tridao 41ebfc7
[DSL] Use cute.math.{exp2,log2,log}
tridao 4182921
[Layout] Use layout_utils.transpose_view and select from quack
tridao 2237ac3
[Bwd,Sm90] Use quack.copy_utils
tridao 7e42b38
[Bwd,Sm100] Shorten PipelineTmaUmma create
tridao 044e510
[Bwd,Sm90] Have score_mod and score_mod_bwd as partial functions
tridao 215188d
[DSL] warpgroup_reg_alloc -> setmaxregister_increase
tridao c79095e
Fix Hopper tests (#2242)
drisspg b346efc
[Bwd,Sm90] For dQ, move wait_group before TMA atomic add
tridao 37ee026
[Cute,Flex,Fwd] Allow vectorized score_mod definitions (#2236)
reubenconducts 244dd1a
[Bwd,Sm90] Simplify dK/dV R2S copy
tridao b8c4c36
[DSL] Use quack.cute_dsl_utils.ParamsBase
tridao 06cefb6
[Cute][Flex] Fix kernel hang w/ multiple empty tiles (#2258)
drisspg 2606352
Bump to 4.4.0 cute dsl pin (#2262)
drisspg 631c83a
BWD sm100 2cta (#2202)
tzadouri 6d0054b
[Bwd,Sm100] Fix num reg variables
tridao 287af25
[Cute] Change compute_capability to arch
tridao bc21012
[Bwd,Postprocess] Update api to cute.arch.fence_view_async_shared
tridao 52a6a61
[Fwd,Sm100] Disable ex2 emulation for Sm103
tridao 4cc99de
[Dep] Update quack dependency to 0.2.10
tridao 2de12f2
[Fwd,Sm100] Use arch from BaseDSL._get_dsl().get_arch_enum()
tridao 3c8f1a9
[Fwd,Sm100] Clean up
tridao 9d21523
[Bwd,Sm100] Put 2CTA asserts under if const_expr
tridao 7830384
[Fwd,Sm100] Refactor _store_O_to_gemm into a separate method
tridao 7c5a65b
[Fwd,Sm100] Simplify tensor layouts
tridao fab75ff
[Fwd,Sm100] Use pipeline_kv in load_KV instead of raw mbarrier
tridao 1d3ff0f
[DSL] Don't need to parse swizzle from str anymore
tridao bf39167
[Fwd,Sm100] Use position_independent for sO, more clean up
tridao f3227f1
[Fwd,Sm100] Use pipeline abstraction for loading Q and KV
tridao 8e9fa0f
[Cute] Handle window_size=(-1, -1) for non-local attention (#2251)
henrylhtsang ba363aa
[Cute,Sm100,Bwd] Add hdim 192 hdimv 128 backward for sm100 (#2270)
jayhshah cad06ac
[Fwd,Sm100] Only 1 thread per warp signals mbar_P_full_2
tridao 90cb5fa
[Fwd,Sm100] Use pipeline abstraction for S_full & P_full_O_rescaled
tridao bd23f88
[Fwd,Sm100] Use pipeline abstraction for softmax-correction mbarrier
tridao e87d281
[Fwd,Sm100] Use pipeline abstraction for correction-epilogue
tridao 53fddb7
[Fwd,Sm100] Tune registers
tridao 5b7393c
guard use_2cta_instrs on sm90 (#2274)
reubenconducts 6f30f07
[cute] Add return_lse (#2271)
erikwijmans 1d3d290
[Fwd,Sm100] Use pipeline abstraction for O_full
tridao 52ad947
[Fwd,Sm100] Use pipeline abstraction for mbar_P_full_2
tridao 02fbbd3
[Fwd,Sm100] Use TmemAllocator
tridao b2edc9d
[Fwd,Sm100] Set split_P_arrive as a tunable parameter
tridao caac1f2
[Fwd,Sm100] Use pipeline abstraction for s0_s1_sequence
tridao 048ae57
[Fwd,Sm100] Fix tScS partitioning for score_mod
tridao 59f1380
fix mask mod bugs (#2276)
reubenconducts 1920b13
[Cute,Sm100,Bwd] Fix and enable 2CTA path for hdim 128 backward (#2280)
jayhshah 9dd8194
[Fwd,Sm100] Change layout of gQ and gO to have q_stage
tridao a9d62f8
[Fwd,Sm100] Pass cta_layout_vmnk to pipelines
tridao 1c5259f
[Fwd,Sm100] Gate mma with is_leader_cta
tridao ce6ade2
[Fwd,Sm100] Take into account mma_tile_coord_v when reading/writing
tridao ceac28c
[Fwd,Sm100] Add pipeline.producer_tail
tridao 112bc15
[Fwd,Sm100] Enable 2CTA for hdim128 noncausal
tridao 436b941
Bump to 4.4.1 to avoid segfault (#2291)
drisspg 22e4ba1
Fix sm100 fwd missing tSrQs init regression (#2293)
drisspg 1b6750a
[Scheduler] Revert SingleTileScheduler to get block_idx
tridao dd507a6
[CuTe] Include broadcast dims in backward compile cache keys (#2298)
bonpyt 23850ad
[Fwd,Sm100] Use NamedBarrier to signal softmax -> corr warps
tridao 2acd159
[Fwd,Sm100] Add polynomials degree 1 - 5
tridao 7cf1b92
[Fwd,Sm100] Switch back to poly degree 3
tridao 35a2ad1
[Fwd,Sm100] Compute kv_stage based on hdim instead of hard-coding
tridao fac979d
[Cute][Testing] Add fake tensor mode support for compile-only test pa…
Alkaid-Benetnash 82b466f
Enable hdim=96 bwd (#2302)
v0i0 266e29c
Fix GQA crash in cute FLASH backend: init load_Q before conditional (…
platers d4331ee
[Fwd,Sm100] Be more explicit when loading Q
tridao 87aa3ec
[Fwd,Sm100] Tune ex2_emu_freq
tridao 472d91f
[Fwd,Sm100] Tweak ptx for gemm
tridao 90f46a9
[Bench] Enable benchmarking bwd with headdim != headdim_v
tridao adba69b
fix paged kv (#2303)
jayhshah 2fc03d8
Add FA4 publishing strategy (#2282)
drisspg 8daecf3
[Cute][Testing] Add persistent compile cache for cutedsl AOT compilat…
Alkaid-Benetnash 3780a4b
[Bench] Add reference attn implementation
tridao bea3263
[Bwd,Sm100] Use TmemAllocator
tridao 8f60d23
Change PyPI name to flash-attn4
tridao f3f6ee5
Try again
tridao cb323bb
Change PyPI package name to fa4
tridao 4d5336e
[Bwd,Sm100] Add fence_view_async_shared before LSE release
tridao 8cce7a7
Change PyPI name back to flash-attn-4
tridao 8191505
[Bwd,Sm103] Fix postprocess for 2cta_instrs
tridao cb6022f
[Sm100] Fix tmem delloc: sync before dealloc
tridao 99c96d1
[Test] Skip non-files in cache_utils.py
tridao e7467fe
Add more code authors
tridao 0552751
Nicer headdim error message (#2227)
drisspg 0862177
[Fwd,Sm100] Extract named barriers (#2309)
drisspg b002ba5
Change 2cta opt in to have min seqlen > 2*m_block_size (#2320)
drisspg a2ef2a4
[CuteDSL][SM90] varlen bwd works (#2275)
KareemMusleh b61657e
[Fwd,Sm90] Move FwdSm90 to a separate file
tridao f60628c
[GQA] Refactor pack_gqa_layout into a helper function
tridao db58018
[Fwd] Refactor compute_softmax_scale_log2 and comptue_fastdiv_mods
tridao 845a6de
[GQA] Add unpack_gqa_layout
tridao 9422ff1
Add Logging helper (#2327)
drisspg 31e65e3
[Sm80] basic fix for new api (#2297)
zhuochenKIDD 0c184b3
fix: duplicate softmax_scale param (#2328)
NanoCode012 47d8f85
[Bwd] Compile bwd_preprocess with cute fake tensors
tridao ec18b53
[Bwd] Clean up bwd_preprocess kernel
tridao 0d4f85e
[Fwd] Port SeqlenInfoQKNewK from C++ to cute-dsl
tridao e7837fb
[Fwd] Clean up fwd_combine kernel, compile w cute fake tensors
tridao 8adabb4
[Fwd,Sm80] Fix import of BlockSparseTensors
tridao 3c2448f
[Fwd,Sm90] Tune tile size for hdim 64, 96, 128
tridao 2f9c6ec
[Bwd,Sm90] Implement deterministic
tridao 72e4907
[Cute,Sm100] Introduce a flexible lambda-based R2P masking (#2313)
Alkaid-Benetnash 69f6b17
[Bwd] Compile bwd_postprocess with cute fake tensors
tridao 4f08f35
SM120 forward pass (Blackwell GeForce / DGX Spark) (#2329)
blake-snc 5d7504c
rename logging module (#2335)
Luosuu 325022e
Add tile parameter to SeqlenInfo creation (#2337)
risan-raja d1429b1
Fix (#2338)
MatthewBonanni a6c0517
[Bwd,Sm90] Pass tile_m to bwd_preprocess, enable varlen tests
tridao 3f1eecd
[Fwd,Sm90] Use mask_r2p_lamba
tridao a5cf610
[Bwd,Sm90] Fix varlen scheduler
tridao a55b9cb
[Bwd,Sm90] Enable varlen tests with seqused_k
tridao e638a22
[Bwd,Sm120] Add SM120 backward pass support (#2330)
blake-snc e0e8a43
[Bwd,Sm90] Enable local
tridao 25466ed
fix tdKrdS typo (#2341)
henrylhtsang 13327e4
[Bwd,Sm90] Implement ShuffleLSE
tridao c911921
[Bwd] Support gradient wrt LSE
tridao e09ed3b
Add SM120 varlen attention support (#2333)
blake-snc 39e4dc4
[Bwd] Use ragged tensor for TMA dKV when varlen
tridao f7755eb
[Bwd,Sm90] Set bwd configs, make hdim64 bwd work
tridao f6129aa
fix the create_ragged_tensor_for_tma issue (#2345)
rainj-me 9fe2595
[Fwd,Sm90] Implement rescale_O_before_gemm, enable hdim 192 & 256
tridao 0cabec3
[Fwd,Sm90] Add hdim 192 and 256 to _validate_head_dims
tridao 08c5d74
Support CPU-only compilation and overriding arch
tridao 43661f7
[Bwd,Sm90] Implement PDL between bwd_preprocess and bwd
tridao ca70809
[Bench] Refactor benchmark script to take args from cmdline
tridao df94d15
[Bwd,Sm90] Implement dQ_single_wg
tridao c44aa4a
[Bwd,Sm90] Make hdim 96 work
tridao e4261ed
[Sm90] Add script to search fwd bwd configs
tridao 9e93505
[Sm90] Clean up sm90_config_search.py
tridao 28c5782
[Bwd,Sm90] Implement hdim 192-128 and hdim 192
tridao d2c80ef
[Sm90] Fix test_mask_mod and bwd block-sparse kwarg mismatch (#2365)
henrylhtsang 2c1ff34
[Cute, Testing] Move stream parameter to end of kernel __call__ signa…
Alkaid-Benetnash e707f7b
[Cute] Bump cutedsl to 4.4.2 and remove prior aot workarounds (#2370)
Alkaid-Benetnash e918ba0
[Cute] fix: FA4 paged attention kv load for DeepSeek (192,128) on SM1…
Luosuu d221956
[Fwd,Sm90] Add paged KV attention support (tma and cp.async) (#2360)
henrylhtsang e083e1e
[CuTe,Flex] limit vec_size to 2 for score mod when not on Sm100 (#2371)
reubenconducts d1e12c3
[Fwd,Sm90] Use pipeline_q instead of raw mbarrier
tridao cf63ded
[Fwd,Sm90] Use TMA for O when PackGQA, keep no. TMA dim when PackGQA
tridao 7c52f15
Support 2CTA for sliding window hdim 192 (#2347)
Inodayy 4a4c1aa
[Fwd,Sm90] Use TMA for O when varlen
tridao 522c0f6
support irregular q to kv head ratio (#2186)
timmy-feng 0a136e6
[Pipeline] Refactor
tridao 82ccf65
[Fwd,Sm90] Use producer instead of mma warps to load Q when !TMA_Q
tridao b3bd27d
[Fwd,Sm90] Implement PipelineAsync with elect_one for commit/release
tridao 16fc364
[DSL] Remove ArgumentsBase
tridao 532293a
Add 'flash_sparse_attn/ops/cute/' from commit '16fc364769fa6329b73a55…
LoserCheems 52b40d9
Add sync_cute_subtree scripts for managing upstream repository integr…
LoserCheems File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| [flake8] | ||
| max-line-length = 100 | ||
| # W503: line break before binary operator | ||
| ignore = E731, E741, F841, W503 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| Tri Dao | ||
| Jay Shah | ||
| Ted Zadouri | ||
| Markus Hoehnerbach | ||
| Vijay Thakkar | ||
| Timmy Liu | ||
| Driss Guessous | ||
| Reuben Stern |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| BSD 3-Clause License | ||
|
|
||
| Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. | ||
| All rights reserved. | ||
|
|
||
| Redistribution and use in source and binary forms, with or without | ||
| modification, are permitted provided that the following conditions are met: | ||
|
|
||
| * Redistributions of source code must retain the above copyright notice, this | ||
| list of conditions and the following disclaimer. | ||
|
|
||
| * Redistributions in binary form must reproduce the above copyright notice, | ||
| this list of conditions and the following disclaimer in the documentation | ||
| and/or other materials provided with the distribution. | ||
|
|
||
| * Neither the name of the copyright holder nor the names of its | ||
| contributors may be used to endorse or promote products derived from | ||
| this software without specific prior written permission. | ||
|
|
||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
| DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
| FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| global-exclude *.egg-info/* | ||
| prune flash_attn_4.egg-info | ||
| prune flash_attn.egg-info | ||
| prune build | ||
| prune dist |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # FlashAttention-4 (CuTeDSL) | ||
|
|
||
| FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs. | ||
|
|
||
| ## Installation | ||
|
|
||
| ```sh | ||
| pip install flash-attn-4 | ||
| ``` | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| from flash_attn.cute import flash_attn_func, flash_attn_varlen_func | ||
|
|
||
| out = flash_attn_func(q, k, v, causal=True) | ||
| ``` | ||
|
|
||
| ## Development | ||
|
|
||
| ```sh | ||
| git clone https://github.com/Dao-AILab/flash-attention.git | ||
| cd flash-attention | ||
| pip install -e "flash_attn/cute[dev]" | ||
| pytest tests/cute/ | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| """Flash Attention CUTE (CUDA Template Engine) implementation.""" | ||
|
|
||
| from importlib.metadata import PackageNotFoundError, version | ||
|
|
||
| try: | ||
| __version__ = version("fa4") | ||
| except PackageNotFoundError: | ||
| __version__ = "0.0.0" | ||
|
|
||
| import cutlass.cute as cute | ||
|
|
||
| from .interface import ( | ||
| flash_attn_func, | ||
| flash_attn_varlen_func, | ||
| ) | ||
|
|
||
| from flash_attn.cute.cute_dsl_utils import cute_compile_patched | ||
|
|
||
| # Patch cute.compile to optionally dump SASS | ||
| cute.compile = cute_compile_patched | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "flash_attn_func", | ||
| "flash_attn_varlen_func", | ||
| ] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Copyright (c) 2025, Tri Dao. | ||
| from typing import Type, Callable, Optional | ||
|
|
||
| import cutlass | ||
| import cutlass.cute as cute | ||
|
|
||
|
|
||
| def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: | ||
| dtype_byte = cutlass.const_expr(dtype.width // 8) | ||
| bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) | ||
| smem_k_block_size = ( | ||
| cutlass.const_expr( | ||
| 128 | ||
| if bytes_per_row % 128 == 0 | ||
| else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) | ||
| ) | ||
| // dtype_byte | ||
| ) | ||
| swizzle_bits = ( | ||
| 4 | ||
| if smem_k_block_size == 128 | ||
| else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) | ||
| ) | ||
| swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) | ||
| return cute.make_composed_layout( | ||
| cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), | ||
| 0, | ||
| cute.make_ordered_layout( | ||
| (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0) | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def gemm( | ||
| tiled_mma: cute.TiledMma, | ||
| acc: cute.Tensor, | ||
| tCrA: cute.Tensor, | ||
| tCrB: cute.Tensor, | ||
| tCsA: cute.Tensor, | ||
| tCsB: cute.Tensor, | ||
| smem_thr_copy_A: cute.TiledCopy, | ||
| smem_thr_copy_B: cute.TiledCopy, | ||
| hook_fn: Optional[Callable] = None, | ||
| A_in_regs: cutlass.Constexpr[bool] = False, | ||
| B_in_regs: cutlass.Constexpr[bool] = False, | ||
| swap_AB: cutlass.Constexpr[bool] = False, | ||
| ) -> None: | ||
| if cutlass.const_expr(swap_AB): | ||
| gemm( | ||
| tiled_mma, | ||
| acc, | ||
| tCrB, | ||
| tCrA, | ||
| tCsB, | ||
| tCsA, | ||
| smem_thr_copy_B, | ||
| smem_thr_copy_A, | ||
| hook_fn, | ||
| A_in_regs=B_in_regs, | ||
| B_in_regs=A_in_regs, | ||
| swap_AB=False, | ||
| ) | ||
| else: | ||
| tCrA_copy_view = smem_thr_copy_A.retile(tCrA) | ||
| tCrB_copy_view = smem_thr_copy_B.retile(tCrB) | ||
| if cutlass.const_expr(not A_in_regs): | ||
| cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0]) | ||
| if cutlass.const_expr(not B_in_regs): | ||
| cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) | ||
| for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])): | ||
| if k < cute.size(tCsA.shape[2]) - 1: | ||
| if cutlass.const_expr(not A_in_regs): | ||
| cute.copy( | ||
| smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1] | ||
| ) | ||
| if cutlass.const_expr(not B_in_regs): | ||
| cute.copy( | ||
| smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1] | ||
| ) | ||
| cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) | ||
| if cutlass.const_expr(k == 0 and hook_fn is not None): | ||
| hook_fn() | ||
|
|
||
|
|
||
| @cute.jit | ||
| def gemm_rs( | ||
| tiled_mma: cute.TiledMma, | ||
| acc: cute.Tensor, | ||
| tCrA: cute.Tensor, | ||
| tCrB: cute.Tensor, | ||
| tCsB: cute.Tensor, | ||
| smem_thr_copy_B: cute.TiledCopy, | ||
| hook_fn: Optional[Callable] = None, | ||
| ) -> None: | ||
| tCrB_copy_view = smem_thr_copy_B.retile(tCrB) | ||
| cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) | ||
| for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): | ||
| if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1): | ||
| cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) | ||
| cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) | ||
| if cutlass.const_expr(k == 0 and hook_fn is not None): | ||
| hook_fn() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| import cutlass | ||
| import cutlass.cute as cute | ||
| from cutlass import Int32 | ||
| from cutlass.cutlass_dsl import T, dsl_user_op | ||
| from cutlass._mlir.dialects import llvm | ||
|
|
||
|
|
||
| @dsl_user_op | ||
| def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: | ||
| lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() | ||
| state = llvm.inline_asm( | ||
| T.i32(), | ||
| [lock_ptr_i64], | ||
| "ld.global.acquire.gpu.b32 $0, [$1];", | ||
| "=r,l", | ||
| has_side_effects=True, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) | ||
| return cutlass.Int32(state) | ||
|
|
||
|
|
||
| @dsl_user_op | ||
| def red_relaxed( | ||
| lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None | ||
| ) -> None: | ||
| lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() | ||
| llvm.inline_asm( | ||
| None, | ||
| [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], | ||
| "red.relaxed.gpu.global.add.s32 [$0], $1;", | ||
| "l,r", | ||
| has_side_effects=True, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) | ||
|
|
||
|
|
||
| @dsl_user_op | ||
| def red_release( | ||
| lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None | ||
| ) -> None: | ||
| lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() | ||
| llvm.inline_asm( | ||
| None, | ||
| [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], | ||
| "red.release.gpu.global.add.s32 [$0], $1;", | ||
| "l,r", | ||
| has_side_effects=True, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: | ||
| flag_ptr = lock_ptr + flag_offset | ||
| if thread_idx == 0: | ||
| read_val = Int32(0) | ||
| while read_val != val: | ||
| read_val = ld_acquire(flag_ptr) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def arrive_inc( | ||
| lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] | ||
| ) -> None: | ||
| flag_ptr = lock_ptr + flag_offset | ||
| if thread_idx == 0: | ||
| red_release(flag_ptr, val) | ||
| # red_relaxed(flag_ptr, val) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distribution metadata lookup uses
version(\"fa4\"), but the addedpyproject.tomldeclaresname = \"flash-attn-4\". With this mismatch,__version__will fall back to0.0.0even when installed. Update the queried distribution name to match the project name (or align the project name iffa4is intended).