Skip to content

Commit 98cadc2

Browse files
authored
[Perf] Avoid performing index selection of sin/cos cache every layer (vllm-project#1890)
Optimize number of index selections of sin/cos cache. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@656c24f Signed-off-by: whx-sjtu <[email protected]>
1 parent 0190b68 commit 98cadc2

File tree

3 files changed

+73
-22
lines changed

3 files changed

+73
-22
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,30 @@ def test_build_dummy(self, mock_ascend_config):
331331
runner.chunked_prefill_enabled = False
332332
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
333333
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
334+
runner.dtype = torch.float16
334335

335336
builder = AscendMLAMetadataBuilder(runner=runner,
336337
metadata_cls=AscendMLAMetadata)
338+
builder.rope_dim = 64
337339

338340
with patch.object(builder,
339341
"_get_graph_runner_block_tables",
340342
side_effect=lambda x, y: y):
341343
metadata = builder.build_torchair_graph_dummy(3, 3)
342344

345+
sin_golden = torch.ones(3,
346+
1,
347+
1,
348+
64,
349+
dtype=runner.dtype,
350+
device=runner.device)
351+
cos_golden = torch.ones(3,
352+
1,
353+
1,
354+
64,
355+
dtype=runner.dtype,
356+
device=runner.device)
357+
343358
self.assertIsInstance(metadata, AscendMLAMetadata)
344359
self.assertEqual(metadata.num_input_tokens, 3)
345360
self.assertEqual(metadata.num_actual_tokens, 3)
@@ -354,6 +369,8 @@ def test_build_dummy(self, mock_ascend_config):
354369
self.assertEqual(metadata.seq_lens.shape[0], 3)
355370
self.assertEqual(metadata.slot_mapping.shape[0], 3)
356371
self.assertEqual(metadata.query_start_loc.shape[0], 3)
372+
assert torch.equal(sin_golden, metadata.decode.sin)
373+
assert torch.equal(cos_golden, metadata.decode.cos)
357374

358375

359376
class TestAscendMLAImpl(TestBase):

vllm_ascend/attention/mla_v1.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class ChunkedContextMetadata:
8080
max_query_len: int
8181
max_seq_lens: int
8282
chunked_context: Optional[ChunkedContextMetadata] = None
83+
sin: torch.Tensor = None
84+
cos: torch.Tensor = None
8385

8486

8587
@dataclass
@@ -92,6 +94,8 @@ class AscendMLADecodeMetadata:
9294
max_seq_lens: int
9395
seq_lens_list: list[int]
9496
attn_mask: Optional[torch.Tensor] = None
97+
sin: torch.Tensor = None
98+
cos: torch.Tensor = None
9599

96100

97101
@dataclass
@@ -200,6 +204,9 @@ def __init__(self,
200204
)
201205
ascend_config = get_ascend_config()
202206
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
207+
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
208+
self.cos_cache = None
209+
self.sin_cache = None
203210

204211
def reorder_batch(self, input_batch: "InputBatch",
205212
scheduler_output: "SchedulerOutput") -> bool:
@@ -318,13 +325,27 @@ def build_torchair_graph_dummy(
318325
-1,
319326
dtype=torch.int32,
320327
device=device)
328+
sin = torch.ones(num_reqs,
329+
1,
330+
1,
331+
self.rope_dim,
332+
dtype=self.runner.dtype,
333+
device=device)
334+
cos = torch.ones(num_reqs,
335+
1,
336+
1,
337+
self.rope_dim,
338+
dtype=self.runner.dtype,
339+
device=device)
321340
decode_metadata = AscendMLADecodeMetadata(
322341
input_positions=input_positions,
323342
block_table=block_table,
324343
seq_lens=seq_lens,
325344
seq_lens_list=seq_lens.tolist(),
326345
max_seq_lens=1,
327-
attn_mask=self.runner.spec_attn_mask)
346+
attn_mask=self.runner.spec_attn_mask,
347+
sin=sin,
348+
cos=cos)
328349
return self.metadata_cls( # type: ignore
329350
num_input_tokens=num_actual_tokens,
330351
num_actual_tokens=num_actual_tokens,
@@ -370,6 +391,16 @@ def build(
370391
seq_lens = seq_lens_cpu
371392
max_query_len = query_lens.max().item()
372393
max_seq_lens = seq_lens.max().item()
394+
if self.cos_cache is None:
395+
self.cos_cache = self.runner.get_model(
396+
).model.layers[0].self_attn.rotary_emb.cos_cached
397+
self.sin_cache = self.runner.get_model(
398+
).model.layers[0].self_attn.rotary_emb.sin_cached
399+
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
400+
self.cos_cache = self.cos_cache.to( # type: ignore
401+
self.runner.dtype) # type: ignore
402+
self.sin_cache = self.sin_cache.to( # type: ignore
403+
self.runner.dtype) # type: ignore
373404

374405
prefill_metadata = None
375406
chunked_context_metadata = None
@@ -415,18 +446,26 @@ def build(
415446
chunk_seq_lens=chunk_seq_lens,
416447
workspace=self.chunked_prefill_workspace,
417448
)
418-
449+
prefill_input_positions = input_positions[tokens_start:]
450+
cos = self.cos_cache[
451+
prefill_input_positions].unsqueeze( # type: ignore
452+
1).unsqueeze(2)
453+
sin = self.sin_cache[
454+
prefill_input_positions].unsqueeze( # type: ignore
455+
1).unsqueeze(2)
419456
prefill_metadata = AscendMLAPrefillMetadata(
420457
attn_mask=self.runner.attn_mask,
421458
query_lens=query_lens[tokens_start:],
422459
seq_lens=seq_lens,
423460
context_lens=seq_lens[tokens_start:],
424-
input_positions=input_positions[tokens_start:],
461+
input_positions=prefill_input_positions,
425462
block_table=block_table[reqs_start:, ...],
426463
max_query_len=max_query_len,
427464
max_seq_lens=max_seq_lens,
428465
query_start_loc=prefill_query_start_loc,
429466
chunked_context=chunked_context_metadata,
467+
sin=sin,
468+
cos=cos,
430469
)
431470

432471
decode_metadata = None
@@ -467,14 +506,20 @@ def build(
467506
dtype=input_positions.dtype,
468507
device=input_positions.device)
469508
input_positions = torch.cat([input_positions, padding_0])
509+
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
510+
1).unsqueeze(2)
511+
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
512+
1).unsqueeze(2)
470513

471514
decode_metadata = AscendMLADecodeMetadata(
472515
input_positions=input_positions,
473516
block_table=block_table,
474517
seq_lens=seq_lens,
475518
seq_lens_list=seq_lens.tolist(),
476519
max_seq_lens=max_seq_lens,
477-
attn_mask=self.runner.spec_attn_mask)
520+
attn_mask=self.runner.spec_attn_mask,
521+
sin=sin,
522+
cos=cos)
478523

479524
return self.metadata_cls( # type: ignore
480525
num_actual_tokens=num_actual_tokens,
@@ -1069,15 +1114,8 @@ def forward(
10691114
decode_k_nope = None
10701115
assert attn_metadata.decode is not None
10711116
if self.running_in_graph:
1072-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1073-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1074-
dtype=decode_hs_or_q_c.dtype)
1075-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1076-
dtype=decode_hs_or_q_c.dtype)
1077-
cos = cos[attn_metadata.decode.input_positions]
1078-
sin = sin[attn_metadata.decode.input_positions]
1079-
cos = cos[:, None, None, :]
1080-
sin = sin[:, None, None, :]
1117+
cos = attn_metadata.decode.cos
1118+
sin = attn_metadata.decode.sin
10811119
with npu_stream_switch("mla_secondary",
10821120
0,
10831121
enabled=enable_multistream_mla):
@@ -1124,15 +1162,8 @@ def forward(
11241162
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11251163
if self.torchair_graph_enabled:
11261164
num_tokens = prefill_hs_or_q_c.shape[0]
1127-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1128-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1129-
dtype=prefill_q_pe.dtype)
1130-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1131-
dtype=prefill_q_pe.dtype)
1132-
cos = cos[attn_metadata.prefill.input_positions]
1133-
sin = sin[attn_metadata.prefill.input_positions]
1134-
cos = cos[:, None, None, :]
1135-
sin = sin[:, None, None, :]
1165+
cos = attn_metadata.prefill.cos
1166+
sin = attn_metadata.prefill.sin
11361167

11371168
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11381169
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,6 +1799,9 @@ def _dummy_run(
17991799
attn_metadata.decode.input_positions)
18001800
torch._dynamo.mark_static(
18011801
get_forward_context().mc2_mask)
1802+
if hasattr(attn_metadata.decode, "sin"):
1803+
torch._dynamo.mark_static(attn_metadata.decode.sin)
1804+
torch._dynamo.mark_static(attn_metadata.decode.cos)
18021805
torch._dynamo.mark_static(attn_metadata.slot_mapping)
18031806
for kv in self.kv_caches:
18041807
assert isinstance(

0 commit comments

Comments
 (0)