Skip to content

Commit 29e9219

Browse files
committed
update code
1 parent 60e4edc commit 29e9219

File tree

7 files changed

+50
-7
lines changed

7 files changed

+50
-7
lines changed

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \
5353
cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \
5454
pip install -e .
5555

56-
RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall
56+
RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall
5757
RUN pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation
5858
RUN pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation
5959

scripts/run_glm45_355b_a32b.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ class ScriptArgs(U.ExecuteTrainConfig):
2525
enable_eval: bool = True
2626
extra_args: str = ""
2727
rollout_fp8: bool = False
28+
rollout_attn_fp8: bool = False
2829
enable_mtp: bool = False # TODO enable by default
2930
dynamic_sampling: bool = False
3031
enable_benchmark: bool = False
32+
enable_mis: bool = False
33+
# TODO improve, should be able to override more easily
34+
tis_use_rs: bool = True
3135
task: Literal["dapo_aime", "gsm8k"] = "dapo_aime"
3236

3337

@@ -243,9 +247,11 @@ def train(args: ScriptArgs):
243247
# """--sglang-json-model-override-args '{"num_hidden_layers": 5}' """
244248
)
245249
sglang_extra_env_vars = {}
250+
if U.GENERATION_HARDWARE[args.hardware] == "Blackwell":
251+
sglang_args += "--sglang-attention-backend trtllm_mha "
246252
if args.rollout_fp8:
247253
sglang_decode_max_bs = 256
248-
sglang_attn_tp_size = 8
254+
sglang_attn_tp_size = min(8, sglang_world_size)
249255
sglang_attn_dp_size = sglang_world_size // sglang_attn_tp_size
250256
sglang_args += (
251257
f"--sglang-ep-size {sglang_world_size} "
@@ -306,6 +312,35 @@ def train(args: ScriptArgs):
306312
if args.enable_benchmark:
307313
misc_args += (
308314
"--custom-generate-function-path slime.rollout.generate_hub.benchmarkers.generate_with_random_osl "
315+
"--rollout-batch-size 128 "
316+
"--n-samples-per-prompt 8 "
317+
"--use-distributed-post "
318+
"--router-policy round_robin "
319+
"--sglang-server-concurrency 10000 "
320+
# GB200 w/ mem-frac 0.8 will lead to oom in long jobs currently, but here we use large value to make baseline more fair
321+
f"--sglang-mem-fraction-static {0.8 if args.hardware == 'GB300' else 0.75} "
322+
)
323+
324+
if args.rollout_attn_fp8:
325+
sglang_args += "--sglang-kv-cache-dtype fp8_e4m3 "
326+
327+
if args.enable_mis:
328+
config_text = f"""
329+
use_tis: true
330+
use_rs: {"true" if args.tis_use_rs else "false"}
331+
tis_level: "token"
332+
rs_level: "token"
333+
tis_mode: "truncate"
334+
tis_lower_bound: 0.5
335+
tis_upper_bound: 2.0
336+
rs_lower_bound: null
337+
rs_upper_bound: null
338+
rs_veto_threshold: 1.0e-4
339+
tis_batch_normalize: true
340+
""".strip()
341+
misc_args += (
342+
f"--custom-config-path {U.save_to_temp_file(config_text, 'yaml')} "
343+
"--custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_with_cp "
309344
)
310345

311346
train_args = (

slime/backends/sglang_utils/sglang_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import requests
77
import sglang_router
88
from packaging.version import parse
9-
from sglang.srt.entrypoints.http_server import launch_server
109
from sglang.srt.server_args import ServerArgs
1110
from sglang.srt.utils import kill_process_tree
1211
from urllib3.exceptions import NewConnectionError
@@ -31,6 +30,8 @@ def get_base_gpu_id(args, rank):
3130

3231

3332
def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:
33+
from sglang.srt.entrypoints.http_server import launch_server
34+
3435
multiprocessing.set_start_method("spawn", force=True)
3536
server_args.host = server_args.host.strip("[]")
3637
p = multiprocessing.Process(target=launch_server, args=(server_args,))

slime/rollout/rm_hub/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import random
23

34
import aiohttp
45

@@ -57,6 +58,8 @@ async def async_rm(args, sample: Sample, **kwargs):
5758
from .ifbench import compute_ifbench_reward
5859

5960
return compute_ifbench_reward(response, label, metadata=metadata)
61+
elif rm_type == "random":
62+
return random.randint(0, 1)
6063
elif rm_type:
6164
raise NotImplementedError(f"Rule-based RM for {rm_type} is not implemented.")
6265
else:

slime/rollout/sglang_rollout.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,8 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]:
293293
response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers")
294294
urls = [worker["url"] for worker in response["workers"]]
295295

296-
for url in urls:
297-
logger.info(f"Abort request for {url}")
298-
await post(f"{url}/abort_request", {"abort_all": True})
296+
logger.info(f"Abort request for {urls}")
297+
await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls])
299298

300299
# make sure all the pending tasks are finished
301300
count = 0

slime/utils/external_utils/command_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,9 @@ def save_to_temp_file(text: str, ext: str):
262262
"GB200": 4,
263263
"GB300": 4,
264264
}
265+
266+
GENERATION_HARDWARE = {
267+
"H100": "Hopper",
268+
"GB200": "Blackwell",
269+
"GB300": "Blackwell",
270+
}

slime/utils/tensor_backper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def get(self, tag: str):
3131
def backup(self, tag: str):
3232
raise NotImplementedError
3333

34-
@abstractmethod
3534
def copy(self, *, src_tag: str, dst_tag: str):
3635
raise NotImplementedError
3736

0 commit comments

Comments
 (0)