@@ -80,6 +80,8 @@ class ChunkedContextMetadata:
80
80
max_query_len : int
81
81
max_seq_lens : int
82
82
chunked_context : Optional [ChunkedContextMetadata ] = None
83
+ sin : torch .Tensor = None
84
+ cos : torch .Tensor = None
83
85
84
86
85
87
@dataclass
@@ -92,6 +94,8 @@ class AscendMLADecodeMetadata:
92
94
max_seq_lens : int
93
95
seq_lens_list : list [int ]
94
96
attn_mask : Optional [torch .Tensor ] = None
97
+ sin : torch .Tensor = None
98
+ cos : torch .Tensor = None
95
99
96
100
97
101
@dataclass
@@ -200,6 +204,9 @@ def __init__(self,
200
204
)
201
205
ascend_config = get_ascend_config ()
202
206
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
207
+ self .rope_dim = self .runner .model_config .hf_text_config .qk_rope_head_dim
208
+ self .cos_cache = None
209
+ self .sin_cache = None
203
210
204
211
def reorder_batch (self , input_batch : "InputBatch" ,
205
212
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -318,13 +325,27 @@ def build_torchair_graph_dummy(
318
325
- 1 ,
319
326
dtype = torch .int32 ,
320
327
device = device )
328
+ sin = torch .ones (num_reqs ,
329
+ 1 ,
330
+ 1 ,
331
+ self .rope_dim ,
332
+ dtype = self .runner .dtype ,
333
+ device = device )
334
+ cos = torch .ones (num_reqs ,
335
+ 1 ,
336
+ 1 ,
337
+ self .rope_dim ,
338
+ dtype = self .runner .dtype ,
339
+ device = device )
321
340
decode_metadata = AscendMLADecodeMetadata (
322
341
input_positions = input_positions ,
323
342
block_table = block_table ,
324
343
seq_lens = seq_lens ,
325
344
seq_lens_list = seq_lens .tolist (),
326
345
max_seq_lens = 1 ,
327
- attn_mask = self .runner .spec_attn_mask )
346
+ attn_mask = self .runner .spec_attn_mask ,
347
+ sin = sin ,
348
+ cos = cos )
328
349
return self .metadata_cls ( # type: ignore
329
350
num_input_tokens = num_actual_tokens ,
330
351
num_actual_tokens = num_actual_tokens ,
@@ -370,6 +391,16 @@ def build(
370
391
seq_lens = seq_lens_cpu
371
392
max_query_len = query_lens .max ().item ()
372
393
max_seq_lens = seq_lens .max ().item ()
394
+ if self .cos_cache is None :
395
+ self .cos_cache = self .runner .get_model (
396
+ ).model .layers [0 ].self_attn .rotary_emb .cos_cached
397
+ self .sin_cache = self .runner .get_model (
398
+ ).model .layers [0 ].self_attn .rotary_emb .sin_cached
399
+ if self .cos_cache .dtype != self .runner .dtype : # type: ignore
400
+ self .cos_cache = self .cos_cache .to ( # type: ignore
401
+ self .runner .dtype ) # type: ignore
402
+ self .sin_cache = self .sin_cache .to ( # type: ignore
403
+ self .runner .dtype ) # type: ignore
373
404
374
405
prefill_metadata = None
375
406
chunked_context_metadata = None
@@ -415,18 +446,26 @@ def build(
415
446
chunk_seq_lens = chunk_seq_lens ,
416
447
workspace = self .chunked_prefill_workspace ,
417
448
)
418
-
449
+ prefill_input_positions = input_positions [tokens_start :]
450
+ cos = self .cos_cache [
451
+ prefill_input_positions ].unsqueeze ( # type: ignore
452
+ 1 ).unsqueeze (2 )
453
+ sin = self .sin_cache [
454
+ prefill_input_positions ].unsqueeze ( # type: ignore
455
+ 1 ).unsqueeze (2 )
419
456
prefill_metadata = AscendMLAPrefillMetadata (
420
457
attn_mask = self .runner .attn_mask ,
421
458
query_lens = query_lens [tokens_start :],
422
459
seq_lens = seq_lens ,
423
460
context_lens = seq_lens [tokens_start :],
424
- input_positions = input_positions [ tokens_start :] ,
461
+ input_positions = prefill_input_positions ,
425
462
block_table = block_table [reqs_start :, ...],
426
463
max_query_len = max_query_len ,
427
464
max_seq_lens = max_seq_lens ,
428
465
query_start_loc = prefill_query_start_loc ,
429
466
chunked_context = chunked_context_metadata ,
467
+ sin = sin ,
468
+ cos = cos ,
430
469
)
431
470
432
471
decode_metadata = None
@@ -467,14 +506,20 @@ def build(
467
506
dtype = input_positions .dtype ,
468
507
device = input_positions .device )
469
508
input_positions = torch .cat ([input_positions , padding_0 ])
509
+ cos = self .cos_cache [input_positions ].unsqueeze ( # type: ignore
510
+ 1 ).unsqueeze (2 )
511
+ sin = self .sin_cache [input_positions ].unsqueeze ( # type: ignore
512
+ 1 ).unsqueeze (2 )
470
513
471
514
decode_metadata = AscendMLADecodeMetadata (
472
515
input_positions = input_positions ,
473
516
block_table = block_table ,
474
517
seq_lens = seq_lens ,
475
518
seq_lens_list = seq_lens .tolist (),
476
519
max_seq_lens = max_seq_lens ,
477
- attn_mask = self .runner .spec_attn_mask )
520
+ attn_mask = self .runner .spec_attn_mask ,
521
+ sin = sin ,
522
+ cos = cos )
478
523
479
524
return self .metadata_cls ( # type: ignore
480
525
num_actual_tokens = num_actual_tokens ,
@@ -1069,15 +1114,8 @@ def forward(
1069
1114
decode_k_nope = None
1070
1115
assert attn_metadata .decode is not None
1071
1116
if self .running_in_graph :
1072
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1073
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1074
- dtype = decode_hs_or_q_c .dtype )
1075
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1076
- dtype = decode_hs_or_q_c .dtype )
1077
- cos = cos [attn_metadata .decode .input_positions ]
1078
- sin = sin [attn_metadata .decode .input_positions ]
1079
- cos = cos [:, None , None , :]
1080
- sin = sin [:, None , None , :]
1117
+ cos = attn_metadata .decode .cos
1118
+ sin = attn_metadata .decode .sin
1081
1119
with npu_stream_switch ("mla_secondary" ,
1082
1120
0 ,
1083
1121
enabled = enable_multistream_mla ):
@@ -1124,15 +1162,8 @@ def forward(
1124
1162
prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1125
1163
if self .torchair_graph_enabled :
1126
1164
num_tokens = prefill_hs_or_q_c .shape [0 ]
1127
- seq_len = self .rotary_emb .max_position_embeddings * self .rotary_emb .scaling_factor
1128
- cos = self .rotary_emb .cos_cached [:seq_len ].to (
1129
- dtype = prefill_q_pe .dtype )
1130
- sin = self .rotary_emb .sin_cached [:seq_len ].to (
1131
- dtype = prefill_q_pe .dtype )
1132
- cos = cos [attn_metadata .prefill .input_positions ]
1133
- sin = sin [attn_metadata .prefill .input_positions ]
1134
- cos = cos [:, None , None , :]
1135
- sin = sin [:, None , None , :]
1165
+ cos = attn_metadata .prefill .cos
1166
+ sin = attn_metadata .prefill .sin
1136
1167
1137
1168
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1138
1169
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
0 commit comments