diff --git a/build_conda.sh b/build_conda.sh index 3baf42ed4..69ea1818e 100644 --- a/build_conda.sh +++ b/build_conda.sh @@ -12,6 +12,8 @@ source ~/.bashrc micromamba create -n slime python=3.12 pip -c conda-forge -y micromamba activate slime export CUDA_HOME="$CONDA_PREFIX" +export SGLANG_COMMIT="24c91001cf99ba642be791e099d358f4dfe955f5" +export MEGATRON_COMMIT="3714d81d418c9f1bca4594fc35f9e8289f652862" export BASE_DIR=${BASE_DIR:-"/root"} cd $BASE_DIR @@ -27,7 +29,7 @@ pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https # install sglang git clone https://github.com/sgl-project/sglang.git cd sglang -git checkout 5e2cda6158e670e64b926a9985d65826c537ac82 +git checkout ${SGLANG_COMMIT} # Install the python packages pip install -e "python[all]" @@ -53,12 +55,9 @@ pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation # megatron cd $BASE_DIR git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ - cd Megatron-LM/ && git checkout core_v0.14.0 && \ + cd Megatron-LM/ && git checkout ${MEGATRON_COMMIT} && \ pip install -e . -# https://github.com/pytorch/pytorch/issues/168167 -pip install nvidia-cudnn-cu12==9.16.0.29 - # install slime and apply patches # if slime does not exist locally, clone it @@ -73,8 +72,11 @@ else pip install -e . fi +# https://github.com/pytorch/pytorch/issues/168167 +pip install nvidia-cudnn-cu12==9.16.0.29 + # apply patch cd $BASE_DIR/sglang -git apply $SLIME_DIR/docker/patch/v0.5.6/sglang.patch +git apply $SLIME_DIR/docker/patch/v0.5.7/sglang.patch cd $BASE_DIR/Megatron-LM -git apply $SLIME_DIR/docker/patch/v0.5.6/megatron.patch \ No newline at end of file +git apply $SLIME_DIR/docker/patch/v0.5.7/megatron.patch \ No newline at end of file diff --git a/docker/README.md b/docker/README.md index 156169c72..29929d282 100644 --- a/docker/README.md +++ b/docker/README.md @@ -5,10 +5,12 @@ We will publish 2 kinds of docker images: 2. latest version, which aligns to `lmsysorg/sglang:latest`. current stable version is: -- sglang nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) +- sglang v0.5.7 nightly-dev-20260103-24c91001 (24c91001cf99ba642be791e099d358f4dfe955f5), megatron dev 3714d81d418c9f1bca4594fc35f9e8289f652862 history versions: +- sglang v0.5.6 nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) - sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) +- sglang v0.5.0rc0-cu126 (8ecf6b9d2480c3f600826c7d8fef6a16ed603c3f), megatron 48406695c4efcf1026a7ed70bb390793918dd97b The command to build: diff --git a/docker/patch/v0.5.7/megatron.patch b/docker/patch/v0.5.7/megatron.patch new file mode 100644 index 000000000..795ebf1de --- /dev/null +++ b/docker/patch/v0.5.7/megatron.patch @@ -0,0 +1,681 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index 5a1ea308d..aa701237f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index acb93ef78..20ee977b0 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +index 1fd5dcfae..c9aeef1f0 100644 +--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py ++++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py +@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( + cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) + +- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads +- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads +- mask = kv_off < head_num * stride_kv_nheads +- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] +- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] +- k = tl.load(KV_ptr + k_in_off, mask=mask) +- v = tl.load(KV_ptr + v_in_off, mask=mask) ++ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ k_off = ki_range * stride_kv_nheads + kj_range ++ if v_dim > 0: ++ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ v = tl.load(KV_ptr + v_off, mask=mask_v) ++ else: ++ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) ++ k = tl.load(KV_ptr + k_off, mask=mask_k) + +- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads +- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads ++ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads ++ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads + +- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] +- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] +- tl.store(K_ptr + k_out_off, k, mask=mask) +- tl.store(V_ptr + v_out_off, v, mask=mask) ++ k_out_off = ki_range * stride_k_nheads + kj_range ++ tl.store(K_ptr + k_out_off, k, mask=mask_k) ++ if v_dim > 0: ++ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] ++ tl.store(V_ptr + v_out_off, v, mask=mask_v) + + EMB = K_POS_EMB + pid_m * stride_emb_seq + # x1 = t[..., 0::2], x2 = t[..., 1::2] +@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( + x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + ++ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ mask_x = x_range < head_num + x_left_off = ( +- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads ++ x_range * stride_k_nheads + + k_dim + + tl.arange(0, emb_dim // 2)[None, :] + ) + x_right_off = x_left_off + emb_dim // 2 +- tl.store(K_ptr + x_left_off, x_left, mask=mask) +- tl.store(K_ptr + x_right_off, x_right, mask=mask) ++ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) ++ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) + + + @triton.autotune( +@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( + SIN, + emb_dim: tl.constexpr, + k_dim: tl.constexpr, ++ k_dim_ceil: tl.constexpr, + v_dim: tl.constexpr, + head_num: tl.constexpr, + batch_size, +@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( + else: + token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) + +- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads +- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads +- mask = dkv_off < head_num * stride_dkv_nheads +- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] +- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] +- +- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads +- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads +- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] +- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] +- dk = tl.load(dK_ptr + dk_in_off, mask=mask) +- dv = tl.load(dV_ptr + dv_in_off, mask=mask) +- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) +- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) ++ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads ++ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H ++ kj_range = tl.arange(0, k_dim_ceil)[None, :] ++ mask_k = (ki_range < head_num) & (kj_range < k_dim) ++ mask_v = ki_range < head_num ++ dk_out_off = ki_range * stride_dkv_nheads + kj_range ++ ++ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads ++ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads ++ dk_in_off = ki_range * stride_dk_nheads + kj_range ++ ++ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) ++ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) ++ ++ if v_dim > 0: ++ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] ++ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] ++ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) ++ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) + + if pid_head == 0: + x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) + for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): +- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads +- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim ++ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads ++ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads + mask = x_off < head_num * stride_dk_nheads + x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] + x_right_off = x_left_off + emb_dim // 2 +@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) + o_value = kv.new_empty(total_seqlen, nheads, v_dim) ++ k_dim_ceil = triton.next_power_of_2(k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_fwd_kv_kernel[grid]( +@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + emb_dim, + k_dim, ++ k_dim_ceil, + v_dim, + nheads, + batch_size, +@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + + d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) + d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) ++ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) + + grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) + rotary_bwd_kv_kernel[grid]( +@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): + sin, + ctx.emb_dim, + ctx.k_dim, ++ k_dim_ceil, + ctx.v_dim, + nheads, + batch_size, +diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py +index 13d74aa52..060898a7a 100644 +--- a/megatron/core/models/common/language_module/language_module.py ++++ b/megatron/core/models/common/language_module/language_module.py +@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." +- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) ++ # output ++ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} ++ output_layer_buffers = dict(column_parallel_linear.named_buffers()) ++ logits, _ = torch.func.functional_call( ++ column_parallel_linear, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden,), ++ col_linear_kwargs, ++ ) + + return self.compute_language_model_loss(labels, logits) + +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index e21127b87..712793853 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( + use_kitchen: bool = False, + use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, ++ post_self_attn_layernorm=post_self_attn_layernorm, ++ post_mlp_layernorm=post_mlp_layernorm, + ) + + +@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + +@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index a1230568c..1fd52f65a 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post +@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() +- if mtp_in_postprocess: ++ ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): + return hidden_states + + # Skip when mtp_num_layers is None or 0 +- if self.config.mtp_num_layers: +- mtp_labels = labels.clone() ++ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( +@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ +- 'weight': output_weight, ++ 'weight': output_weight.detach() if output_weight else None, + 'runtime_gather_output': runtime_gather_output, + }, + ) +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index 6e093f96f..eac21a3ea 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + if key == 'padding': + tensors[key] = LocalNonpersistentObject(tensors[key]) + continue ++ if key == 'step': ++ continue + assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( + tensors[key].shape, + gbuf_local_start, +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a273002b9..4f821cfd5 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -11,6 +11,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index ac839c21f..f18309217 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 28cff06f5..58dc4bb70 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -587,6 +587,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from slime.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 16fc9d9af..517944f25 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -201,6 +201,9 @@ class TopKRouter(Router): + self.global_tokens_per_expert = None + self.ga_steps = None + ++ from slime.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index a8f4abfcd..f33f6f05e 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) +- position_ids, _ = roll_tensor( +- position_ids, +- shifts=-1, +- dims=-1, +- cp_group=self.cp_group, +- packed_seq_params=packed_seq_params, +- ) ++ if position_ids is not None: ++ position_ids, _ = roll_tensor( ++ position_ids, ++ shifts=-1, ++ dims=-1, ++ cp_group=self.cp_group, ++ packed_seq_params=packed_seq_params, ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: +@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index e2705bd9f..a0aa109b5 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): + attention_output_gate: bool = False + """Whether to apply output gate to the attention layers.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 3ea405770..5a42001b9 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index b267c8a81..83736acdc 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): + + kw_args['inference_sampling_seed'] = args.seed + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 13b7526ca..6c590f653 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, +- trust_remote_code=trust_remote_code, ++ trust_remote_code=True, + **kwargs, + ) + self._vocab = self._tokenizer.get_vocab() diff --git a/docker/patch/v0.5.7/sglang.patch b/docker/patch/v0.5.7/sglang.patch new file mode 100644 index 000000000..42d23ed65 --- /dev/null +++ b/docker/patch/v0.5.7/sglang.patch @@ -0,0 +1,864 @@ +diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py +index 199885244..742ad0639 100644 +--- a/python/sglang/srt/disaggregation/decode.py ++++ b/python/sglang/srt/disaggregation/decode.py +@@ -314,6 +314,13 @@ class DecodePreallocQueue: + ) + return kv_manager + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + def add(self, req: Req, is_retracted: bool = False) -> None: + """Add a request to the pending queue.""" + if self._check_if_req_exceed_kv_capacity(req): +diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py +index 32e8c0b69..df913da7b 100644 +--- a/python/sglang/srt/disaggregation/mooncake/conn.py ++++ b/python/sglang/srt/disaggregation/mooncake/conn.py +@@ -1079,6 +1079,19 @@ class MooncakeKVManager(CommonKVManager): + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" + ) + ++ def close(self): ++ # Batch deregister KV data buffers ++ if self.kv_args.kv_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.kv_data_ptrs) ++ ++ # Batch deregister auxiliary data buffers ++ if self.kv_args.aux_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.aux_data_ptrs) ++ ++ # Batch deregister state/extra pool data buffers ++ if self.kv_args.state_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.state_data_ptrs) ++ + + class MooncakeKVSender(CommonKVSender): + +diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py +index ac11013f8..478e469f6 100644 +--- a/python/sglang/srt/disaggregation/prefill.py ++++ b/python/sglang/srt/disaggregation/prefill.py +@@ -309,6 +309,13 @@ class PrefillBootstrapQueue: + else: + return bootstrapped_reqs, failed_reqs + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + + class SchedulerDisaggregationPrefillMixin: + """ +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index 0478526ef..cfb1aa669 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py +index b07164c53..8e6722ce0 100644 +--- a/python/sglang/srt/layers/layernorm.py ++++ b/python/sglang/srt/layers/layernorm.py +@@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp): + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + cast_x_before_out_mul: bool = False, +- fp32_residual: bool = False, +- weight_dtype: Optional = None, +- override_orig_dtype: Optional = None, ++ fp32_residual: bool = True, + ) -> None: + super().__init__() + self.cast_x_before_out_mul = cast_x_before_out_mul + self.fp32_residual = fp32_residual +- self.override_orig_dtype = override_orig_dtype +- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) ++ self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( +@@ -194,10 +191,22 @@ class RMSNorm(MultiPlatformOp): + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() +- orig_dtype = self.override_orig_dtype or x.dtype ++ orig_dtype = x.dtype + post_residual_addition = kwargs.get("post_residual_addition") ++ ++ if residual is not None and not self.fp32_residual: ++ x = ( ++ x ++ + residual ++ + ( ++ post_residual_addition ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.clone() + x = x.to(torch.float32) +- if residual is not None: ++ if residual is not None and self.fp32_residual: + x = ( + x + + residual.to(torch.float32) +@@ -207,10 +216,7 @@ class RMSNorm(MultiPlatformOp): + else 0.0 + ) + ) +- if self.fp32_residual: +- residual = x.clone() +- else: +- residual = x.to(orig_dtype) ++ residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index fa7431048..cd33ea735 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module): + None, # bias + True, # is_vnni + ) +- elif get_global_server_args().rl_on_policy_target is not None: +- # Due to tie-weight, we may not be able to change lm_head's weight dtype +- logits = torch.matmul( +- hidden_states.bfloat16(), lm_head.weight.T.bfloat16() +- ) + else: + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +index a1885fade..14d692365 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@@ -14,6 +14,7 @@ import torch.nn.functional as F + import triton.language as tl + + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +@@ -573,7 +574,10 @@ def fused_experts_impl( + ).squeeze(dim=1) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. +- if tokens_in_chunk <= 32: ++ if ( ++ not get_global_server_args().enable_deterministic_inference ++ and tokens_in_chunk <= 32 ++ ): + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], +diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py +index 00bd68755..5a3ca8a67 100644 +--- a/python/sglang/srt/layers/moe/routed_experts_capturer.py ++++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py +@@ -1,5 +1,6 @@ + import logging + from abc import ABC ++from contextlib import contextmanager + from typing import Optional + + import numpy as np +@@ -8,13 +9,18 @@ import torch + + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.layers.dp_attention import ( ++ attn_tp_all_gather_into_tensor, + get_attention_dp_rank, ++ get_attention_tp_size, + get_dp_local_info, + is_dp_attention_enabled, + ) + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.server_args import get_global_server_args ++from sglang.srt.layers.moe import ( ++ get_moe_a2a_backend, ++) + + logger = logging.getLogger(__name__) + +@@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + device=device, + ) + ++ if get_moe_a2a_backend().is_deepep(): ++ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 ++ self.gather_buffer = torch.empty( ++ ( ++ self.device_cache.buffer.shape[0] * attn_tp_size, ++ self.device_cache.buffer.shape[2], ++ ), ++ dtype=torch.int32, ++ device=device, ++ ) ++ + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): +- if is_dp_attention_enabled(): ++ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer ++ # contains data from all DP ranks. We should not slice by DP rank in this case. ++ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + # handle with cuda graph padding + if can_run_graph: +@@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + ].cpu() + + def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ if get_moe_a2a_backend().is_deepep(): ++ local_topk_ids = topk_ids ++ topk_ids = self.gather_buffer[ ++ : local_topk_ids.size(0) * get_attention_tp_size() ++ ] ++ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + + def get_routed_experts( +diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py +index 56516b41b..cb2ebca60 100644 +--- a/python/sglang/srt/layers/rotary_embedding.py ++++ b/python/sglang/srt/layers/rotary_embedding.py +@@ -135,9 +135,7 @@ class RotaryEmbedding(MultiPlatformOp): + + if get_global_server_args().rl_on_policy_target is not None: + self._forward_method = self.forward_native +- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( +- self._apply_rotary_emb_wrapped +- ) ++ + self.position_cos, self.position_sin = None, None + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: +@@ -1577,6 +1575,9 @@ class MRotaryEmbedding(RotaryEmbedding): + key: torch.Tensor, + fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: ++ assert ( ++ fused_set_kv_buffer_arg is None ++ ), "fused_set_kv_buffer_arg is not supported for npu implementation" + # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 + assert ( + fused_set_kv_buffer_arg is None +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index 55bef5652..35ad68b1c 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -108,16 +108,11 @@ class Sampler(nn.Module): + if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: + probs_without_temp_scaling = torch.softmax(logits, dim=-1) + +- if get_global_server_args().rl_on_policy_target is not None: +- logits_div_temperature = ( +- logits.bfloat16().div(sampling_info.temperatures).bfloat16() +- ) +- logprobs_via_logsoftmax_kernel = torch.log_softmax( +- logits_div_temperature, dim=-1 +- ) +- + # Post process logits + logits.div_(sampling_info.temperatures) ++ if get_global_server_args().rl_on_policy_target is not None: ++ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) ++ + # For ascend backend, softmax is not needed before sampling + if not get_global_server_args().sampling_backend == "ascend" or ( + return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index 468d8fb8a..229a9a2dc 100644 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -2181,7 +2181,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + def __str__(self): + return ( + f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " +- f"#req={(len(self.reqs))})" ++ f"#req={(len(self.reqs))}), " ++ f"#out_cache_loc={self.out_cache_loc})" + ) + + +diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +index e40586c24..32d98aee4 100644 +--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py ++++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +@@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode + from sglang.srt.environ import envs + from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer ++ + from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOutput, +diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +index 293a84350..0947f77e0 100644 +--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py ++++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +@@ -1,6 +1,7 @@ + from __future__ import annotations + + import logging ++import os + import traceback + from typing import TYPE_CHECKING, Tuple + +@@ -12,6 +13,9 @@ from sglang.srt.constants import ( + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, + ) ++from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group ++from sglang.srt.layers.dp_attention import get_attention_tp_group + from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, +@@ -137,6 +141,13 @@ class SchedulerUpdateWeightsMixin: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.release_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.release_memory_occupation() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.model_runner.model +@@ -177,6 +188,13 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.resume_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.resume_memory_occupation() ++ + return ResumeMemoryOccupationReqOutput() + + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index f4fc29e29..5ef12cca6 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -1652,12 +1652,13 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + return + + if len(recv_obj.input_token_logprobs_val) > 0: +- state.input_token_logprobs_val.extend( +- recv_obj.input_token_logprobs_val[recv_obj_index] +- ) +- state.input_token_logprobs_idx.extend( +- recv_obj.input_token_logprobs_idx[recv_obj_index] +- ) ++ if recv_obj.input_token_logprobs_val[recv_obj_index]: ++ state.input_token_logprobs_val.extend( ++ recv_obj.input_token_logprobs_val[recv_obj_index] ++ ) ++ state.input_token_logprobs_idx.extend( ++ recv_obj.input_token_logprobs_idx[recv_obj_index] ++ ) + state.output_token_logprobs_val.extend( + recv_obj.output_token_logprobs_val[recv_obj_index] + ) +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 1d69c0582..9027374be 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): + ) + + # Init routed experts capturer +- self.init_routed_experts_capturer() ++ if not self.is_draft_worker: ++ self.init_routed_experts_capturer() + + if self.device == "cuda": + self.init_cublas() +@@ -2224,11 +2225,12 @@ class ModelRunner(ModelRunnerKVCacheMixin): + output.expert_distribution_metrics = recorder_outputs.get("metrics") + + # Copy cached routing experts' buffers back to CPU cache +- get_global_experts_capturer().on_forward_end( +- forward_batch=forward_batch, +- can_run_graph=output.can_run_graph, +- cuda_graph_batch=getattr(self.graph_runner, "bs", None), +- ) ++ if not self.is_draft_worker: ++ get_global_experts_capturer().on_forward_end( ++ forward_batch=forward_batch, ++ can_run_graph=output.can_run_graph, ++ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ ) + + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end() +diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py +index 2918461d3..2bcc67087 100644 +--- a/python/sglang/srt/models/deepseek_v2.py ++++ b/python/sglang/srt/models/deepseek_v2.py +@@ -2704,7 +2704,11 @@ class DeepseekV2AttentionMLA(nn.Module): + ): + k = k_nope.new_empty(*k_shape) + concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) +- elif _is_cuda: ++ elif _is_cuda and all( ++ # (i.bit_count() == 1) == (is_power_of_two(i)) ++ i.bit_count() == 1 ++ for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1]) ++ ): + # fa3 mha support fp8 inputs + if ( + self.current_attention_backend == "fa3" +diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py +index a7dbadec6..c83a41338 100644 +--- a/python/sglang/srt/models/qwen2.py ++++ b/python/sglang/srt/models/qwen2.py +@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): + self.act_fn = SiluAndMul() + + def forward(self, x): +- if get_global_server_args().rl_on_policy_target is not None: +- x = x.bfloat16() +- + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) +@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), +- params_dtype=( +- torch.float32 +- if get_global_server_args().rl_on_policy_target is not None +- else None +- ), + ) + else: + self.embed_tokens = PPMissingLayer() +@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): + if self.pp_group.is_last_rank: + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py +index 3ad9f6736..0b9c7f499 100644 +--- a/python/sglang/srt/models/qwen2_moe.py ++++ b/python/sglang/srt/models/qwen2_moe.py +@@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module): + prefix=add_prefix("layers", prefix), + ) + if self.pp_group.is_last_rank: +- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.norm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + else: + self.norm = PPMissingLayer(return_tuple=True) + +diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py +index 9220831f6..47a1a4e4c 100644 +--- a/python/sglang/srt/models/qwen3.py ++++ b/python/sglang/srt/models/qwen3.py +@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py +index e11678a9e..e277d46f2 100644 +--- a/python/sglang/srt/models/qwen3_moe.py ++++ b/python/sglang/srt/models/qwen3_moe.py +@@ -22,6 +22,7 @@ import math + from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar + + import torch ++import torch.nn.functional as F + from torch import nn + from transformers import PretrainedConfig + +@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( + ) + from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +-from sglang.srt.layers.moe.topk import TopK ++from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK + from sglang.srt.layers.moe.utils import RoutingMethodType + from sglang.srt.layers.quantization.base_config import QuantizationConfig + from sglang.srt.layers.radix_attention import RadixAttention +@@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + use_grouped_topk=False, + layer_id=layer_id, + ) ++ self.top_k = config.num_experts_per_tok + + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.num_experts +@@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) +- topk_output = self.topk(hidden_states, router_logits) ++ ++ if get_global_server_args().rl_on_policy_target is not None: ++ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) ++ routing_weights, selected_experts = torch.topk( ++ routing_weights, self.top_k, dim=-1 ++ ) ++ routing_weights /= routing_weights.sum(dim=-1, keepdim=True) ++ routing_weights = routing_weights.to(hidden_states.dtype) ++ topk_output = StandardTopKOutput( ++ topk_weights=routing_weights, ++ topk_ids=selected_experts, ++ router_logits=router_logits, ++ ) ++ else: ++ topk_output = self.topk(hidden_states, router_logits) ++ + final_hidden_states = self.experts(hidden_states, topk_output) + if ( + self.tp_size > 1 +@@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): + ) + self.compatible_with_fused_kv_buffer = ( + False if isinstance(self.rotary_emb, MRotaryEmbedding) else True +- ) ++ ) and (get_global_server_args().rl_on_policy_target is None) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope ++ and (get_global_server_args().rl_on_policy_target is None) + ) + self._used_fused_qk_norm_rope_last_call = False + +@@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): + prefix=add_prefix("attn", prefix), + ) + +- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) +- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) ++ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.alt_stream = alt_stream + + def op_prepare(self, state): +@@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module): + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) +- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.input_layernorm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + self.post_attention_layernorm = RMSNorm( +- config.hidden_size, eps=config.rms_norm_eps ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) + + self.layer_communicator = LayerCommunicator( +diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py +index 891913078..c9dbecd23 100644 +--- a/python/sglang/srt/models/qwen3_vl.py ++++ b/python/sglang/srt/models/qwen3_vl.py +@@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): + return cos_combined, sin_combined + + def fast_pos_embed_interpolate(self, grid_thw): +- patch_pos_embeds_permute = [] +- m_size = self.spatial_merge_size ++ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] ++ num_grid_per_side = int(self.num_position_embeddings**0.5) ++ device = self.pos_embed.weight.device ++ ++ idx_list = [[] for _ in range(4)] ++ weight_list = [[] for _ in range(4)] ++ ++ for t, h, w in zip(grid_ts, grid_hs, grid_ws): ++ h_idxs = torch.linspace(0, num_grid_per_side - 1, h) ++ w_idxs = torch.linspace(0, num_grid_per_side - 1, w) ++ ++ h_idxs_floor = h_idxs.int() ++ w_idxs_floor = w_idxs.int() ++ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ ++ dh = h_idxs - h_idxs_floor ++ dw = w_idxs - w_idxs_floor ++ ++ base_h = h_idxs_floor * num_grid_per_side ++ base_h_ceil = h_idxs_ceil * num_grid_per_side ++ ++ indices = [ ++ (base_h[None].T + w_idxs_floor[None]).flatten(), ++ (base_h[None].T + w_idxs_ceil[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ++ ] ++ ++ weights = [ ++ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ++ ((1 - dh)[None].T * dw[None]).flatten(), ++ (dh[None].T * (1 - dw)[None]).flatten(), ++ (dh[None].T * dw[None]).flatten(), ++ ] + +- embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device) +- embeds = ( +- self.pos_embed(embeds) +- .permute(1, 0) +- .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side) ++ for i in range(4): ++ idx_list[i].extend(indices[i].tolist()) ++ weight_list[i].extend(weights[i].tolist()) ++ ++ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) ++ weight_tensor = torch.tensor( ++ weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) +- for t, h, w in grid_thw: +- pos_embed = torch.nn.functional.interpolate( +- embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners +- ) +- pos_embed = pos_embed.reshape( +- -1, +- h // self.spatial_merge_size, +- self.spatial_merge_size, +- w // self.spatial_merge_size, +- self.spatial_merge_size, ++ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] ++ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] ++ ++ patch_pos_embeds = patch_pos_embeds.split( ++ [h * w for h, w in zip(grid_hs, grid_ws)] ++ ) ++ ++ patch_pos_embeds_permute = [] ++ merge_size = self.spatial_merge_size ++ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): ++ pos_embed = pos_embed.repeat(t, 1) ++ pos_embed = ( ++ pos_embed.view( ++ t, h // merge_size, merge_size, w // merge_size, merge_size, -1 ++ ) ++ .permute(0, 1, 3, 2, 4, 5) ++ .flatten(0, 4) + ) +- pos_embed = pos_embed.permute(1, 3, 2, 4, 0) +- pos_embed = pos_embed.flatten(0, 3).repeat(t, 1) + patch_pos_embeds_permute.append(pos_embed) + return torch.cat(patch_pos_embeds_permute) + +@@ -607,14 +647,19 @@ class Qwen3LLMModel(Qwen3Model): + hidden_states + residual if residual is not None else hidden_states + ) + ++ deepstack_embeds = None ++ if input_deepstack_embeds is not None: ++ prev_layer_idx = layer_idx - 1 ++ if prev_layer_idx in self.deepstack_embed_to_decoder_layer: ++ sep = self.hidden_size * prev_layer_idx ++ deepstack_embeds = input_deepstack_embeds[ ++ :, sep : sep + self.hidden_size ++ ] ++ + # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. + # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 + # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack + # The order matters because addition with different tensors is not associative in practice. +- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. +- deepstack_embeds = self.get_deepstack_embeds( +- layer_idx - 1, input_deepstack_embeds +- ) + hidden_states, residual = layer( + positions, + hidden_states, +diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py +index 54d4e415a..de7620c20 100644 +--- a/python/sglang/srt/server_args.py ++++ b/python/sglang/srt/server_args.py +@@ -523,6 +523,7 @@ class ServerArgs: + cuda_graph_max_bs: Optional[int] = None + cuda_graph_bs: Optional[List[int]] = None + disable_cuda_graph: bool = False ++ disable_draft_cuda_graph: bool = False + disable_cuda_graph_padding: bool = False + enable_profile_cuda_graph: bool = False + enable_cudagraph_gc: bool = False +@@ -3951,6 +3952,11 @@ class ServerArgs: + action="store_true", + help="Disable cuda graph.", + ) ++ parser.add_argument( ++ "--disable-draft-cuda-graph", ++ action="store_true", ++ help="Disable cuda graph for draft model in speculative decoding.", ++ ) + parser.add_argument( + "--disable-cuda-graph-padding", + action="store_true", +diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +index 5fe45086c..c95fbd0f6 100644 +--- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py ++++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: + self.seq_lens.fill_(self.seq_len_fill_value) + self.out_cache_loc.zero_() + self.positions.zero_() +- ++ self.topk_p.zero_() ++ self.topk_index.zero_() ++ self.hidden_states.zero_() ++ self.req_pool_indices.zero_() + num_tokens = bs * self.num_tokens_per_bs + + # Common inputs +@@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner: + forward_batch.out_cache_loc + ) + self.positions[:raw_num_token].copy_(forward_batch.positions) +- self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) +- self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) ++ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) ++ self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1)) + self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + +diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py +index 1bf3816e9..b5b41dba4 100644 +--- a/python/sglang/srt/speculative/eagle_info.py ++++ b/python/sglang/srt/speculative/eagle_info.py +@@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] ++ if self.accept_length is not None: ++ self.accept_length = self.accept_length[: len(new_indices)] ++ if self.accept_length_cpu is not None: ++ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] +@@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) ++ if self.accept_length is not None and spec_info.accept_length is not None: ++ self.accept_length = torch.cat( ++ [self.accept_length, spec_info.accept_length] ++ ) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif self.accept_length is not None: ++ zeros = torch.zeros( ++ [spec_info.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([self.accept_length, zeros]) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif spec_info.accept_length is not None: ++ zeros = torch.zeros( ++ [self.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([zeros, spec_info.accept_length]) ++ self.accept_length_cpu = self.accept_length.tolist() + + + @dataclass +diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py +index a702df4f8..61d9ae366 100644 +--- a/python/sglang/srt/speculative/eagle_worker.py ++++ b/python/sglang/srt/speculative/eagle_worker.py +@@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker): + self.cuda_graph_runner = None + self.cuda_graph_runner_for_draft_extend = None + +- if self.server_args.disable_cuda_graph: ++ if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph: + return + + Device2DraftCudaGraphRunner = { diff --git a/docker/version.txt b/docker/version.txt index 4cc406789..7ebac80c9 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20260105a \ No newline at end of file +nightly-dev-20260106a \ No newline at end of file