Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void trtllm_paged_attention_launcher(
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sm_count,
cudaStream_t stream) {
bool enable_pdl, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -128,6 +128,7 @@ void trtllm_paged_attention_launcher(
runner_params.mMaxSeqLenQ = max_q_len;
runner_params.mSumOfSeqLensQ = sum_seq_q;
runner_params.ptrAttentionSinks = attention_sinks;
runner_params.enable_pdl = enable_pdl;
if (mode == TllmPagedAttentionMode::Context) {
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
runner_params.mKernelType = FmhaKernelType::Context;
Expand Down Expand Up @@ -191,7 +192,7 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out
at::Tensor seq_lens, int64_t max_kv_len, double bmm1_scale,
double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sm_count,
std::optional<at::Tensor> attention_sinks) {
bool enable_pdl, std::optional<at::Tensor> attention_sinks) {
auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type());
auto kv_data_type = torch_dtype_to_tllm_data_type(key_cache.scalar_type());
TORCH_CHECK_EQ(key_cache.dim(), value_cache.dim());
Expand Down Expand Up @@ -249,7 +250,7 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> out
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
stream);
enable_pdl, stream);
}

void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> out_scale_factor,
Expand All @@ -260,7 +261,8 @@ void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> ou
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t batch_size, int64_t window_left,
at::Tensor cum_seq_lens_q, at::Tensor cum_seq_lens_kv,
int64_t sm_count, std::optional<at::Tensor> attention_sinks) {
int64_t sm_count, bool enable_pdl,
std::optional<at::Tensor> attention_sinks) {
auto q_data_type = torch_dtype_to_tllm_data_type(query.scalar_type());
auto kv_data_type = torch_dtype_to_tllm_data_type(key_cache.scalar_type());
auto o_data_type = torch_dtype_to_tllm_data_type(out.scalar_type());
Expand Down Expand Up @@ -308,7 +310,7 @@ void trtllm_paged_attention_context(at::Tensor out, std::optional<at::Tensor> ou
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index,
window_left, sum_seq_q, sm_count, stream);
window_left, sum_seq_q, sm_count, enable_pdl, stream);
}

namespace trtllm_cubin_loader {
Expand Down
15 changes: 15 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,13 +1820,15 @@ def _paged_run(
bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later
bmm2_scale: float,
window_left: int = -1,
enable_pdl: bool = None,
out: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(query)
if self._sm_count is None:
self._sm_count = get_device_sm_count(query.device)

self._op.trtllm_paged_attention_decode(
out,
None, # fp4 output not supported in wrapper api yet.
Expand All @@ -1846,6 +1848,7 @@ def _paged_run(
0, # o_sf_start_index
window_left,
self._sm_count,
enable_pdl,
sinks,
)
return out
Expand Down Expand Up @@ -1916,6 +1919,7 @@ def paged_run(
assert kv_lens_buffer is not None
assert page_size is not None
assert max_kv_len is not None
assert enable_pdl is not None
o = module._paged_run(
q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect
paged_k_cache,
Expand All @@ -1927,6 +1931,7 @@ def paged_run(
sm_scale,
1.0, # NOTE(Siyuan): update this to expose bmm2 scale
window_left,
enable_pdl,
out=o,
sinks=sinks,
)
Expand Down Expand Up @@ -1996,6 +2001,7 @@ def trtllm_batch_decode_with_kv_cache(
o_sf_scale: Optional[float] = None,
o_sf_vec_size: Optional[int] = None,
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand Down Expand Up @@ -2044,11 +2050,16 @@ def trtllm_batch_decode_with_kv_cache(
sinks : Optional[List[torch.Tensor]] = None
additional value per head in the denominator of the softmax.

enable_pdl : bool
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
Expand Down Expand Up @@ -2132,6 +2143,7 @@ def trtllm_batch_decode_with_kv_cache(
o_sf_start_index,
window_left,
sm_count,
enable_pdl,
sinks,
)

Expand Down Expand Up @@ -2200,6 +2212,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
bmm2_scale_tensor: Optional[torch.Tensor] = None,
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
) -> torch.Tensor:
"""
Parameters:
Expand Down Expand Up @@ -2237,6 +2250,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)

Expand Down Expand Up @@ -2293,6 +2307,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
0, # o_sf_start_index
-1, # window_left
sm_count,
enable_pdl,
sinks,
)
return out
9 changes: 9 additions & 0 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def _paged_run(
batch_size: int,
cum_seq_lens_q: torch.Tensor,
cum_seq_lens_kv: torch.Tensor,
enable_pdl: bool,
window_left: int = -1,
out: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -214,6 +215,7 @@ def _paged_run(
cum_seq_lens_q,
cum_seq_lens_kv,
sm_count,
enable_pdl,
sinks,
)
return out
Expand Down Expand Up @@ -546,6 +548,7 @@ def paged_run(
assert batch_size is not None
assert cum_seq_lens_q is not None
assert cum_seq_lens_kv is not None
assert enable_pdl is not None
o = paged_run_func(
q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect
paged_k_cache,
Expand All @@ -560,6 +563,7 @@ def paged_run(
batch_size,
cum_seq_lens_q,
cum_seq_lens_kv,
enable_pdl,
window_left,
out=o,
sinks=sinks,
Expand Down Expand Up @@ -3134,6 +3138,7 @@ def trtllm_batch_context_with_kv_cache(
out_dtype: Optional[Union[torch.dtype, str]] = None,
o_sf_scale: Optional[float] = None,
o_sf_vec_size: Optional[int] = None,
enable_pdl: Optional[bool] = None,
sinks: Optional[List[torch.Tensor]] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Expand Down Expand Up @@ -3184,6 +3189,9 @@ def trtllm_batch_context_with_kv_cache(
output torch.Tensor or FP4Tensor.
"""

if enable_pdl is None:
enable_pdl = device_support_pdl(query.device)

if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
else:
Expand Down Expand Up @@ -3270,6 +3278,7 @@ def trtllm_batch_context_with_kv_cache(
cum_seq_lens_q,
cum_seq_lens_kv,
sm_count,
enable_pdl,
sinks,
)
return (
Expand Down
8 changes: 8 additions & 0 deletions tests/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def unpack_compare_nvfp4(
("fp8", "fp8", "nvfp4"),
],
)
@pytest.mark.parametrize("enable_pdl", [True, False, None])
def test_trtllm_batch_prefill(
kv_layout,
batch_size,
Expand All @@ -249,6 +250,7 @@ def test_trtllm_batch_prefill(
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
):
# Set up test parameters
torch.manual_seed(0)
Expand Down Expand Up @@ -340,6 +342,7 @@ def test_trtllm_batch_prefill(
out_dtype=DTYPE_MAP[o_dtype],
o_sf_scale=o_sf_scale,
o_sf_vec_size=o_sf_vec_size,
enable_pdl=enable_pdl,
)

if o_dtype == "nvfp4":
Expand Down Expand Up @@ -372,6 +375,7 @@ def test_trtllm_batch_prefill(
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale / o_scale,
enable_pdl=enable_pdl,
)
# v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel.
if v_scale == o_scale == 1.0:
Expand Down Expand Up @@ -399,6 +403,7 @@ def test_trtllm_batch_prefill(
("fp8", "fp8", "nvfp4"),
],
)
@pytest.mark.parametrize("enable_pdl", [True, False, None])
def test_trtllm_batch_decode(
kv_layout,
batch_size,
Expand All @@ -409,6 +414,7 @@ def test_trtllm_batch_decode(
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
):
# Set up test parameters
torch.manual_seed(0)
Expand Down Expand Up @@ -493,6 +499,7 @@ def test_trtllm_batch_decode(
out_dtype=DTYPE_MAP[o_dtype],
o_sf_scale=o_sf_scale,
o_sf_vec_size=o_sf_vec_size,
enable_pdl=enable_pdl,
)

if o_dtype == "nvfp4":
Expand Down Expand Up @@ -525,6 +532,7 @@ def test_trtllm_batch_decode(
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale / o_scale,
enable_pdl=enable_pdl,
)
# v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel.
if v_scale == o_scale == 1.0:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
@pytest.mark.parametrize("page_size", [32, 64])
@pytest.mark.parametrize("q_len_per_request", [1, 2])
@pytest.mark.parametrize("dynamic_scale", [False])
@pytest.mark.parametrize("enable_pdl", [True, False, None])
def test_trtllm_batch_decode_mla(
batch_size: int,
scale: float,
dtype: torch.dtype,
page_size: int,
q_len_per_request: int,
dynamic_scale: bool,
enable_pdl: bool,
):
if dynamic_scale and dtype != torch.float8_e4m3fn:
pytest.skip("Dynamic scale is not supported for non-fp8 dtype")
Expand Down Expand Up @@ -128,6 +130,7 @@ def test_trtllm_batch_decode_mla(
bmm2_scale=1.0,
bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,
bmm2_scale_tensor=bmm2_scale_tensor,
enable_pdl=enable_pdl,
)

# Run reference attention and align output
Expand Down