@@ -25,9 +25,13 @@ class ScriptArgs(U.ExecuteTrainConfig):
2525 enable_eval : bool = True
2626 extra_args : str = ""
2727 rollout_fp8 : bool = False
28+ rollout_attn_fp8 : bool = False
2829 enable_mtp : bool = False # TODO enable by default
2930 dynamic_sampling : bool = False
3031 enable_benchmark : bool = False
32+ enable_mis : bool = False
33+ # TODO improve, should be able to override more easily
34+ tis_use_rs : bool = True
3135 task : Literal ["dapo_aime" , "gsm8k" ] = "dapo_aime"
3236
3337
@@ -243,9 +247,11 @@ def train(args: ScriptArgs):
243247 # """--sglang-json-model-override-args '{"num_hidden_layers": 5}' """
244248 )
245249 sglang_extra_env_vars = {}
250+ if U .GENERATION_HARDWARE [args .hardware ] == "Blackwell" :
251+ sglang_args += "--sglang-attention-backend trtllm_mha "
246252 if args .rollout_fp8 :
247253 sglang_decode_max_bs = 256
248- sglang_attn_tp_size = 8
254+ sglang_attn_tp_size = min ( 8 , sglang_world_size )
249255 sglang_attn_dp_size = sglang_world_size // sglang_attn_tp_size
250256 sglang_args += (
251257 f"--sglang-ep-size { sglang_world_size } "
@@ -306,6 +312,35 @@ def train(args: ScriptArgs):
306312 if args .enable_benchmark :
307313 misc_args += (
308314 "--custom-generate-function-path slime.rollout.generate_hub.benchmarkers.generate_with_random_osl "
315+ "--rollout-batch-size 128 "
316+ "--n-samples-per-prompt 8 "
317+ "--use-distributed-post "
318+ "--router-policy round_robin "
319+ "--sglang-server-concurrency 10000 "
320+ # GB200 w/ mem-frac 0.8 will lead to oom in long jobs currently, but here we use large value to make baseline more fair
321+ f"--sglang-mem-fraction-static { 0.8 if args .hardware == 'GB300' else 0.75 } "
322+ )
323+
324+ if args .rollout_attn_fp8 :
325+ sglang_args += "--sglang-kv-cache-dtype fp8_e4m3 "
326+
327+ if args .enable_mis :
328+ config_text = f"""
329+ use_tis: true
330+ use_rs: { "true" if args .tis_use_rs else "false" }
331+ tis_level: "token"
332+ rs_level: "token"
333+ tis_mode: "truncate"
334+ tis_lower_bound: 0.5
335+ tis_upper_bound: 2.0
336+ rs_lower_bound: null
337+ rs_upper_bound: null
338+ rs_veto_threshold: 1.0e-4
339+ tis_batch_normalize: true
340+ """ .strip ()
341+ misc_args += (
342+ f"--custom-config-path { U .save_to_temp_file (config_text , 'yaml' )} "
343+ "--custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_with_cp "
309344 )
310345
311346 train_args = (
0 commit comments