Skip to content

Commit e6b4953

Browse files
heyuhhhlfr-0531
authored andcommitted
fix rebase breaks
Signed-off-by: yuhangh <[email protected]> fix rebase bug. Signed-off-by: Fanrong Li <[email protected]> fix rebase bug. Signed-off-by: Fanrong Li <[email protected]>
1 parent 58f9538 commit e6b4953

File tree

6 files changed

+35
-85
lines changed

6 files changed

+35
-85
lines changed

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,10 +494,11 @@ class AttentionOp
494494
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
495495
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
496496
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
497-
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
498-
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
499-
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
500-
mSkipAttn, mFuseFp4Quant, mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
497+
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mMLAParams.data(), mCpSize, mCpRank,
498+
mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize,
499+
mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA,
500+
mUseKVCache, mSkipAttn, mFuseFp4Quant, mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores,
501+
mAttentionChunkSize.value_or(-1));
501502
};
502503

503504
private:

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ void initBindings(nb::module_& m)
4545
nb::arg("q_scaling"), nb::arg("rotary_embedding_int_params"), nb::arg("rotary_embedding_base"),
4646
nb::arg("rotary_embedding_scales"), nb::arg("rotary_embedding_max_position_info"),
4747
nb::arg("use_paged_context_fmha"), nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"),
48-
nb::arg("chunked_prefill_buffer_batch_size") = std::nullopt,
49-
nb::arg("q_lora_rank") = std::nullopt, nb::arg("kv_lora_rank") = std::nullopt,
50-
nb::arg("qk_nope_head_dim") = std::nullopt, nb::arg("qk_rope_head_dim") = std::nullopt,
51-
nb::arg("v_head_dim") = std::nullopt, nb::arg("mrope_rotary_cos_sin") = std::nullopt,
52-
nb::arg("mrope_position_deltas") = std::nullopt, nb::arg("attention_chunk_size") = std::nullopt,
53-
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
54-
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_attention_params") = std::nullopt,
55-
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
48+
nb::arg("chunked_prefill_buffer_batch_size") = std::nullopt, nb::arg("q_lora_rank") = std::nullopt,
49+
nb::arg("kv_lora_rank") = std::nullopt, nb::arg("qk_nope_head_dim") = std::nullopt,
50+
nb::arg("qk_rope_head_dim") = std::nullopt, nb::arg("v_head_dim") = std::nullopt,
51+
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
52+
nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt,
53+
nb::arg("spec_decoding_bool_params"), nb::arg("spec_decoding_tensor_params"),
54+
nb::arg("sparse_attention_params") = std::nullopt, "Multi-head attention operation",
55+
nb::call_guard<nb::gil_scoped_release>());
5656
}
5757
} // namespace tensorrt_llm::nanobind::thop

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ void initBindings(pybind11::module_& m)
4545
py::arg("q_scaling"), py::arg("rotary_embedding_int_params"), py::arg("rotary_embedding_base"),
4646
py::arg("rotary_embedding_scales"), py::arg("rotary_embedding_max_position_info"),
4747
py::arg("use_paged_context_fmha"), py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),
48-
py::arg("chunked_prefill_buffer_batch_size") = std::nullopt,
49-
py::arg("q_lora_rank") = std::nullopt, py::arg("kv_lora_rank") = std::nullopt,
50-
py::arg("qk_nope_head_dim") = std::nullopt, py::arg("qk_rope_head_dim") = std::nullopt,
51-
py::arg("v_head_dim") = std::nullopt, py::arg("mrope_rotary_cos_sin") = std::nullopt,
52-
py::arg("mrope_position_deltas") = std::nullopt, py::arg("attention_chunk_size") = std::nullopt,
53-
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
54-
py::arg("spec_decoding_tensor_params"), py::arg("sparse_attention_params") = std::nullopt,
55-
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
48+
py::arg("chunked_prefill_buffer_batch_size") = std::nullopt, py::arg("q_lora_rank") = std::nullopt,
49+
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
50+
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
51+
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
52+
py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt,
53+
py::arg("spec_decoding_bool_params"), py::arg("spec_decoding_tensor_params"),
54+
py::arg("sparse_attention_params") = std::nullopt, "Multi-head attention operation",
55+
py::call_guard<py::gil_scoped_release>());
5656
}
5757
} // namespace tensorrt_llm::pybind::thop

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
1616
from tensorrt_llm._utils import get_size_in_bytes, next_power_of_two
1717
from tensorrt_llm.bindings import DataType
18-
from tensorrt_llm.bindings.executor import ExecutorConfig, KvCacheConfig
18+
from tensorrt_llm.bindings.executor import KvCacheConfig
1919
from tensorrt_llm.bindings.internal.batch_manager import \
2020
CacheType as CacheTypeCpp
2121
from tensorrt_llm.mapping import Mapping
@@ -843,9 +843,7 @@ def compute_page_count(self, token_count: int, tokens_per_page: int) -> int:
843843

844844
@staticmethod
845845
def get_cache_size_per_token(model_config: ModelConfig,
846-
executor_config: ExecutorConfig,
847-
mapping: Mapping):
848-
sparse_attn_config = executor_config.sparse_attention_config
846+
tokens_per_block: int, mapping: Mapping):
849847
# get kv cache dtype bytes
850848
mem_per_token = 2
851849
quant_config = model_config.quant_config
@@ -875,7 +873,7 @@ def get_cache_size_per_token(model_config: ModelConfig,
875873

876874
# K and V
877875
# 2 for K and V, 2 * kt_tokens_per_block / tokens_per_block for KT cache
878-
tokens_per_block = executor_config.tokens_per_block
876+
sparse_attn_config = model_config.sparse_attention_config
879877
kt_tokens_per_block = next_power_of_two(
880878
math.ceil(tokens_per_block / sparse_attn_config.page_size))
881879
kv_factor = 2 + 2 * kt_tokens_per_block / tokens_per_block

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
1111
from tensorrt_llm.bindings.executor import DecodingMode
1212
from tensorrt_llm.llmapi.llm_args import (PeftCacheConfig, SamplerType,
13-
SpeculativeConfig, SparseAttentionConfig)
13+
SparseAttentionConfig,
14+
SpeculativeConfig)
1415
from tensorrt_llm.logger import logger
1516
from tensorrt_llm.lora_helper import (LoraConfig,
1617
get_default_trtllm_modules_to_hf_modules)
@@ -40,10 +41,9 @@
4041
GB = 1 << 30
4142

4243

43-
def get_kv_cache_manager_cls(model_config: ModelConfig,
44-
executor_config: ExecutorConfig):
44+
def get_kv_cache_manager_cls(model_config: ModelConfig):
4545
config = model_config.pretrained_config
46-
sparse_attn_config = executor_config.sparse_attention_config
46+
sparse_attn_config = model_config.sparse_attention_config
4747
if is_mla(config):
4848
return KVCacheManager
4949
elif is_nemotron_hybrid(config):
@@ -93,46 +93,7 @@ def __init__(
9393
self._max_seq_len = max_seq_len
9494
self._max_batch_size = max_batch_size
9595
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
96-
model_engine.model.model_config, executor_config)
97-
98-
@staticmethod
99-
def _get_cache_size_per_token(model_config: ModelConfig,
100-
mapping: Mapping) -> int:
101-
mem_per_token = 2
102-
quant_config = model_config.quant_config
103-
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
104-
):
105-
mem_per_token = 1
106-
107-
config = model_config.pretrained_config
108-
109-
num_key_value_heads = getattr(config, 'num_key_value_heads',
110-
config.num_attention_heads)
111-
if isinstance(num_key_value_heads, Iterable):
112-
num_key_value_heads = sum(num_key_value_heads) / len(
113-
num_key_value_heads)
114-
115-
mla = is_mla(config)
116-
tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size
117-
118-
kv_factor = 2
119-
if mla:
120-
# MLA has kv_lora_rank and qk_rope_head_dim
121-
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
122-
kv_factor = 1
123-
else:
124-
_head_dim = getattr(config, 'head_dim', None)
125-
if not isinstance(_head_dim, int):
126-
_head_dim = config.hidden_size // config.num_attention_heads
127-
head_dim = _head_dim * num_key_value_heads // tp_size
128-
129-
# provide at least 1 layer to prevent division by zero cache size
130-
num_attention_layers = max(
131-
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
132-
mem_per_token *= num_attention_layers * head_dim
133-
# K and V
134-
mem_per_token *= kv_factor
135-
return mem_per_token
96+
model_engine.model.model_config)
13697

13798
def _get_free_gpu_memory_fraction(self) -> float:
13899
fraction = self._kv_cache_config.free_gpu_memory_fraction
@@ -144,11 +105,11 @@ def _get_kv_size_per_token(self):
144105
model_config = self._model_engine.model.model_config
145106
mapping = self._mapping
146107
kv_size_per_token = self._kv_cache_manager_cls.get_cache_size_per_token(
147-
model_config, self._executor_config, mapping)
108+
model_config, self._tokens_per_block, mapping)
148109
if self._draft_model_engine is not None:
149110
draft_model_config = self._draft_model_engine.model.model_config
150111
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
151-
draft_model_config, self._executor_config, mapping)
112+
draft_model_config, self._tokens_per_block, mapping)
152113
return kv_size_per_token
153114

154115
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import tensorrt_llm
1111
import tensorrt_llm.bindings
1212
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
13-
from tensorrt_llm.bindings.executor import ExecutorConfig
1413
from tensorrt_llm.lora_helper import LoraConfig
1514
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
1615
from tensorrt_llm.sampling_params import SamplingParams
@@ -279,11 +278,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
279278
# Standard case: use original Python implementation
280279
self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks(
281280
kv_cache_config=kv_cache_config,
282-
head_dim=head_dim,
283-
tokens_per_block=tokens_per_block,
284281
mapping=mapping,
285-
dtype=dtype,
286-
kv_factor=self.kv_factor,
287282
)
288283
blocks_per_window = {
289284
self.max_attention_window_vec[0]:
@@ -549,8 +544,7 @@ def calculate_scaling_factor_size_bytes(
549544

550545
@staticmethod
551546
def get_cache_size_per_token(model_config: ModelConfig,
552-
executor_config: ExecutorConfig,
553-
mapping: Mapping):
547+
tokens_per_block: int, mapping: Mapping):
554548
# get kv cache dtype bytes
555549
mem_per_token = 2
556550
quant_config = model_config.quant_config
@@ -605,13 +599,9 @@ def get_cache_bytes_per_token(self):
605599
scaling_factor_dtype=DataType.FP8)
606600
return cache_size_bytes_per_token
607601

608-
def calculate_max_num_blocks(self,
609-
kv_cache_config: KvCacheConfigCpp,
610-
head_dim: int,
611-
tokens_per_block: int,
612-
mapping: Mapping,
613-
dtype: DataType,
614-
kv_factor: int = 2):
602+
def calculate_max_num_blocks(self, kv_cache_config: KvCacheConfigCpp,
603+
mapping: Mapping):
604+
tokens_per_block = self.tokens_per_block
615605
free_mem_fraction = (kv_cache_config.free_gpu_memory_fraction
616606
if kv_cache_config.free_gpu_memory_fraction
617607
is not None else 0.9)

0 commit comments

Comments
 (0)