@@ -431,7 +431,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
431431 # num_kv_splits_indptr = None
432432
433433 if forward_batch .forward_mode .is_decode_or_idle ():
434- if spec_info is None or forward_batch . forward_mode . is_idle () :
434+ if spec_info is None :
435435 kv_indptr [1 : bs + 1 ] = torch .cumsum (forward_batch .seq_lens , dim = 0 )
436436 kv_indptr = kv_indptr [: bs + 1 ]
437437 kv_indices = torch .empty (
@@ -1074,17 +1074,6 @@ def init_forward_metadata_replay_cuda_graph(
10741074 seq_lens_cpu : Optional [torch .Tensor ],
10751075 ):
10761076
1077- num_kv_splits = None
1078- # num_kv_splits_indptr = None
1079-
1080- work_metadata = None
1081- work_info_set = None
1082- work_indptr = None
1083-
1084- reduce_indptr = None
1085- reduce_final_map = None
1086- reduce_partial_map = None
1087-
10881077 if forward_mode .is_decode_or_idle ():
10891078 kv_indptr = self .kv_indptr
10901079 kv_indices = self .cuda_graph_kv_indices
@@ -1104,58 +1093,6 @@ def init_forward_metadata_replay_cuda_graph(
11041093 kv_indptr [: spec_info .kv_indptr .shape [0 ]] = spec_info .kv_indptr
11051094 kv_indices [: spec_info .kv_indices .shape [0 ]] = spec_info .kv_indices
11061095
1107- if self .use_mla :
1108- qo_indptr = self .qo_indptr_ [: bs + 1 ]
1109- qo_indptr [1 : bs + 1 ] = torch .cumsum (
1110- self .cuda_graph_kv_last_page_len [:bs ], dim = 0
1111- )
1112- kv_last_page_len = self .cuda_graph_kv_last_page_len [:bs ]
1113- max_q_len = 1
1114-
1115- if _use_mla_ps_kernel :
1116- num_kv_splits = self .max_split_per_batch
1117-
1118- self .make_mla_meta_data (
1119- qo_indptr ,
1120- kv_indptr ,
1121- kv_last_page_len ,
1122- self .work_metadata ,
1123- self .work_info_set ,
1124- self .work_indptr ,
1125- self .reduce_indptr ,
1126- self .reduce_final_map ,
1127- self .reduce_partial_map ,
1128- max_q_len ,
1129- fast_mode = fast_mode ,
1130- max_split_per_batch = num_kv_splits ,
1131- intra_batch_mode = intra_batch_mode ,
1132- )
1133-
1134- work_metadata = self .work_metadata
1135- work_info_set = self .work_info_set
1136- work_indptr = self .work_indptr
1137-
1138- reduce_indptr = self .reduce_indptr
1139- reduce_final_map = self .reduce_final_map
1140- reduce_partial_map = self .reduce_partial_map
1141-
1142- self .forward_metadata = ForwardMetadata (
1143- kv_indptr ,
1144- kv_indices ,
1145- qo_indptr ,
1146- kv_last_page_len ,
1147- max_q_len ,
1148- kv_indptr [- 1 ].item (),
1149- work_metadata = work_metadata ,
1150- work_info_set = work_info_set ,
1151- work_indptr = work_indptr ,
1152- reduce_indptr = reduce_indptr ,
1153- reduce_final_map = reduce_final_map ,
1154- reduce_partial_map = reduce_partial_map ,
1155- num_kv_splits = num_kv_splits ,
1156- # num_kv_splits_indptr=num_kv_splits_indptr,
1157- )
1158-
11591096 elif forward_mode .is_target_verify ():
11601097 bs = len (req_pool_indices )
11611098 qo_indptr = self .qo_indptr [: bs + 1 ]
@@ -1180,57 +1117,7 @@ def init_forward_metadata_replay_cuda_graph(
11801117 self .req_to_token .stride (0 ),
11811118 )
11821119
1183- kv_last_page_len = self .cuda_graph_kv_last_page_len [:bs ]
1184- max_q_len = self .num_draft_tokens
1185-
1186- # if self.kv_cache_dtype == fp8_dtype:
1187- if _use_mla_ps_kernel :
1188-
1189- num_kv_splits = self .max_split_per_batch
1190-
1191- self .make_mla_meta_data (
1192- qo_indptr ,
1193- kv_indptr ,
1194- kv_last_page_len ,
1195- self .work_metadata ,
1196- self .work_info_set ,
1197- self .work_indptr ,
1198- self .reduce_indptr ,
1199- self .reduce_final_map ,
1200- self .reduce_partial_map ,
1201- max_q_len ,
1202- fast_mode = fast_mode ,
1203- max_split_per_batch = num_kv_splits ,
1204- intra_batch_mode = intra_batch_mode ,
1205- )
1206-
1207- work_metadata = self .work_metadata
1208- work_info_set = self .work_info_set
1209- work_indptr = self .work_indptr
1210-
1211- reduce_indptr = self .reduce_indptr
1212- reduce_final_map = self .reduce_final_map
1213- reduce_partial_map = self .reduce_partial_map
1214-
1215- self .forward_metadata = ForwardMetadata (
1216- kv_indptr ,
1217- kv_indices ,
1218- qo_indptr ,
1219- kv_last_page_len ,
1220- max_q_len ,
1221- kv_indptr [- 1 ].item (),
1222- work_metadata = work_metadata ,
1223- work_info_set = work_info_set ,
1224- work_indptr = work_indptr ,
1225- reduce_indptr = reduce_indptr ,
1226- reduce_final_map = reduce_final_map ,
1227- reduce_partial_map = reduce_partial_map ,
1228- num_kv_splits = num_kv_splits ,
1229- # num_kv_splits_indptr=num_kv_splits_indptr,
1230- )
1231-
12321120 elif forward_mode .is_draft_extend ():
1233- num_tokens_per_bs = self .speculative_num_steps + 1
12341121 seq_lens = seq_lens [:bs ]
12351122 accept_lens = spec_info .accept_length [:bs ]
12361123 qo_indptr = self .qo_indptr [: bs + 1 ]
@@ -1248,54 +1135,6 @@ def init_forward_metadata_replay_cuda_graph(
12481135 self .req_to_token .stride (0 ),
12491136 )
12501137
1251- kv_last_page_len = self .cuda_graph_kv_last_page_len [:bs ]
1252- max_q_len = num_tokens_per_bs
1253-
1254- if _use_mla_ps_kernel :
1255-
1256- num_kv_splits = self .max_split_per_batch
1257-
1258- self .make_mla_meta_data (
1259- qo_indptr ,
1260- kv_indptr ,
1261- kv_last_page_len ,
1262- self .work_metadata ,
1263- self .work_info_set ,
1264- self .work_indptr ,
1265- self .reduce_indptr ,
1266- self .reduce_final_map ,
1267- self .reduce_partial_map ,
1268- max_q_len ,
1269- fast_mode = fast_mode ,
1270- max_split_per_batch = num_kv_splits ,
1271- intra_batch_mode = intra_batch_mode ,
1272- )
1273-
1274- work_metadata = self .work_metadata
1275- work_info_set = self .work_info_set
1276- work_indptr = self .work_indptr
1277-
1278- reduce_indptr = self .reduce_indptr
1279- reduce_final_map = self .reduce_final_map
1280- reduce_partial_map = self .reduce_partial_map
1281-
1282- self .forward_metadata = ForwardMetadata (
1283- kv_indptr ,
1284- kv_indices ,
1285- qo_indptr ,
1286- kv_last_page_len ,
1287- max_q_len ,
1288- kv_indptr [- 1 ].item (),
1289- work_metadata = work_metadata ,
1290- work_info_set = work_info_set ,
1291- work_indptr = work_indptr ,
1292- reduce_indptr = reduce_indptr ,
1293- reduce_final_map = reduce_final_map ,
1294- reduce_partial_map = reduce_partial_map ,
1295- num_kv_splits = num_kv_splits ,
1296- # num_kv_splits_indptr=num_kv_splits_indptr,
1297- )
1298-
12991138 else :
13001139 raise ValueError ("Invalid forward mode" )
13011140
@@ -1527,6 +1366,23 @@ def forward_extend(
15271366
15281367 num_kv_splits = self .forward_metadata .num_kv_splits
15291368
1369+ if layer .layer_id == 0 and _use_mla_ps_kernel :
1370+ self .make_mla_meta_data (
1371+ self .forward_metadata .qo_indptr ,
1372+ self .forward_metadata .kv_indptr ,
1373+ self .forward_metadata .kv_last_page_len ,
1374+ work_metadata ,
1375+ work_info_set ,
1376+ work_indptr ,
1377+ reduce_indptr ,
1378+ reduce_final_map ,
1379+ reduce_partial_map ,
1380+ self .forward_metadata .max_q_len ,
1381+ fast_mode = fast_mode ,
1382+ max_split_per_batch = num_kv_splits ,
1383+ intra_batch_mode = intra_batch_mode ,
1384+ )
1385+
15301386 mla_decode_fwd (
15311387 q ,
15321388 K_Buffer .view (- 1 , 1 , 1 , layer .qk_head_dim ),
@@ -1562,6 +1418,23 @@ def forward_extend(
15621418
15631419 num_kv_splits = self .forward_metadata .num_kv_splits
15641420
1421+ if layer .layer_id == 0 and _use_mla_ps_kernel :
1422+ self .make_mla_meta_data (
1423+ self .forward_metadata .qo_indptr ,
1424+ self .forward_metadata .kv_indptr ,
1425+ self .forward_metadata .kv_last_page_len ,
1426+ work_metadata ,
1427+ work_info_set ,
1428+ work_indptr ,
1429+ reduce_indptr ,
1430+ reduce_final_map ,
1431+ reduce_partial_map ,
1432+ self .forward_metadata .max_q_len ,
1433+ fast_mode = fast_mode ,
1434+ max_split_per_batch = num_kv_splits ,
1435+ intra_batch_mode = intra_batch_mode ,
1436+ )
1437+
15651438 if self .forward_metadata .run_graph is not True :
15661439
15671440 bs , q_pad , q_mask = pad_sequence_with_mask (
@@ -1704,6 +1577,23 @@ def forward_decode(
17041577
17051578 num_kv_splits = self .forward_metadata .num_kv_splits
17061579
1580+ if layer .layer_id == 0 and _use_mla_ps_kernel :
1581+ self .make_mla_meta_data (
1582+ self .forward_metadata .qo_indptr ,
1583+ self .forward_metadata .kv_indptr ,
1584+ self .forward_metadata .kv_last_page_len ,
1585+ work_metadata ,
1586+ work_info_set ,
1587+ work_indptr ,
1588+ reduce_indptr ,
1589+ reduce_final_map ,
1590+ reduce_partial_map ,
1591+ self .forward_metadata .max_q_len ,
1592+ fast_mode = fast_mode ,
1593+ max_split_per_batch = num_kv_splits ,
1594+ intra_batch_mode = intra_batch_mode ,
1595+ )
1596+
17071597 mla_decode_fwd (
17081598 q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
17091599 k_buffer .view (- 1 , 1 , 1 , layer .qk_head_dim ),
0 commit comments