Skip to content

Commit cecba06

Browse files
authored
[docker] support mtp for r3 (#1131)
1 parent 2924d85 commit cecba06

File tree

2 files changed

+75
-38
lines changed

2 files changed

+75
-38
lines changed

docker/patch/latest/sglang.patch

Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,10 @@ index e7d5a67cc..639e47163 100644
301301
out_hidden_states[begin_chunk_idx:end_chunk_idx],
302302
diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py
303303
new file mode 100644
304-
index 000000000..7369f9dc9
304+
index 000000000..11adcaa77
305305
--- /dev/null
306306
+++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py
307-
@@ -0,0 +1,308 @@
307+
@@ -0,0 +1,305 @@
308308
+import logging
309309
+from abc import ABC
310310
+from contextlib import contextmanager
@@ -402,9 +402,6 @@ index 000000000..7369f9dc9
402402
+ assert hasattr(self, "buffer")
403403
+ return get_tensor_size_bytes(self.buffer)
404404
+
405-
+ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor):
406-
+ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True)
407-
+
408405
+ def _finalize_allocation_log(self):
409406
+ """Common logging and memory usage computation for captured experts buffers."""
410407
+ buffer_size_GB = self.get_buffer_size_bytes() / _GB
@@ -903,7 +900,7 @@ index e34736cc4..5e5997a1a 100644
903900
# idx is the index of the token in the prompt after expansion.
904901
# val is the length of padded tokens after expansion.
905902
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
906-
index c4c5a9ebb..1450c5fd8 100644
903+
index c4c5a9ebb..3650ba881 100644
907904
--- a/python/sglang/srt/managers/schedule_batch.py
908905
+++ b/python/sglang/srt/managers/schedule_batch.py
909906
@@ -450,6 +450,7 @@ class Req:
@@ -953,15 +950,23 @@ index c4c5a9ebb..1450c5fd8 100644
953950
is_prefill_only=all(req.is_prefill_only for req in reqs),
954951
chunked_req=chunked_req,
955952
dllm_config=dllm_config,
956-
@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
953+
@@ -1282,6 +1294,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
954+
)
955+
else:
956+
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
957+
+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)
958+
959+
if not encoder_out_cache_loc:
960+
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
961+
@@ -1457,6 +1470,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
957962
self.req_pool_indices = req_pool_indices_tensor
958963
self.orig_seq_lens = orig_seq_lens_tensor
959964
self.out_cache_loc = out_cache_loc
960965
+ self.out_cache_loc_cpu = out_cache_loc.cpu()
961966
self.input_embeds = (
962967
torch.tensor(input_embeds).to(self.device, non_blocking=True)
963968
if input_embeds
964-
@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
969+
@@ -1508,10 +1522,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
965970

966971
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
967972
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
@@ -976,47 +981,47 @@ index c4c5a9ebb..1450c5fd8 100644
976981

977982
# For overlap scheduler, the output_ids has one step delay
978983
delta = 0 if self.enable_overlap else -1
979-
@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
984+
@@ -1677,6 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
980985
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
981986
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
982987
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
983988
+ self.out_cache_loc_cpu = torch.empty(0, dtype=torch.int64, device="cpu")
984989
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
985990
self.seq_lens_sum = 0
986991
self.extend_num_tokens = 0
987-
@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
992+
@@ -1736,6 +1755,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
988993

989994
# Allocate memory
990995
self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
991996
+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)
992997

993998
# Update req-level memory management fields
994999
for req in self.reqs:
995-
@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1000+
@@ -1807,6 +1827,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
9961001
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
9971002
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
9981003
self.out_cache_loc = None
9991004
+ self.out_cache_loc_cpu = None
10001005
self.seq_lens_sum = self.seq_lens.sum().item()
10011006
self.output_ids = self.output_ids[keep_indices_device]
10021007
self.return_logprob = any(req.return_logprob for req in self.reqs)
1003-
@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1008+
@@ -1852,6 +1873,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
10041009
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
10051010
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
10061011
self.out_cache_loc = None
10071012
+ self.out_cache_loc_cpu = None
10081013
self.seq_lens_sum += other.seq_lens_sum
10091014
if self.output_ids is not None:
10101015
self.output_ids = torch.cat([self.output_ids, other.output_ids])
1011-
@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1016+
@@ -1903,6 +1925,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
10121017
seq_lens=self.seq_lens,
10131018
orig_seq_lens=self.orig_seq_lens,
10141019
out_cache_loc=self.out_cache_loc,
10151020
+ out_cache_loc_cpu=self.out_cache_loc_cpu,
10161021
seq_lens_cpu=seq_lens_cpu,
10171022
seq_lens_sum=self.seq_lens_sum,
10181023
return_logprob=self.return_logprob,
1019-
@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1024+
@@ -1983,7 +2006,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
10201025
def __str__(self):
10211026
return (
10221027
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
@@ -1026,7 +1031,7 @@ index c4c5a9ebb..1450c5fd8 100644
10261031
)
10271032

10281033

1029-
@@ -2038,6 +2061,9 @@ class ModelWorkerBatch:
1034+
@@ -2038,6 +2062,9 @@ class ModelWorkerBatch:
10301035
# Sampling info
10311036
sampling_info: SamplingBatchInfo
10321037

@@ -1194,7 +1199,7 @@ index edbc52526..2cdc42755 100644
11941199

11951200
# This means that weight sync
11961201
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
1197-
index b90cf0616..9b0992655 100644
1202+
index b90cf0616..8a5cbdbed 100644
11981203
--- a/python/sglang/srt/managers/tokenizer_manager.py
11991204
+++ b/python/sglang/srt/managers/tokenizer_manager.py
12001205
@@ -20,6 +20,7 @@ import logging
@@ -1213,14 +1218,12 @@ index b90cf0616..9b0992655 100644
12131218
data_parallel_rank=obj.data_parallel_rank,
12141219
priority=obj.priority,
12151220
extra_key=obj.extra_key,
1216-
@@ -1621,6 +1623,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1221+
@@ -1621,6 +1623,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
12171222
if getattr(recv_obj, "output_hidden_states", None):
12181223
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
12191224

12201225
+ if getattr(recv_obj, "output_routed_experts", None):
12211226
+ if recv_obj.output_routed_experts[i] is not None:
1222-
+ # print(f"{recv_obj.output_routed_experts[i].shape=}, {recv_obj.output_routed_experts[i].dtype=}")
1223-
+ # torch.save(recv_obj.output_routed_experts[i], f"/root/{recv_obj.output_routed_experts[i].shape[0]}.pt")
12241227
+ meta_info["routed_experts"] = pybase64.b64encode(
12251228
+ recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C")
12261229
+ ).decode("ascii")
@@ -1230,7 +1233,7 @@ index b90cf0616..9b0992655 100644
12301233
if isinstance(recv_obj, BatchStrOutput):
12311234
state.text += recv_obj.output_strs[i]
12321235
if self.server_args.stream_output and state.obj.stream:
1233-
@@ -1747,12 +1759,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1236+
@@ -1747,12 +1757,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
12341237
return
12351238

12361239
if len(recv_obj.input_token_logprobs_val) > 0:
@@ -1251,7 +1254,7 @@ index b90cf0616..9b0992655 100644
12511254
recv_obj.output_token_logprobs_val[recv_obj_index]
12521255
)
12531256
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
1254-
index 3a85e6a7e..2859dafa1 100644
1257+
index 3a85e6a7e..d2560e79b 100644
12551258
--- a/python/sglang/srt/model_executor/forward_batch_info.py
12561259
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
12571260
@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
@@ -1315,8 +1318,16 @@ index 3a85e6a7e..2859dafa1 100644
13151318
if self.encoder_lens is not None:
13161319
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
13171320
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
1321+
@@ -906,6 +921,7 @@ class ForwardBatch:
1322+
self.spec_info.hidden_states = self.hidden_states_backup
1323+
if hasattr(self, "output_cache_loc_backup"):
1324+
self.out_cache_loc = self.output_cache_loc_backup
1325+
+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)
1326+
1327+
elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
1328+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
13181329
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
1319-
index 4d58278b7..8f50dc430 100644
1330+
index 4d58278b7..5965c481e 100644
13201331
--- a/python/sglang/srt/model_executor/model_runner.py
13211332
+++ b/python/sglang/srt/model_executor/model_runner.py
13221333
@@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import (
@@ -1331,18 +1342,19 @@ index 4d58278b7..8f50dc430 100644
13311342
from sglang.srt.layers.pooler import EmbeddingPoolerOutput
13321343
from sglang.srt.layers.sampler import Sampler
13331344
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
1334-
@@ -502,6 +507,10 @@ class ModelRunner:
1345+
@@ -502,6 +507,11 @@ class ModelRunner:
13351346
server_args.max_running_requests,
13361347
server_args.max_total_tokens,
13371348
)
13381349
+
13391350
+ # Init routed experts capturer
1340-
+ self.init_routed_experts_capturer()
1351+
+ if not self.is_draft_worker:
1352+
+ self.init_routed_experts_capturer()
13411353
+
13421354
if self.device == "cuda":
13431355
self.init_cublas()
13441356
self.init_attention_backend()
1345-
@@ -545,6 +554,40 @@ class ModelRunner:
1357+
@@ -545,6 +555,40 @@ class ModelRunner:
13461358
# Initialize piecewise CUDA graph
13471359
self.init_piecewise_cuda_graphs()
13481360

@@ -1383,7 +1395,7 @@ index 4d58278b7..8f50dc430 100644
13831395
def model_specific_adjustment(self):
13841396
server_args = self.server_args
13851397

1386-
@@ -792,7 +835,11 @@ class ModelRunner:
1398+
@@ -792,7 +836,11 @@ class ModelRunner:
13871399
)
13881400
with self.memory_saver_adapter.region(
13891401
GPU_MEMORY_TYPE_WEIGHTS,
@@ -1396,7 +1408,7 @@ index 4d58278b7..8f50dc430 100644
13961408
):
13971409
self.model = get_model(
13981410
model_config=self.model_config,
1399-
@@ -2645,9 +2692,12 @@ class ModelRunner:
1411+
@@ -2645,9 +2693,12 @@ class ModelRunner:
14001412
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
14011413
self.forward_pass_id += 1
14021414

@@ -1412,17 +1424,18 @@ index 4d58278b7..8f50dc430 100644
14121424
):
14131425
output = self._forward_raw(
14141426
forward_batch,
1415-
@@ -2656,6 +2706,13 @@ class ModelRunner:
1427+
@@ -2656,6 +2707,14 @@ class ModelRunner:
14161428
reinit_attn_backend,
14171429
split_forward_count,
14181430
)
14191431
+ # Copy cached routing experts' buffers back to CPU cache
1420-
+ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH(
1421-
+ device_loc=forward_batch.out_cache_loc,
1422-
+ cpu_loc=forward_batch.out_cache_loc_cpu,
1423-
+ can_run_graph=output[1],
1424-
+ cuda_graph_batch=getattr(self.graph_runner, "bs", None),
1425-
+ )
1432+
+ if not self.is_draft_worker:
1433+
+ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH(
1434+
+ device_loc=forward_batch.out_cache_loc,
1435+
+ cpu_loc=forward_batch.out_cache_loc_cpu,
1436+
+ can_run_graph=output[1],
1437+
+ cuda_graph_batch=getattr(self.graph_runner, "bs", None),
1438+
+ )
14261439

14271440
if self.eplb_manager is not None:
14281441
self.eplb_manager.on_forward_pass_end()
@@ -1976,10 +1989,34 @@ index 8e7753dab..323788f39 100644
19761989
"--scheduler-recv-interval",
19771990
type=int,
19781991
diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py
1979-
index b3d72df05..ddfe0b178 100644
1992+
index b3d72df05..09a1634e0 100644
19801993
--- a/python/sglang/srt/speculative/eagle_info.py
19811994
+++ b/python/sglang/srt/speculative/eagle_info.py
1982-
@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
1995+
@@ -135,6 +135,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
1996+
len(batch.input_ids),
1997+
)
1998+
self.last_loc = last_loc
1999+
+ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)
2000+
2001+
bs = batch.batch_size()
2002+
assign_req_to_token_pool_func(
2003+
@@ -492,6 +493,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
2004+
batch.out_cache_loc = tgt_cache_loc
2005+
batch.seq_lens.add_(accept_length + 1)
2006+
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
2007+
+ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)
2008+
2009+
draft_input = EagleDraftInput(
2010+
hidden_states=batch.spec_info.hidden_states[accept_index],
2011+
@@ -575,6 +577,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
2012+
topk=self.topk,
2013+
capture_hidden_mode=CaptureHiddenMode.LAST,
2014+
)
2015+
+ batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)
2016+
2017+
return EagleVerifyOutput(
2018+
draft_input=draft_input,
2019+
@@ -746,6 +749,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
19832020
self.topk_index = self.topk_index[: len(new_indices)]
19842021
self.hidden_states = self.hidden_states[: len(new_indices)]
19852022
self.verified_id = self.verified_id[: len(new_indices)]
@@ -1990,7 +2027,7 @@ index b3d72df05..ddfe0b178 100644
19902027
else:
19912028
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
19922029
self.topk_p = self.topk_p[new_indices]
1993-
@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
2030+
@@ -777,6 +784,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
19942031
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
19952032
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
19962033
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])

docker/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
nightly-dev-20251216a
1+
nightly-dev-20251216b

0 commit comments

Comments
 (0)