Skip to content

Commit 77082cd

Browse files
authored
[https://nvbugspro.nvidia.com/bug/5329655] [feat] Pytorch path add spec dec param to attention op (NVIDIA#5146)
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 4cd8543 commit 77082cd

File tree

10 files changed

+262
-57
lines changed

10 files changed

+262
-57
lines changed

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class RunnerBase
7878
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
7979
torch::optional<torch::Tensor> mla_context_paged_kv,
8080
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
81-
torch::optional<torch::Tensor> softmax_stats_tensor) const
81+
torch::optional<torch::Tensor> softmax_stats_tensor,
82+
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params) const
8283
= 0;
8384
};
8485

@@ -129,7 +130,8 @@ class Runner : public RunnerBase
129130
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
130131
torch::optional<torch::Tensor> mla_context_paged_kv,
131132
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
132-
torch::optional<torch::Tensor> softmax_stats_tensor) const override
133+
torch::optional<torch::Tensor> softmax_stats_tensor,
134+
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params) const override
133135
{
134136
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
135137
T* attention_input = static_cast<T*>(qkv.slice(0, token_offset).data_ptr());
@@ -322,6 +324,27 @@ class Runner : public RunnerBase
322324
{
323325
enqueue_params.mrope_position_deltas = mrope_position_deltas.value().data_ptr<int32_t>();
324326
}
327+
if (op.mIsSpecDecodingEnabled && op.mUseSpecDecoding)
328+
{
329+
TORCH_CHECK(spec_decoding_tensor_params.size() == 3,
330+
"Expecting 3 tensors for spec-dec mode, spec_decoding_generation_lengths, "
331+
"spec_decoding_position_offsets and spec_decoding_packed_mask.");
332+
TORCH_CHECK(spec_decoding_tensor_params[0].has_value(),
333+
"Expecting spec_decoding_generation_lengths spec-dec mode.");
334+
TORCH_CHECK(spec_decoding_tensor_params[1].has_value(),
335+
"Expecting spec_decoding_position_offsets spec-dec mode.");
336+
TORCH_CHECK(
337+
spec_decoding_tensor_params[2].has_value(), "Expecting spec_decoding_packed_mask spec-dec mode.");
338+
339+
enqueue_params.spec_decoding_generation_lengths
340+
= spec_decoding_tensor_params[0].value().data_ptr<int32_t>();
341+
enqueue_params.spec_decoding_position_offsets
342+
= spec_decoding_tensor_params[1].value().data_ptr<int32_t>();
343+
enqueue_params.spec_decoding_packed_mask = spec_decoding_tensor_params[2].value().data_ptr<int32_t>();
344+
enqueue_params.spec_decoding_is_generation_length_variable = true;
345+
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1;
346+
}
347+
325348
// Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration
326349
if (op.isMLAEnabled())
327350
{
@@ -384,15 +407,14 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
384407
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
385408
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
386409
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
387-
double const rotary_embedding_scale, double const rotary_embedding_short_m_scale,
388-
double const rotary_embedding_long_m_scale, int64_t const rotary_embedding_max_positions,
389-
int64_t const rotary_embedding_original_max_positions, bool const use_paged_context_fmha,
390-
std::optional<int64_t> attention_input_type, bool is_mla_enable, std::optional<int64_t> q_lora_rank,
391-
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
410+
c10::ArrayRef<double> rotary_embedding_scales, c10::ArrayRef<int64_t> rotary_embedding_max_position_info,
411+
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
412+
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
392413
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
393414
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
394415
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
395-
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor)
416+
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
417+
c10::List<bool> spec_decoding_bool_params, c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params)
396418
{
397419
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
398420
// Use these tensors to infer if the attention is using KV cache
@@ -462,6 +484,12 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
462484
runner->attention_window_size = attention_window_size;
463485
runner->sink_token_length = sink_token_length;
464486

487+
double const rotary_embedding_scale = rotary_embedding_scales[0];
488+
double const rotary_embedding_short_m_scale = rotary_embedding_scales[1];
489+
double const rotary_embedding_long_m_scale = rotary_embedding_scales[2];
490+
int64_t const rotary_embedding_max_positions = rotary_embedding_max_position_info[0];
491+
int64_t const rotary_embedding_original_max_positions = rotary_embedding_max_position_info[1];
492+
465493
auto op = std::make_shared<AttentionOp>();
466494
op->mType = dtype;
467495
op->mFMHAForceFP32Acc = dtype == nvinfer1::DataType::kBF16;
@@ -494,6 +522,12 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
494522

495523
op->mAttentionChunkSize = attention_chunk_size;
496524

525+
TORCH_CHECK(spec_decoding_bool_params.size() == 2,
526+
"Expecting 2 bools for spec-dec mode, is_spec_decoding_enabled and use_spec_decoding.");
527+
op->mIsSpecDecodingEnabled = spec_decoding_bool_params[0]; // is_spec_decoding_enabled
528+
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
529+
op->mMultiBlockMode = op->mIsSpecDecodingEnabled ? false : true;
530+
497531
if (is_mla_enable)
498532
{
499533
// MLA does not support NVFP4 output yet.
@@ -610,7 +644,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
610644
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
611645
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
612646
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
613-
mla_context_kv_cache_block_offsets, softmax_stats_tensor);
647+
mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params);
614648
}
615649

616650
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -626,7 +660,7 @@ void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch:
626660
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
627661
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
628662
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
629-
mla_context_kv_cache_block_offsets, softmax_stats_tensor);
663+
mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params);
630664
}
631665

632666
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
@@ -731,11 +765,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
731765
", int rotary_embedding_dim"
732766
", float rotary_embedding_base"
733767
", int rotary_embedding_scale_type"
734-
", float rotary_embedding_scale"
735-
", float rotary_embedding_short_m_scale"
736-
", float rotary_embedding_long_m_scale"
737-
", int rotary_embedding_max_positions"
738-
", int rotary_embedding_original_max_positions"
768+
", float[] rotary_embedding_scales"
769+
", int[] rotary_embedding_max_position_info"
739770
", bool use_paged_context_fmha"
740771
", int? attention_input_type"
741772
", bool is_mla_enable"
@@ -750,6 +781,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
750781
", Tensor? mla_context_kv_cache_block_offsets"
751782
", int? attention_chunk_size"
752783
", Tensor? softmax_stats_tensor"
784+
", bool[] spec_decoding_bool_params"
785+
", Tensor?[] spec_decoding_tensor_params"
753786
") -> ()");
754787

755788
m.def("attention_supports_nvfp4_output", &torch_ext::attention_supports_nvfp4_output);

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,12 @@ def create_cuda_graph_metadata(self,
316316
cuda_graph_metadata.__post_init__()
317317
return cuda_graph_metadata
318318

319+
def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
320+
is_spec_dec_dynamic_tree, max_draft_tokens):
321+
"""
322+
Hook to be called when using TRTLLM attention backend in spec-dec mode.
323+
"""
324+
319325

320326
class PositionalEmbedder(Protocol):
321327
"""

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import math
12
import os
23
import weakref
34
from dataclasses import dataclass, field
45
from typing import Optional
56

67
import torch
78

9+
from tensorrt_llm._utils import get_sm_version
810
from tensorrt_llm.functional import AttentionMaskType
911
from tensorrt_llm.logger import logger
1012
from tensorrt_llm.models.modeling_utils import QuantConfig
@@ -64,6 +66,10 @@ class TrtllmAttentionWrapper:
6466
qk_nope_head_dim: Optional[int]
6567
v_head_dim: Optional[int]
6668
attention_chunk_size: Optional[int]
69+
use_spec_decoding: bool
70+
spec_decoding_position_offsets: Optional[torch.Tensor]
71+
spec_decoding_packed_mask: Optional[torch.Tensor]
72+
spec_decoding_generation_lengths: Optional[torch.Tensor]
6773
kwargs: dict
6874

6975
def __init__(
@@ -169,6 +175,11 @@ def plan(
169175
mla_context_paged_kv: Optional[torch.Tensor] = None,
170176
mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None,
171177
softmax_stats_tensor: Optional[torch.Tensor] = None,
178+
is_spec_decoding_enabled: bool = False,
179+
use_spec_decoding: bool = False,
180+
spec_decoding_position_offsets: Optional[torch.Tensor] = None,
181+
spec_decoding_packed_mask: Optional[torch.Tensor] = None,
182+
spec_decoding_generation_lengths: Optional[torch.Tensor] = None,
172183
**kwargs,
173184
):
174185
"""
@@ -245,7 +256,11 @@ def plan(
245256
self.rope_params.max_positions = max_sequence_length
246257
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
247258
)
248-
259+
self.is_spec_decoding_enabled = is_spec_decoding_enabled
260+
self.use_spec_decoding = use_spec_decoding
261+
self.spec_decoding_position_offsets = spec_decoding_position_offsets
262+
self.spec_decoding_packed_mask = spec_decoding_packed_mask
263+
self.spec_decoding_generation_lengths = spec_decoding_generation_lengths
249264
self.kwargs.update(kwargs)
250265

251266
def run(
@@ -374,6 +389,23 @@ def run(
374389
# output is provided, expect output_sf be provided as well if has NVFP4 output.
375390
assert out_dtype is None or out_dtype != torch.uint8 or output_sf is not None
376391

392+
# packing parameters to avoid maxing out 64 arguments
393+
rotary_embedding_scales = [
394+
self.rotary_embedding_scale, self.rotary_embedding_short_m_scale,
395+
self.rotary_embedding_long_m_scale
396+
]
397+
rotary_embedding_max_position_info = [
398+
self.rotary_embedding_max_positions,
399+
self.rotary_embedding_original_max_positions
400+
]
401+
spec_decoding_bool_params = [
402+
self.is_spec_decoding_enabled, self.use_spec_decoding
403+
]
404+
spec_decoding_tensor_params = [
405+
self.spec_decoding_generation_lengths,
406+
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
407+
]
408+
377409
torch.ops.trtllm.attention_inplace(
378410
q,
379411
k,
@@ -420,11 +452,8 @@ def run(
420452
self.rotary_embedding_dim,
421453
self.rotary_embedding_base,
422454
self.rotary_embedding_scale_type,
423-
self.rotary_embedding_scale,
424-
self.rotary_embedding_short_m_scale,
425-
self.rotary_embedding_long_m_scale,
426-
self.rotary_embedding_max_positions,
427-
self.rotary_embedding_original_max_positions,
455+
rotary_embedding_scales,
456+
rotary_embedding_max_position_info,
428457
self.use_paged_context_fmha,
429458
self.attention_input_type,
430459
self.is_mla_enable,
@@ -439,6 +468,8 @@ def run(
439468
self.mla_context_kv_cache_block_offsets,
440469
self.attention_chunk_size,
441470
self.softmax_stats_tensor,
471+
spec_decoding_bool_params,
472+
spec_decoding_tensor_params,
442473
)
443474

444475
# reset the planned states (especially tensors) to avoid memory leak
@@ -495,6 +526,23 @@ class TrtllmAttentionMetadata(AttentionMetadata):
495526
init=True,
496527
repr=False)
497528

529+
# Flags to enable spec-dec mode (multi-query mode) in TRTLLM XQA Kernels
530+
# spec decoding mode can be enabled for non-TRTLLM-gen kernels (pre-Blackwell XQA kernels)
531+
# is_spec_decoding_enabled specifies if spec-dec mode is supported for the entire runtime.
532+
is_spec_decoding_enabled: bool = False
533+
# use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer.
534+
use_spec_decoding: bool = False
535+
536+
# if spec-dec tree is a tree or a chain (linear tree)
537+
is_spec_dec_tree: bool = False
538+
# if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
539+
is_spec_dec_dynamic_tree: bool = False
540+
541+
# parameters required for spec-dec mode
542+
spec_decoding_position_offsets: Optional[torch.Tensor] = None
543+
spec_decoding_packed_mask: Optional[torch.Tensor] = None
544+
spec_decoding_generation_lengths: Optional[torch.Tensor] = None
545+
498546
@property
499547
def max_seq_len(self) -> int:
500548
"""
@@ -849,6 +897,76 @@ def prepare_paged_context_mla(self, cached_token_lens: torch.Tensor,
849897
self.ctx_kv_indptr[:self.num_contexts + 1].copy_(
850898
self.host_ctx_kv_indptr[:self.num_contexts + 1], non_blocking=True)
851899

900+
def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
901+
is_spec_dec_dynamic_tree, max_draft_tokens):
902+
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
903+
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
904+
) < 100
905+
906+
# use_spec_decoding is default to true by default, change in runtime by layers / requests
907+
self.use_spec_decoding = self.is_spec_decoding_enabled
908+
909+
self.is_spec_dec_tree = is_spec_dec_tree
910+
self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
911+
912+
# Parameters can be fixed and not changed during runtime if the
913+
if self.is_spec_decoding_enabled:
914+
self.spec_decoding_position_offsets = torch.empty(
915+
[self.max_num_requests, max_draft_tokens + 1],
916+
dtype=torch.int,
917+
device='cuda',
918+
)
919+
920+
self.spec_decoding_packed_mask = torch.empty(
921+
[
922+
self.max_num_requests, max_draft_tokens + 1,
923+
math.ceil(max_draft_tokens / 32)
924+
],
925+
dtype=torch.int,
926+
device='cuda',
927+
)
928+
929+
self.spec_decoding_generation_lengths = torch.empty(
930+
[self.max_num_requests],
931+
dtype=torch.int,
932+
device='cuda',
933+
)
934+
935+
if self.is_spec_dec_dynamic_tree:
936+
assert False, "currently dynamic tree is not supported"
937+
else:
938+
# Populate the mask that won't change during inference phase.
939+
self.generate_spec_decoding_position_offsets(
940+
max_draft_tokens=max_draft_tokens)
941+
self.generate_spec_decoding_packed_mask(
942+
max_draft_tokens=max_draft_tokens)
943+
self.generate_spec_decoding_generation_length(
944+
max_draft_tokens=max_draft_tokens)
945+
946+
def generate_spec_decoding_position_offsets(self, max_draft_tokens):
947+
assert not self.is_spec_dec_tree, "only chained/linear tree is supported now"
948+
position_offset = torch.arange(max_draft_tokens + 1,
949+
dtype=torch.int,
950+
device='cpu',
951+
pin_memory=True)
952+
953+
# fill all the batches with same position offset
954+
self.spec_decoding_position_offsets.copy_(position_offset,
955+
non_blocking=True)
956+
957+
def generate_spec_decoding_packed_mask(self, max_draft_tokens):
958+
assert not self.is_spec_dec_tree, "only chained/linear tree is supported now"
959+
dummy_idx = torch.arange(max_draft_tokens + 1)
960+
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
961+
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,
962+
non_blocking=True)
963+
964+
def generate_spec_decoding_generation_length(self, max_draft_tokens):
965+
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
966+
max_draft_tokens + 1)
967+
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
968+
spec_decoding_generation_length, non_blocking=True)
969+
852970

853971
class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
854972

@@ -984,7 +1102,6 @@ def forward(
9841102
use_paged_context_fmha=use_paged_context_fmha,
9851103
is_mla_enable=self.is_mla_enable,
9861104
)
987-
9881105
self.wrapper.plan(
9891106
layer_idx=self.get_local_layer_idx(metadata),
9901107
tokens_per_block=metadata.tokens_per_block,
@@ -1021,6 +1138,13 @@ def forward(
10211138
mla_context_kv_cache_block_offsets=
10221139
mla_context_kv_cache_block_offsets,
10231140
softmax_stats_tensor=softmax_stats_tensor,
1141+
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
1142+
use_spec_decoding=metadata.use_spec_decoding,
1143+
spec_decoding_position_offsets=metadata.
1144+
spec_decoding_position_offsets,
1145+
spec_decoding_packed_mask=metadata.spec_decoding_packed_mask,
1146+
spec_decoding_generation_lengths=metadata.
1147+
spec_decoding_generation_lengths,
10241148
)
10251149
out_dtype = None
10261150
if out_scale is not None:

0 commit comments

Comments
 (0)