Skip to content

Commit 43f8352

Browse files
authored
Revert "[AMD] support two batch overlapping for mori ep sgl-project#17953" (sgl-project#19161)
1 parent 45095ba commit 43f8352

File tree

11 files changed

+150
-718
lines changed

11 files changed

+150
-718
lines changed

docs/advanced_features/server_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
311311
| Argument | Description | Defaults | Options |
312312
| --- | --- | --- | --- |
313313
| `--expert-parallel-size`<br>`--ep-size`<br>`--ep` | The expert parallelism size. | `1` | Type: int |
314-
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `ascend_fuseep`|
314+
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`|
315315
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
316316
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
317317
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |

python/sglang/srt/batch_overlap/operations_strategy.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
from sglang.srt.batch_overlap.operations import Operation
88
from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
99
from sglang.srt.model_executor.forward_batch_info import ForwardMode
10-
from sglang.srt.utils import is_hip
11-
12-
_is_hip = is_hip()
1310

1411

1512
@dataclass
@@ -94,9 +91,7 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
9491
def _compute_moe_deepseek_blog_prefill(layer):
9592
device_properties = torch.cuda.get_device_properties(device="cuda")
9693
total_num_sms = device_properties.multi_processor_count
97-
deep_gemm_num_sms = None
98-
if not _is_hip:
99-
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
94+
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
10095

10196
return OperationsStrategy(
10297
deep_gemm_num_sms=deep_gemm_num_sms,
@@ -173,9 +168,7 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
173168
def _compute_moe_qwen3_prefill(layer):
174169
device_properties = torch.cuda.get_device_properties(device="cuda")
175170
total_num_sms = device_properties.multi_processor_count
176-
deep_gemm_num_sms = None
177-
if not _is_hip:
178-
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
171+
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
179172

180173
return OperationsStrategy(
181174
deep_gemm_num_sms=deep_gemm_num_sms,

python/sglang/srt/batch_overlap/two_batch_overlap.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from sglang.srt.layers.moe.token_dispatcher import (
3131
DeepEPDispatcher,
3232
MooncakeEPDispatcher,
33-
MoriEPDispatcher,
3433
)
3534
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
3635
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -1028,10 +1027,6 @@ def __init__(self, **kwargs):
10281027
self._inners = [
10291028
MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
10301029
]
1031-
elif get_moe_a2a_backend().is_mori():
1032-
self._inners = [
1033-
MoriEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
1034-
]
10351030

10361031
def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
10371032
return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)

python/sglang/srt/layers/attention/aiter_backend.py

Lines changed: 52 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
431431
# num_kv_splits_indptr = None
432432

433433
if forward_batch.forward_mode.is_decode_or_idle():
434-
if spec_info is None or forward_batch.forward_mode.is_idle():
434+
if spec_info is None:
435435
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
436436
kv_indptr = kv_indptr[: bs + 1]
437437
kv_indices = torch.empty(
@@ -1074,17 +1074,6 @@ def init_forward_metadata_replay_cuda_graph(
10741074
seq_lens_cpu: Optional[torch.Tensor],
10751075
):
10761076

1077-
num_kv_splits = None
1078-
# num_kv_splits_indptr = None
1079-
1080-
work_metadata = None
1081-
work_info_set = None
1082-
work_indptr = None
1083-
1084-
reduce_indptr = None
1085-
reduce_final_map = None
1086-
reduce_partial_map = None
1087-
10881077
if forward_mode.is_decode_or_idle():
10891078
kv_indptr = self.kv_indptr
10901079
kv_indices = self.cuda_graph_kv_indices
@@ -1104,58 +1093,6 @@ def init_forward_metadata_replay_cuda_graph(
11041093
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
11051094
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
11061095

1107-
if self.use_mla:
1108-
qo_indptr = self.qo_indptr_[: bs + 1]
1109-
qo_indptr[1 : bs + 1] = torch.cumsum(
1110-
self.cuda_graph_kv_last_page_len[:bs], dim=0
1111-
)
1112-
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
1113-
max_q_len = 1
1114-
1115-
if _use_mla_ps_kernel:
1116-
num_kv_splits = self.max_split_per_batch
1117-
1118-
self.make_mla_meta_data(
1119-
qo_indptr,
1120-
kv_indptr,
1121-
kv_last_page_len,
1122-
self.work_metadata,
1123-
self.work_info_set,
1124-
self.work_indptr,
1125-
self.reduce_indptr,
1126-
self.reduce_final_map,
1127-
self.reduce_partial_map,
1128-
max_q_len,
1129-
fast_mode=fast_mode,
1130-
max_split_per_batch=num_kv_splits,
1131-
intra_batch_mode=intra_batch_mode,
1132-
)
1133-
1134-
work_metadata = self.work_metadata
1135-
work_info_set = self.work_info_set
1136-
work_indptr = self.work_indptr
1137-
1138-
reduce_indptr = self.reduce_indptr
1139-
reduce_final_map = self.reduce_final_map
1140-
reduce_partial_map = self.reduce_partial_map
1141-
1142-
self.forward_metadata = ForwardMetadata(
1143-
kv_indptr,
1144-
kv_indices,
1145-
qo_indptr,
1146-
kv_last_page_len,
1147-
max_q_len,
1148-
kv_indptr[-1].item(),
1149-
work_metadata=work_metadata,
1150-
work_info_set=work_info_set,
1151-
work_indptr=work_indptr,
1152-
reduce_indptr=reduce_indptr,
1153-
reduce_final_map=reduce_final_map,
1154-
reduce_partial_map=reduce_partial_map,
1155-
num_kv_splits=num_kv_splits,
1156-
# num_kv_splits_indptr=num_kv_splits_indptr,
1157-
)
1158-
11591096
elif forward_mode.is_target_verify():
11601097
bs = len(req_pool_indices)
11611098
qo_indptr = self.qo_indptr[: bs + 1]
@@ -1180,57 +1117,7 @@ def init_forward_metadata_replay_cuda_graph(
11801117
self.req_to_token.stride(0),
11811118
)
11821119

1183-
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
1184-
max_q_len = self.num_draft_tokens
1185-
1186-
# if self.kv_cache_dtype == fp8_dtype:
1187-
if _use_mla_ps_kernel:
1188-
1189-
num_kv_splits = self.max_split_per_batch
1190-
1191-
self.make_mla_meta_data(
1192-
qo_indptr,
1193-
kv_indptr,
1194-
kv_last_page_len,
1195-
self.work_metadata,
1196-
self.work_info_set,
1197-
self.work_indptr,
1198-
self.reduce_indptr,
1199-
self.reduce_final_map,
1200-
self.reduce_partial_map,
1201-
max_q_len,
1202-
fast_mode=fast_mode,
1203-
max_split_per_batch=num_kv_splits,
1204-
intra_batch_mode=intra_batch_mode,
1205-
)
1206-
1207-
work_metadata = self.work_metadata
1208-
work_info_set = self.work_info_set
1209-
work_indptr = self.work_indptr
1210-
1211-
reduce_indptr = self.reduce_indptr
1212-
reduce_final_map = self.reduce_final_map
1213-
reduce_partial_map = self.reduce_partial_map
1214-
1215-
self.forward_metadata = ForwardMetadata(
1216-
kv_indptr,
1217-
kv_indices,
1218-
qo_indptr,
1219-
kv_last_page_len,
1220-
max_q_len,
1221-
kv_indptr[-1].item(),
1222-
work_metadata=work_metadata,
1223-
work_info_set=work_info_set,
1224-
work_indptr=work_indptr,
1225-
reduce_indptr=reduce_indptr,
1226-
reduce_final_map=reduce_final_map,
1227-
reduce_partial_map=reduce_partial_map,
1228-
num_kv_splits=num_kv_splits,
1229-
# num_kv_splits_indptr=num_kv_splits_indptr,
1230-
)
1231-
12321120
elif forward_mode.is_draft_extend():
1233-
num_tokens_per_bs = self.speculative_num_steps + 1
12341121
seq_lens = seq_lens[:bs]
12351122
accept_lens = spec_info.accept_length[:bs]
12361123
qo_indptr = self.qo_indptr[: bs + 1]
@@ -1248,54 +1135,6 @@ def init_forward_metadata_replay_cuda_graph(
12481135
self.req_to_token.stride(0),
12491136
)
12501137

1251-
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
1252-
max_q_len = num_tokens_per_bs
1253-
1254-
if _use_mla_ps_kernel:
1255-
1256-
num_kv_splits = self.max_split_per_batch
1257-
1258-
self.make_mla_meta_data(
1259-
qo_indptr,
1260-
kv_indptr,
1261-
kv_last_page_len,
1262-
self.work_metadata,
1263-
self.work_info_set,
1264-
self.work_indptr,
1265-
self.reduce_indptr,
1266-
self.reduce_final_map,
1267-
self.reduce_partial_map,
1268-
max_q_len,
1269-
fast_mode=fast_mode,
1270-
max_split_per_batch=num_kv_splits,
1271-
intra_batch_mode=intra_batch_mode,
1272-
)
1273-
1274-
work_metadata = self.work_metadata
1275-
work_info_set = self.work_info_set
1276-
work_indptr = self.work_indptr
1277-
1278-
reduce_indptr = self.reduce_indptr
1279-
reduce_final_map = self.reduce_final_map
1280-
reduce_partial_map = self.reduce_partial_map
1281-
1282-
self.forward_metadata = ForwardMetadata(
1283-
kv_indptr,
1284-
kv_indices,
1285-
qo_indptr,
1286-
kv_last_page_len,
1287-
max_q_len,
1288-
kv_indptr[-1].item(),
1289-
work_metadata=work_metadata,
1290-
work_info_set=work_info_set,
1291-
work_indptr=work_indptr,
1292-
reduce_indptr=reduce_indptr,
1293-
reduce_final_map=reduce_final_map,
1294-
reduce_partial_map=reduce_partial_map,
1295-
num_kv_splits=num_kv_splits,
1296-
# num_kv_splits_indptr=num_kv_splits_indptr,
1297-
)
1298-
12991138
else:
13001139
raise ValueError("Invalid forward mode")
13011140

@@ -1527,6 +1366,23 @@ def forward_extend(
15271366

15281367
num_kv_splits = self.forward_metadata.num_kv_splits
15291368

1369+
if layer.layer_id == 0 and _use_mla_ps_kernel:
1370+
self.make_mla_meta_data(
1371+
self.forward_metadata.qo_indptr,
1372+
self.forward_metadata.kv_indptr,
1373+
self.forward_metadata.kv_last_page_len,
1374+
work_metadata,
1375+
work_info_set,
1376+
work_indptr,
1377+
reduce_indptr,
1378+
reduce_final_map,
1379+
reduce_partial_map,
1380+
self.forward_metadata.max_q_len,
1381+
fast_mode=fast_mode,
1382+
max_split_per_batch=num_kv_splits,
1383+
intra_batch_mode=intra_batch_mode,
1384+
)
1385+
15301386
mla_decode_fwd(
15311387
q,
15321388
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
@@ -1562,6 +1418,23 @@ def forward_extend(
15621418

15631419
num_kv_splits = self.forward_metadata.num_kv_splits
15641420

1421+
if layer.layer_id == 0 and _use_mla_ps_kernel:
1422+
self.make_mla_meta_data(
1423+
self.forward_metadata.qo_indptr,
1424+
self.forward_metadata.kv_indptr,
1425+
self.forward_metadata.kv_last_page_len,
1426+
work_metadata,
1427+
work_info_set,
1428+
work_indptr,
1429+
reduce_indptr,
1430+
reduce_final_map,
1431+
reduce_partial_map,
1432+
self.forward_metadata.max_q_len,
1433+
fast_mode=fast_mode,
1434+
max_split_per_batch=num_kv_splits,
1435+
intra_batch_mode=intra_batch_mode,
1436+
)
1437+
15651438
if self.forward_metadata.run_graph is not True:
15661439

15671440
bs, q_pad, q_mask = pad_sequence_with_mask(
@@ -1704,6 +1577,23 @@ def forward_decode(
17041577

17051578
num_kv_splits = self.forward_metadata.num_kv_splits
17061579

1580+
if layer.layer_id == 0 and _use_mla_ps_kernel:
1581+
self.make_mla_meta_data(
1582+
self.forward_metadata.qo_indptr,
1583+
self.forward_metadata.kv_indptr,
1584+
self.forward_metadata.kv_last_page_len,
1585+
work_metadata,
1586+
work_info_set,
1587+
work_indptr,
1588+
reduce_indptr,
1589+
reduce_final_map,
1590+
reduce_partial_map,
1591+
self.forward_metadata.max_q_len,
1592+
fast_mode=fast_mode,
1593+
max_split_per_batch=num_kv_splits,
1594+
intra_batch_mode=intra_batch_mode,
1595+
)
1596+
17071597
mla_decode_fwd(
17081598
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
17091599
k_buffer.view(-1, 1, 1, layer.qk_head_dim),

0 commit comments

Comments
 (0)