Skip to content

Commit ea3e0ee

Browse files
authored
[TRTLLM-7954][feat] Target model KV cache rellocation (#8421)
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 8a3b870 commit ea3e0ee

File tree

9 files changed

+495
-22
lines changed

9 files changed

+495
-22
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from collections import namedtuple
44
from dataclasses import dataclass, field
55
from enum import Enum, IntEnum
6-
from typing import (Dict, Generic, List, Optional, Protocol, Tuple, Type,
7-
TypeVar, Union)
6+
from typing import (TYPE_CHECKING, Dict, Generic, List, Optional, Protocol,
7+
Tuple, Type, TypeVar, Union)
88

99
import torch
1010
from typing_extensions import Self
1111

12+
if TYPE_CHECKING:
13+
from ..speculative.utils import SpecDecodingTensor
14+
1215
from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
1316
RotaryScalingType)
1417
from tensorrt_llm.mapping import Mapping
@@ -330,8 +333,13 @@ def restore_from_spec_dec(self) -> None:
330333
setattr(self, f, v)
331334
self._saved_tensors.clear()
332335

333-
def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
334-
is_spec_dec_dynamic_tree, max_draft_tokens):
336+
def update_spec_dec_param(
337+
self,
338+
is_spec_decoding_enabled,
339+
is_spec_dec_tree,
340+
is_spec_dec_dynamic_tree,
341+
max_draft_tokens,
342+
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
335343
"""
336344
Hook to be called when using TRTLLM attention backend in spec-dec mode.
337345
"""

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import tensorrt_llm
1010
import tensorrt_llm.bindings
11+
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
1112
from tensorrt_llm._torch.attention_backend.trtllm import (
1213
TrtllmAttention, TrtllmAttentionMetadata)
1314
from tensorrt_llm._torch.attention_backend.vanilla import (
@@ -949,7 +950,10 @@ def prepare_resources(self, scheduled_batch):
949950
kt_token_num = math.ceil(req.max_beam_num_tokens / self.page_size)
950951
self.add_kt_tokens(request_id, kt_token_num)
951952

952-
def update_resources(self, scheduled_batch):
953+
def update_resources(self,
954+
scheduled_batch,
955+
attn_metadata: AttentionMetadata = None,
956+
kv_cache_dtype_byte_size: float = None):
953957
for request in scheduled_batch.context_requests:
954958
if request.state != LlmRequestState.GENERATION_COMPLETE:
955959
seq_len = request.get_num_tokens(0)

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
import os
33
import weakref
44
from dataclasses import dataclass, field
5-
from typing import Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Optional, Tuple, Union
66

77
import torch
88

9+
if TYPE_CHECKING:
10+
from ..speculative.utils import SpecDecodingTensor
11+
912
from tensorrt_llm._utils import get_sm_version
1013
from tensorrt_llm.bindings.internal import thop
1114
from tensorrt_llm.functional import AttentionMaskType
@@ -1045,12 +1048,32 @@ def prepare_context_mla_with_cached_kv(self,
10451048
self.ctx_kv_indptr[:self.num_contexts + 1].copy_(
10461049
self.host_ctx_kv_indptr[:self.num_contexts + 1], non_blocking=True)
10471050

1048-
def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
1049-
is_spec_dec_dynamic_tree, max_draft_tokens):
1051+
def update_spec_dec_param(
1052+
self,
1053+
is_spec_decoding_enabled,
1054+
is_spec_dec_tree,
1055+
is_spec_dec_dynamic_tree,
1056+
max_draft_tokens,
1057+
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
1058+
):
1059+
1060+
if spec_decoding_tensor is not None:
1061+
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
1062+
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask
1063+
spec_decoding_generation_lengths = spec_decoding_tensor.generation_lengths
1064+
else:
1065+
spec_decoding_position_offsets = None
1066+
spec_decoding_packed_mask = None
1067+
spec_decoding_generation_lengths = None
10501068
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10511069
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
10521070
) < 100
10531071

1072+
if get_sm_version() >= 100:
1073+
if is_spec_dec_tree or is_spec_dec_dynamic_tree:
1074+
assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
1075+
assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
1076+
10541077
# use_spec_decoding is default to true by default, change in runtime by layers / requests
10551078
self.use_spec_decoding = self.is_spec_decoding_enabled
10561079

@@ -1068,7 +1091,7 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
10681091
self.spec_decoding_packed_mask = torch.empty(
10691092
[
10701093
self.max_num_requests, max_draft_tokens + 1,
1071-
math.ceil(max_draft_tokens / 32)
1094+
math.ceil((max_draft_tokens + 1) / 32)
10721095
],
10731096
dtype=torch.int,
10741097
device='cuda',
@@ -1081,7 +1104,18 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
10811104
)
10821105

10831106
if self.is_spec_dec_dynamic_tree:
1084-
assert False, "currently dynamic tree is not supported"
1107+
assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree"
1108+
assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree"
1109+
self.spec_decoding_position_offsets.copy_(
1110+
spec_decoding_position_offsets, non_blocking=True)
1111+
self.spec_decoding_packed_mask.copy_(spec_decoding_packed_mask,
1112+
non_blocking=True)
1113+
if spec_decoding_generation_lengths is not None:
1114+
self.spec_decoding_generation_lengths.copy_(
1115+
spec_decoding_generation_lengths, non_blocking=True)
1116+
else:
1117+
self.generate_spec_decoding_generation_length(
1118+
max_draft_tokens=max_draft_tokens)
10851119
else:
10861120
# Populate the mask that won't change during inference phase.
10871121
self.generate_spec_decoding_position_offsets(
@@ -1092,7 +1126,6 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
10921126
max_draft_tokens=max_draft_tokens)
10931127

10941128
def generate_spec_decoding_position_offsets(self, max_draft_tokens):
1095-
assert not self.is_spec_dec_tree, "only chained/linear tree is supported now"
10961129
position_offset = torch.arange(max_draft_tokens + 1,
10971130
dtype=torch.int,
10981131
device='cpu',
@@ -1103,7 +1136,6 @@ def generate_spec_decoding_position_offsets(self, max_draft_tokens):
11031136
non_blocking=True)
11041137

11051138
def generate_spec_decoding_packed_mask(self, max_draft_tokens):
1106-
assert not self.is_spec_dec_tree, "only chained/linear tree is supported now"
11071139
dummy_idx = torch.arange(max_draft_tokens + 1)
11081140
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
11091141
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ def __init__(
482482
self.py_target_probs = None
483483
self.py_last_draft_tokens = None
484484
self.py_num_accepted_draft_tokens = 0
485+
self.py_num_accepted_draft_tokens_indices = []
486+
self.py_rewind_draft_token_separate_adjustment = 0
485487
self.py_decoding_iter = 0
486488
self.is_attention_dp_dummy = False
487489
self.is_cuda_graph_dummy = False

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ..speculative.drafting_loops import ChainDrafter
5050
from ..speculative.eagle3 import Eagle3ResourceManager
5151
from ..speculative.mtp import SampleStateTensorsMTP
52+
from ..speculative.utils import SpecDecodingTensor
5253
from ..utils import (get_model_extra_attrs,
5354
set_per_request_piecewise_cuda_graph_flag,
5455
set_torch_compiling, with_model_extra_attrs)
@@ -370,9 +371,24 @@ def __init__(
370371
else:
371372
self.cache_indirection_attention = None
372373

374+
self.kv_cache_dtype_byte_size = self.get_kv_cache_dtype_byte_size()
375+
373376
def register_forward_pass_callable(self, callable: Callable):
374377
self.forward_pass_callable = callable
375378

379+
def get_kv_cache_dtype_byte_size(self) -> float:
380+
"""
381+
Returns the size (in bytes) occupied by kv cache type.
382+
"""
383+
layer_quant_mode = self.model.model_config.quant_config.layer_quant_mode
384+
if layer_quant_mode.has_fp4_kv_cache():
385+
return 1 / 2
386+
elif layer_quant_mode.has_fp8_kv_cache(
387+
) or layer_quant_mode.has_int8_kv_cache():
388+
return 1
389+
else:
390+
return 2
391+
376392
@property
377393
def runtime_draft_len(self):
378394
return self.max_total_draft_tokens if self.enable_spec_decode else 0
@@ -2249,6 +2265,7 @@ def forward(
22492265
new_tensors_device: Optional[SampleStateTensors] = None,
22502266
gather_context_logits: bool = False,
22512267
cache_indirection_buffer: Optional[torch.Tensor] = None,
2268+
spec_decoding_tensor: Optional[SpecDecodingTensor] = None,
22522269
):
22532270
kv_cache_manager = resource_manager.get_resource_manager(
22542271
self.kv_cache_manager_key)
@@ -2267,7 +2284,7 @@ def forward(
22672284
attn_metadata.update_spec_dec_param(
22682285
is_spec_dec_mode, spec_metadata.is_spec_dec_tree,
22692286
spec_metadata.is_spec_dec_dynamic_tree,
2270-
self.original_max_draft_len)
2287+
self.original_max_draft_len, spec_decoding_tensor)
22712288
else:
22722289
spec_resource_manager = None
22732290
spec_metadata = None

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -989,8 +989,13 @@ def _executor_loop_pp(self):
989989

990990
finished_requests = self._handle_responses()
991991
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
992+
attn_metadata = getattr(self.model_engine,
993+
'attn_metadata', None)
994+
kv_cache_dtype_byte_size = getattr(
995+
self.model_engine, 'kv_cache_dtype_byte_size', None)
992996
self.resource_manager.update_resources(
993-
previous_scheduled_batch)
997+
previous_scheduled_batch, attn_metadata,
998+
kv_cache_dtype_byte_size)
994999
self._remove_inflight_ids(previous_scheduled_batch)
9951000

9961001
self.wait_on_pp_send_handles(prev_microbatch_id)
@@ -1200,7 +1205,13 @@ def _executor_loop(self):
12001205

12011206
self._handle_canceled_requests()
12021207
finished_requests = self._handle_responses()
1203-
self.resource_manager.update_resources(scheduled_batch)
1208+
attn_metadata = getattr(self.model_engine, 'attn_metadata',
1209+
None)
1210+
kv_cache_dtype_byte_size = getattr(
1211+
self.model_engine, 'kv_cache_dtype_byte_size', None)
1212+
self.resource_manager.update_resources(
1213+
scheduled_batch, attn_metadata,
1214+
kv_cache_dtype_byte_size)
12041215
if self.enable_kv_cache_events:
12051216
self._add_kv_cache_events()
12061217

@@ -1403,7 +1414,12 @@ def _process_previous_batch(self):
14031414
self._handle_canceled_requests()
14041415
finished_requests = self._handle_responses()
14051416
scheduled_requests = self.previous_batch.sample_state.scheduled_requests
1406-
self.resource_manager.update_resources(scheduled_requests)
1417+
attn_metadata = getattr(self.model_engine, 'attn_metadata', None)
1418+
kv_cache_dtype_byte_size = getattr(self.model_engine,
1419+
'kv_cache_dtype_byte_size', None)
1420+
self.resource_manager.update_resources(scheduled_requests,
1421+
attn_metadata,
1422+
kv_cache_dtype_byte_size)
14071423
if self.enable_kv_cache_events:
14081424
self._add_kv_cache_events()
14091425

0 commit comments

Comments
 (0)