Commit 5997daf
Add paged attention decode with Gluon JIT/AOT for AMD CDNA3 (ROCm#1383)
* add pa_decode_triton_fp8_gluon kernel and perf benchmark
* enhance code readability, make the code more elegant
* add gluon version paged attention decode and test
- Create __init__.py for aiter.ops.triton.gluon module
- Add test_paged_attention_decode_gluon.py to test gluon FP8 implementation
- Test compares performance between assembly and gluon kernels
* add per_tensor quant test for gluon version PA
* fix per_tensor quant bug of paged_attention_decode_v2_gluon_large_block_fp8
* triton3.4 and triton3.5 use different reduce codes to avoid performance degeneration caused by triton3.5
* fix bug in kv use varlen, fix per_token && per_tensor bug, add torch_mha_extend_flashattn_style which is consistent with the triton kernel to compare diff with the triton kernel
* rm useless code
* add pa_gluon AOT compile flow and unittest
* fix pa_gluon AOT compile flow bug
* remove xxx_POW2 params of PA triton kernels && add separate test for pa_decode_attention_kernel AOT compile
* modify pa kernel interface && fix add compute_type param problem in AOT mode
* merge pa_decode attention and reduce kernel into one C++ wrapper
* rename some variable && add paged_attention_decode_v2_reduce_kernel_triton34 into pa_decode_triton_gluon_fp8.py and remove pa_decode_triton_gluon_fp8_triton34.py
* rename file
* change the triton version comparison logic
* format file
* fix the issue of pa gluon performance decline in triton AOT mode
* add assertions to all tests pass in test_pa_decode_gluon.py
* fix per-token quant bug
* add support for COMPUTE_TYPE: bf16,fp16 && QUANT_Q_AND_KV: (False,False),(False,True)
* AOT mode support COMPUTE_TYPE: bf16,fp16 && QUANT_Q_AND_KV: (False,False),(False,True)
* add script to build PA AOT so
* the script to build PA AOT so support multiprocessing
* add loop version of paged_attention_decode_v2_reduce_kernel to remove MAX_CONTEXT_PARTITION_NUM and reduce so count in AOT mode
* reduce kernel templates
* modify build PA AOT so options
* test support sample_rate to save time
* fix bug of building PA AOT so with multiprocessing
* refactor: generalize gluon kernel && implement build cache cleanup with .so file reporting
* feat: add backward compatibility for Triton without gluon
Gracefully handle missing triton.experimental.gluon imports by adding try-except blocks and runtime checks. Add simplified test configuration.
* optimize paged_attention_decode_v2_reduce_kernel in long context scenarios
* fix deadlock bug of building PA AOT so with multiprocessing && refine unit test
* close support for bf16 and fp16 for the time being, due to some precision problems
* Add Gluon transpose kernels for query and output in paged attention decode
- Add transpose_query_gluon_kernel to transpose query from
[batch*seq, num_heads, head_size] to Gluon format
- Add transpose_output_gluon_kernel to transpose output back
to standard format
- Support both JIT and AOT compilation modes
- Update pa_decode_gluon API to handle transposition internally
- Add unit tests and prebuild scripts for transpose kernels
* ali test
* use gl.int64 to hold kv_cache offsets to avoid overflow when the kv_cache's shape is too large
* clean cache file after aot compile function, add doc for func pa_decode_gluon and pa_decode_gluon_aot
* remove paged_attention_decode_v2_reduce_kernel_triton34, PA kernel can only run in triton 3.5.0 or higher
* rename PA dot kernel
* Add arch assertion to restrict pa_decode_gluon to gfx942 (CDNA3)
---------
Co-authored-by: Xin Huang <[email protected]>1 parent 97ac929 commit 5997daf
File tree
18 files changed
+11190
-0
lines changed- aiter/ops/triton/gluon
- csrc/cpp_itfs
- gluon_aot_tools
- extra/hip
- pa_gluon_aot
- op_tests/triton_tests
18 files changed
+11190
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
0 commit comments