Skip to content

Commit dc38601

Browse files
heyuhhhSuperjomn
authored andcommitted
[None] [feat] Optimize the algorithm part of RocketKV (NVIDIA#9333)
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
1 parent 5bca1f6 commit dc38601

File tree

10 files changed

+321
-227
lines changed

10 files changed

+321
-227
lines changed

cpp/tensorrt_llm/thop/IndexerTopKOp.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ void indexer_topk_decode(
5757
TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
5858

5959
TORCH_CHECK(next_n > 0, "next_n must be greater than 0");
60-
TORCH_CHECK(index_topk == 2048, "index_topk must be 2048 for now");
6160

6261
int32_t num_rows = static_cast<int32_t>(numRows64);
6362
int32_t num_columns = static_cast<int32_t>(numColumns64);
@@ -95,7 +94,6 @@ void indexer_topk_prefill(th::Tensor const& logits, th::Tensor const& row_starts
9594

9695
TORCH_CHECK(indices.dim() == 2, "indices must be a 2D Tensor");
9796
TORCH_CHECK(logits.dim() == 2, "logits must be a 2D Tensor");
98-
TORCH_CHECK(index_topk == 2048, "index_topk must be 2048 for now");
9997

10098
auto const inputSize = logits.sizes();
10199
auto const numRows64 = inputSize[0];

examples/llm-api/llm_sparse_attention.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def parse_arguments():
4444
type=str,
4545
default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
4646
)
47+
4748
# Build config
4849
parser.add_argument('--algo',
4950
type=str,
@@ -53,6 +54,8 @@ def parse_arguments():
5354
type=str,
5455
default='TRTLLM',
5556
choices=['VANILLA', 'TRTLLM'])
57+
58+
# RocketKV config
5659
parser.add_argument('--window_size',
5760
type=int,
5861
default=32,
@@ -65,6 +68,14 @@ def parse_arguments():
6568
type=int,
6669
default=2048,
6770
help="The prompt budget for RocketKV.")
71+
parser.add_argument('--topk',
72+
type=int,
73+
default=64,
74+
help='Top-k for RocketKV')
75+
parser.add_argument('--kt_cache_dtype',
76+
type=str,
77+
default='float8_e5m2',
78+
choices=['bfloat16', 'float8_e5m2'])
6879
parser.add_argument('--index_max_chunk_size',
6980
type=int,
7081
default=32768,
@@ -106,6 +117,7 @@ def parse_arguments():
106117
# KV cache
107118
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
108119
parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
120+
parser.add_argument('--tokens_per_block', type=int, default=32)
109121
parser.add_argument('--num_samples', type=int, default=10)
110122

111123
# Runtime
@@ -139,8 +151,8 @@ def run_llm(args, sparse_attention_config):
139151
enable_block_reuse=
140152
False, # sparse attention does not support kv cache reuse now
141153
free_gpu_memory_fraction=args.kv_cache_fraction,
154+
tokens_per_block=args.tokens_per_block,
142155
dtype=args.kv_cache_dtype,
143-
tokens_per_block=64,
144156
)
145157

146158
cuda_graph_config = CudaGraphConfig(
@@ -191,6 +203,8 @@ def run_RocketKV(args):
191203
window_size=args.window_size,
192204
kernel_size=args.kernel_size,
193205
prompt_budget=args.prompt_budget,
206+
topk=args.topk,
207+
kt_cache_dtype=args.kt_cache_dtype,
194208
)
195209
run_llm(args, sparse_attention_config)
196210

examples/longbench/eval_longbench_v1.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,22 @@ def parse_arguments() -> argparse.Namespace:
150150
type=int,
151151
default=63,
152152
help='Kernel size for RocketKV')
153-
parser.add_argument('--topr',
153+
parser.add_argument('--topk',
154154
type=int,
155-
default=90,
156-
help='Top-r for RocketKV')
155+
default=64,
156+
help='Top-k for RocketKV')
157+
parser.add_argument('--kt_cache_dtype',
158+
type=str,
159+
default='float8_e5m2',
160+
choices=['bfloat16', 'float8_e5m2'],
161+
help='KT cache data type')
157162

158163
# KV cache configuration
159164
parser.add_argument('--kv_cache_dtype',
160165
type=str,
161166
default='auto',
162167
help='KV cache data type')
168+
parser.add_argument('--tokens_per_block', type=int, default=32)
163169
parser.add_argument('--kv_cache_fraction',
164170
type=float,
165171
default=0.7,
@@ -320,6 +326,7 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
320326
# sparse attention doesn't support KV cache reuse
321327
enable_block_reuse=False,
322328
free_gpu_memory_fraction=args.kv_cache_fraction,
329+
tokens_per_block=args.tokens_per_block,
323330
)
324331

325332
# Configure CUDA graph
@@ -335,7 +342,8 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
335342
window_size=args.window_size,
336343
kernel_size=args.kernel_size,
337344
prompt_budget=args.token_budget,
338-
topr=args.topr,
345+
topk=args.topk,
346+
kt_cache_dtype=args.kt_cache_dtype,
339347
)
340348
logger.info(f"Using RocketKV sparse attention")
341349
else:
@@ -427,6 +435,14 @@ def evaluate_single_dataset(
427435
formatted_prompt = format_prompt_style(sample, prompt_format,
428436
chat_template, dataset,
429437
tokenizer)
438+
# Truncate prompt if it's too long
439+
token_ids = tokenizer.encode(formatted_prompt, truncation=False)
440+
if len(token_ids) > args.max_seq_len:
441+
half = (args.max_seq_len - max_new_tokens) // 2
442+
formatted_prompt = tokenizer.decode(
443+
token_ids[:half], skip_special_tokens=True) + tokenizer.decode(
444+
token_ids[-half:], skip_special_tokens=True)
445+
430446
prompts.append(formatted_prompt)
431447

432448
if len(prompts) == 0:

examples/longbench/eval_longbench_v2.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,15 @@ def parse_arguments() -> argparse.Namespace:
121121
type=int,
122122
default=63,
123123
help='Kernel size for RocketKV')
124-
parser.add_argument('--topr',
124+
parser.add_argument('--topk',
125125
type=int,
126-
default=90,
127-
help='Top-r for RocketKV')
126+
default=64,
127+
help='Top-k for RocketKV')
128+
parser.add_argument('--kt_cache_dtype',
129+
type=str,
130+
default='float8_e5m2',
131+
choices=['bfloat16', 'float8_e5m2'],
132+
help='KT cache data type')
128133

129134
# KV cache configuration
130135
parser.add_argument('--kv_cache_dtype',
@@ -356,7 +361,8 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
356361
window_size=args.window_size,
357362
kernel_size=args.kernel_size,
358363
prompt_budget=args.token_budget,
359-
topr=args.topr,
364+
topk=args.topk,
365+
kt_cache_dtype=args.kt_cache_dtype,
360366
)
361367
logger.info(f"Using RocketKV sparse attention")
362368
else:

0 commit comments

Comments
 (0)