Skip to content

Commit 5997daf

Browse files
Add paged attention decode with Gluon JIT/AOT for AMD CDNA3 (ROCm#1383)
* add pa_decode_triton_fp8_gluon kernel and perf benchmark * enhance code readability, make the code more elegant * add gluon version paged attention decode and test - Create __init__.py for aiter.ops.triton.gluon module - Add test_paged_attention_decode_gluon.py to test gluon FP8 implementation - Test compares performance between assembly and gluon kernels * add per_tensor quant test for gluon version PA * fix per_tensor quant bug of paged_attention_decode_v2_gluon_large_block_fp8 * triton3.4 and triton3.5 use different reduce codes to avoid performance degeneration caused by triton3.5 * fix bug in kv use varlen, fix per_token && per_tensor bug, add torch_mha_extend_flashattn_style which is consistent with the triton kernel to compare diff with the triton kernel * rm useless code * add pa_gluon AOT compile flow and unittest * fix pa_gluon AOT compile flow bug * remove xxx_POW2 params of PA triton kernels && add separate test for pa_decode_attention_kernel AOT compile * modify pa kernel interface && fix add compute_type param problem in AOT mode * merge pa_decode attention and reduce kernel into one C++ wrapper * rename some variable && add paged_attention_decode_v2_reduce_kernel_triton34 into pa_decode_triton_gluon_fp8.py and remove pa_decode_triton_gluon_fp8_triton34.py * rename file * change the triton version comparison logic * format file * fix the issue of pa gluon performance decline in triton AOT mode * add assertions to all tests pass in test_pa_decode_gluon.py * fix per-token quant bug * add support for COMPUTE_TYPE: bf16,fp16 && QUANT_Q_AND_KV: (False,False),(False,True) * AOT mode support COMPUTE_TYPE: bf16,fp16 && QUANT_Q_AND_KV: (False,False),(False,True) * add script to build PA AOT so * the script to build PA AOT so support multiprocessing * add loop version of paged_attention_decode_v2_reduce_kernel to remove MAX_CONTEXT_PARTITION_NUM and reduce so count in AOT mode * reduce kernel templates * modify build PA AOT so options * test support sample_rate to save time * fix bug of building PA AOT so with multiprocessing * refactor: generalize gluon kernel && implement build cache cleanup with .so file reporting * feat: add backward compatibility for Triton without gluon Gracefully handle missing triton.experimental.gluon imports by adding try-except blocks and runtime checks. Add simplified test configuration. * optimize paged_attention_decode_v2_reduce_kernel in long context scenarios * fix deadlock bug of building PA AOT so with multiprocessing && refine unit test * close support for bf16 and fp16 for the time being, due to some precision problems * Add Gluon transpose kernels for query and output in paged attention decode - Add transpose_query_gluon_kernel to transpose query from [batch*seq, num_heads, head_size] to Gluon format - Add transpose_output_gluon_kernel to transpose output back to standard format - Support both JIT and AOT compilation modes - Update pa_decode_gluon API to handle transposition internally - Add unit tests and prebuild scripts for transpose kernels * ali test * use gl.int64 to hold kv_cache offsets to avoid overflow when the kv_cache's shape is too large * clean cache file after aot compile function, add doc for func pa_decode_gluon and pa_decode_gluon_aot * remove paged_attention_decode_v2_reduce_kernel_triton34, PA kernel can only run in triton 3.5.0 or higher * rename PA dot kernel * Add arch assertion to restrict pa_decode_gluon to gfx942 (CDNA3) --------- Co-authored-by: Xin Huang <[email protected]>
1 parent 97ac929 commit 5997daf

18 files changed

+11190
-0
lines changed

aiter/ops/triton/gluon/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

0 commit comments

Comments
 (0)