@@ -301,10 +301,10 @@ index e7d5a67cc..639e47163 100644
301301 out_hidden_states[begin_chunk_idx:end_chunk_idx],
302302diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py
303303new file mode 100644
304- index 000000000..7369f9dc9
304+ index 000000000..11adcaa77
305305--- /dev/null
306306+++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py
307- @@ -0,0 +1,308 @@
307+ @@ -0,0 +1,305 @@
308308+ import logging
309309+ from abc import ABC
310310+ from contextlib import contextmanager
@@ -402,9 +402,6 @@ index 000000000..7369f9dc9
402402+ assert hasattr(self, "buffer")
403403+ return get_tensor_size_bytes(self.buffer)
404404+
405- + def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor):
406- + self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True)
407- +
408405+ def _finalize_allocation_log(self):
409406+ """Common logging and memory usage computation for captured experts buffers."""
410407+ buffer_size_GB = self.get_buffer_size_bytes() / _GB
@@ -903,7 +900,7 @@ index e34736cc4..5e5997a1a 100644
903900 # idx is the index of the token in the prompt after expansion.
904901 # val is the length of padded tokens after expansion.
905902diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
906- index c4c5a9ebb..1450c5fd8 100644
903+ index c4c5a9ebb..3650ba881 100644
907904--- a/python/sglang/srt/managers/schedule_batch.py
908905+++ b/python/sglang/srt/managers/schedule_batch.py
909906@@ -450,6 +450,7 @@ class Req:
@@ -953,15 +950,23 @@ index c4c5a9ebb..1450c5fd8 100644
953950 is_prefill_only=all(req.is_prefill_only for req in reqs),
954951 chunked_req=chunked_req,
955952 dllm_config=dllm_config,
956- @@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
953+ @@ -1282,6 +1294,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
954+ )
955+ else:
956+ self.out_cache_loc = torch.cat(decoder_out_cache_loc)
957+ + self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)
958+
959+ if not encoder_out_cache_loc:
960+ self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
961+ @@ -1457,6 +1470,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
957962 self.req_pool_indices = req_pool_indices_tensor
958963 self.orig_seq_lens = orig_seq_lens_tensor
959964 self.out_cache_loc = out_cache_loc
960965+ self.out_cache_loc_cpu = out_cache_loc.cpu()
961966 self.input_embeds = (
962967 torch.tensor(input_embeds).to(self.device, non_blocking=True)
963968 if input_embeds
964- @@ -1508,10 +1521 ,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
969+ @@ -1508,10 +1522 ,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
965970
966971 input_ids = torch.cat([self.input_ids, running_batch.input_ids])
967972 out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
@@ -976,47 +981,47 @@ index c4c5a9ebb..1450c5fd8 100644
976981
977982 # For overlap scheduler, the output_ids has one step delay
978983 delta = 0 if self.enable_overlap else -1
979- @@ -1677,6 +1694 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
984+ @@ -1677,6 +1695 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
980985 self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
981986 self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
982987 self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
983988+ self.out_cache_loc_cpu = torch.empty(0, dtype=torch.int64, device="cpu")
984989 self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
985990 self.seq_lens_sum = 0
986991 self.extend_num_tokens = 0
987- @@ -1736,6 +1754 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
992+ @@ -1736,6 +1755 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
988993
989994 # Allocate memory
990995 self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
991996+ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)
992997
993998 # Update req-level memory management fields
994999 for req in self.reqs:
995- @@ -1807,6 +1826 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1000+ @@ -1807,6 +1827 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
9961001 self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
9971002 self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
9981003 self.out_cache_loc = None
9991004+ self.out_cache_loc_cpu = None
10001005 self.seq_lens_sum = self.seq_lens.sum().item()
10011006 self.output_ids = self.output_ids[keep_indices_device]
10021007 self.return_logprob = any(req.return_logprob for req in self.reqs)
1003- @@ -1852,6 +1872 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1008+ @@ -1852,6 +1873 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
10041009 self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
10051010 self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
10061011 self.out_cache_loc = None
10071012+ self.out_cache_loc_cpu = None
10081013 self.seq_lens_sum += other.seq_lens_sum
10091014 if self.output_ids is not None:
10101015 self.output_ids = torch.cat([self.output_ids, other.output_ids])
1011- @@ -1903,6 +1924 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1016+ @@ -1903,6 +1925 ,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
10121017 seq_lens=self.seq_lens,
10131018 orig_seq_lens=self.orig_seq_lens,
10141019 out_cache_loc=self.out_cache_loc,
10151020+ out_cache_loc_cpu=self.out_cache_loc_cpu,
10161021 seq_lens_cpu=seq_lens_cpu,
10171022 seq_lens_sum=self.seq_lens_sum,
10181023 return_logprob=self.return_logprob,
1019- @@ -1983,7 +2005 ,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1024+ @@ -1983,7 +2006 ,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
10201025 def __str__(self):
10211026 return (
10221027 f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
@@ -1026,7 +1031,7 @@ index c4c5a9ebb..1450c5fd8 100644
10261031 )
10271032
10281033
1029- @@ -2038,6 +2061 ,9 @@ class ModelWorkerBatch:
1034+ @@ -2038,6 +2062 ,9 @@ class ModelWorkerBatch:
10301035 # Sampling info
10311036 sampling_info: SamplingBatchInfo
10321037
@@ -1194,7 +1199,7 @@ index edbc52526..2cdc42755 100644
11941199
11951200 # This means that weight sync
11961201diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
1197- index b90cf0616..9b0992655 100644
1202+ index b90cf0616..8a5cbdbed 100644
11981203--- a/python/sglang/srt/managers/tokenizer_manager.py
11991204+++ b/python/sglang/srt/managers/tokenizer_manager.py
12001205@@ -20,6 +20,7 @@ import logging
@@ -1213,14 +1218,12 @@ index b90cf0616..9b0992655 100644
12131218 data_parallel_rank=obj.data_parallel_rank,
12141219 priority=obj.priority,
12151220 extra_key=obj.extra_key,
1216- @@ -1621,6 +1623,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1221+ @@ -1621,6 +1623,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
12171222 if getattr(recv_obj, "output_hidden_states", None):
12181223 meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
12191224
12201225+ if getattr(recv_obj, "output_routed_experts", None):
12211226+ if recv_obj.output_routed_experts[i] is not None:
1222- + # print(f"{recv_obj.output_routed_experts[i].shape=}, {recv_obj.output_routed_experts[i].dtype=}")
1223- + # torch.save(recv_obj.output_routed_experts[i], f"/root/{recv_obj.output_routed_experts[i].shape[0]}.pt")
12241227+ meta_info["routed_experts"] = pybase64.b64encode(
12251228+ recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C")
12261229+ ).decode("ascii")
@@ -1230,7 +1233,7 @@ index b90cf0616..9b0992655 100644
12301233 if isinstance(recv_obj, BatchStrOutput):
12311234 state.text += recv_obj.output_strs[i]
12321235 if self.server_args.stream_output and state.obj.stream:
1233- @@ -1747,12 +1759 ,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1236+ @@ -1747,12 +1757 ,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
12341237 return
12351238
12361239 if len(recv_obj.input_token_logprobs_val) > 0:
@@ -1251,7 +1254,7 @@ index b90cf0616..9b0992655 100644
12511254 recv_obj.output_token_logprobs_val[recv_obj_index]
12521255 )
12531256diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
1254- index 3a85e6a7e..2859dafa1 100644
1257+ index 3a85e6a7e..d2560e79b 100644
12551258--- a/python/sglang/srt/model_executor/forward_batch_info.py
12561259+++ b/python/sglang/srt/model_executor/forward_batch_info.py
12571260@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import (
@@ -1315,8 +1318,16 @@ index 3a85e6a7e..2859dafa1 100644
13151318 if self.encoder_lens is not None:
13161319 self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
13171320 self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
1321+ @@ -906,6 +921,7 @@ class ForwardBatch:
1322+ self.spec_info.hidden_states = self.hidden_states_backup
1323+ if hasattr(self, "output_cache_loc_backup"):
1324+ self.out_cache_loc = self.output_cache_loc_backup
1325+ + self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True)
1326+
1327+ elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
1328+ logits_output.next_token_logits = logits_output.next_token_logits[:bs]
13181329diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
1319- index 4d58278b7..8f50dc430 100644
1330+ index 4d58278b7..5965c481e 100644
13201331--- a/python/sglang/srt/model_executor/model_runner.py
13211332+++ b/python/sglang/srt/model_executor/model_runner.py
13221333@@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import (
@@ -1331,18 +1342,19 @@ index 4d58278b7..8f50dc430 100644
13311342 from sglang.srt.layers.pooler import EmbeddingPoolerOutput
13321343 from sglang.srt.layers.sampler import Sampler
13331344 from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
1334- @@ -502,6 +507,10 @@ class ModelRunner:
1345+ @@ -502,6 +507,11 @@ class ModelRunner:
13351346 server_args.max_running_requests,
13361347 server_args.max_total_tokens,
13371348 )
13381349+
13391350+ # Init routed experts capturer
1340- + self.init_routed_experts_capturer()
1351+ + if not self.is_draft_worker:
1352+ + self.init_routed_experts_capturer()
13411353+
13421354 if self.device == "cuda":
13431355 self.init_cublas()
13441356 self.init_attention_backend()
1345- @@ -545,6 +554 ,40 @@ class ModelRunner:
1357+ @@ -545,6 +555 ,40 @@ class ModelRunner:
13461358 # Initialize piecewise CUDA graph
13471359 self.init_piecewise_cuda_graphs()
13481360
@@ -1383,7 +1395,7 @@ index 4d58278b7..8f50dc430 100644
13831395 def model_specific_adjustment(self):
13841396 server_args = self.server_args
13851397
1386- @@ -792,7 +835 ,11 @@ class ModelRunner:
1398+ @@ -792,7 +836 ,11 @@ class ModelRunner:
13871399 )
13881400 with self.memory_saver_adapter.region(
13891401 GPU_MEMORY_TYPE_WEIGHTS,
@@ -1396,7 +1408,7 @@ index 4d58278b7..8f50dc430 100644
13961408 ):
13971409 self.model = get_model(
13981410 model_config=self.model_config,
1399- @@ -2645,9 +2692 ,12 @@ class ModelRunner:
1411+ @@ -2645,9 +2693 ,12 @@ class ModelRunner:
14001412 ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
14011413 self.forward_pass_id += 1
14021414
@@ -1412,17 +1424,18 @@ index 4d58278b7..8f50dc430 100644
14121424 ):
14131425 output = self._forward_raw(
14141426 forward_batch,
1415- @@ -2656,6 +2706,13 @@ class ModelRunner:
1427+ @@ -2656,6 +2707,14 @@ class ModelRunner:
14161428 reinit_attn_backend,
14171429 split_forward_count,
14181430 )
14191431+ # Copy cached routing experts' buffers back to CPU cache
1420- + get_global_experts_capturer().sync_fwd_experts_buffer_DtoH(
1421- + device_loc=forward_batch.out_cache_loc,
1422- + cpu_loc=forward_batch.out_cache_loc_cpu,
1423- + can_run_graph=output[1],
1424- + cuda_graph_batch=getattr(self.graph_runner, "bs", None),
1425- + )
1432+ + if not self.is_draft_worker:
1433+ + get_global_experts_capturer().sync_fwd_experts_buffer_DtoH(
1434+ + device_loc=forward_batch.out_cache_loc,
1435+ + cpu_loc=forward_batch.out_cache_loc_cpu,
1436+ + can_run_graph=output[1],
1437+ + cuda_graph_batch=getattr(self.graph_runner, "bs", None),
1438+ + )
14261439
14271440 if self.eplb_manager is not None:
14281441 self.eplb_manager.on_forward_pass_end()
@@ -1976,10 +1989,34 @@ index 8e7753dab..323788f39 100644
19761989 "--scheduler-recv-interval",
19771990 type=int,
19781991diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py
1979- index b3d72df05..ddfe0b178 100644
1992+ index b3d72df05..09a1634e0 100644
19801993--- a/python/sglang/srt/speculative/eagle_info.py
19811994+++ b/python/sglang/srt/speculative/eagle_info.py
1982- @@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
1995+ @@ -135,6 +135,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
1996+ len(batch.input_ids),
1997+ )
1998+ self.last_loc = last_loc
1999+ + batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)
2000+
2001+ bs = batch.batch_size()
2002+ assign_req_to_token_pool_func(
2003+ @@ -492,6 +493,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
2004+ batch.out_cache_loc = tgt_cache_loc
2005+ batch.seq_lens.add_(accept_length + 1)
2006+ batch.seq_lens_cpu.add_(accept_length_cpu + 1)
2007+ + batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)
2008+
2009+ draft_input = EagleDraftInput(
2010+ hidden_states=batch.spec_info.hidden_states[accept_index],
2011+ @@ -575,6 +577,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
2012+ topk=self.topk,
2013+ capture_hidden_mode=CaptureHiddenMode.LAST,
2014+ )
2015+ + batch.out_cache_loc_cpu = batch.out_cache_loc.to("cpu", non_blocking=True)
2016+
2017+ return EagleVerifyOutput(
2018+ draft_input=draft_input,
2019+ @@ -746,6 +749,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
19832020 self.topk_index = self.topk_index[: len(new_indices)]
19842021 self.hidden_states = self.hidden_states[: len(new_indices)]
19852022 self.verified_id = self.verified_id[: len(new_indices)]
@@ -1990,7 +2027,7 @@ index b3d72df05..ddfe0b178 100644
19902027 else:
19912028 # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
19922029 self.topk_p = self.topk_p[new_indices]
1993- @@ -777,6 +781 ,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
2030+ @@ -777,6 +784 ,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
19942031 self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
19952032 self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
19962033 self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
0 commit comments